From 18570a4ac3f4f5435337401a5e4c58ed73e327fd Mon Sep 17 00:00:00 2001 From: Daniel Woznicki Date: Tue, 4 Jul 2023 13:55:06 -0700 Subject: [PATCH] Modified seeding script to use PostgreSQL COPY function instead of INSERTs. This *greatly* speeds up the seeding process. Removed call to `convert_markdown_to_safe_html` since it was slow. --- tildes/scripts/seed_dev_database.py | 147 ++++++++++++++++------------ 1 file changed, 85 insertions(+), 62 deletions(-) diff --git a/tildes/scripts/seed_dev_database.py b/tildes/scripts/seed_dev_database.py index 19cda4f..16b5049 100644 --- a/tildes/scripts/seed_dev_database.py +++ b/tildes/scripts/seed_dev_database.py @@ -17,15 +17,15 @@ insert a ton of junk data. import os import sys import random +from io import StringIO +import csv from typing import Any, Tuple from sqlalchemy import func from sqlalchemy.orm import Session -from sqlalchemy.sql import text from faker import Faker from tildes.lib.database import get_session_from_config from tildes.lib.hash import hash_string -from tildes.lib.markdown import convert_markdown_to_safe_html from tildes.models.group import Group from tildes.models.topic import Topic from tildes.models.user import User @@ -105,7 +105,7 @@ def seed_dev_database() -> None: confirm_everything_looks_okay(db_session, ini_file_path) # NOTE: Creating records in bulk using the Tildes models is quite inefficient. - # Instead, we'll insert using raw SQL and dictionaries. + # Instead, we'll insert using PostgreSQL's COPY and raw CSV strings. if NUM_GROUPS > 0: seed_groups(db_session) @@ -122,56 +122,71 @@ def seed_dev_database() -> None: def seed_groups(db_session: Session) -> None: """Insert dummy groups into database.""" print(f"Generating {NUM_GROUPS} groups") - groups_to_insert = [] + groups_output = StringIO() + csv_writer = csv.writer(groups_output, delimiter="\t") print_progress(0, NUM_GROUPS) - for _ in range(NUM_GROUPS): - group = { - "path": fake.unique.word(), - "short_description": "automatically generated by seed_dev_database.py", - } - groups_to_insert.append(group) - print_progress(len(groups_to_insert), NUM_GROUPS) + for i in range(NUM_GROUPS): + csv_writer.writerow( + [ + fake.unique.word(), + "automatically generated by seed_dev_database.py", + ] + ) + print_progress(i + 1, NUM_GROUPS) print("Writing groups to database") - statement = text( - """ - INSERT INTO groups (path, short_description) - VALUES (:path, :short_description) - """ + raw_conn = db_session.get_bind().raw_connection() + cursor = raw_conn.cursor() + groups_output.seek(0) + # Set timeout to an arbitrarily high number since COPY can take a while. + cursor.execute("SET statement_timeout = '999999s'") + cursor.copy_from( + groups_output, + "groups", + columns=( + "path", + "short_description", + ), ) - for group in groups_to_insert: - db_session.execute(statement, group) - db_session.flush() + raw_conn.commit() def seed_users(db_session: Session) -> None: """Insert dummy users into database.""" print(f"Generating {NUM_USERS} users") - users_to_insert = [] + users_output = StringIO() + csv_writer = csv.writer(users_output, delimiter="\t") print_progress(0, NUM_USERS) password_hash = hash_string("password") - for _ in range(NUM_USERS): - user = { - "username": fake.unique.user_name(), - "password_hash": password_hash, - } - users_to_insert.append(user) - print_progress(len(users_to_insert), NUM_USERS) + for i in range(NUM_USERS): + csv_writer.writerow( + [ + fake.unique.user_name(), + password_hash, + ] + ) + print_progress(i + 1, NUM_USERS) print("Writing users to database") - statement = text( - """ - INSERT INTO users (username, password_hash) - VALUES (:username, :password_hash) - """ + raw_conn = db_session.get_bind().raw_connection() + cursor = raw_conn.cursor() + users_output.seek(0) + # Set timeout to an arbitrarily high number since COPY can take a while. + cursor.execute("SET statement_timeout = '999999s'") + cursor.copy_from( + users_output, + "users", + columns=( + "username", + "password_hash", + ), ) - for user in users_to_insert: - db_session.execute(statement, user) - db_session.flush() + raw_conn.commit() def seed_topics(db_session: Session) -> None: """Insert dummy topics into database.""" print(f"Generating {NUM_TOPICS} topics") - topics_to_insert = [] + topics_output = StringIO() + csv_writer = csv.writer(topics_output, delimiter="\t") all_groups = db_session.query(Group).all() if not all_groups: raise Exception("At least one group is required when seeding topics") @@ -187,36 +202,44 @@ def seed_topics(db_session: Session) -> None: if not all_tags: raise Exception("At least one tag is required for seeding topics") print_progress(0, NUM_TOPICS) - for _ in range(NUM_TOPICS): + for i in range(NUM_TOPICS): num_tags = TAGS_PER_TOPIC_ODDS.get() topic_type = TOPIC_LINK_ODDS.get() - markdown = fake.sentence() - topic = { - "group_id": random.choice(all_groups).group_id, - "user_id": random.choice(all_users).user_id, - "topic_type": topic_type, - "title": " ".join(fake.words(3)), - "markdown": markdown if topic_type == "TEXT" else None, - "rendered_html": convert_markdown_to_safe_html(markdown) - if topic_type == "TEXT" - else None, - "link": fake.url() if topic_type == "LINK" else None, - "tags": [random.choice(all_tags) for _ in range(num_tags)], - } - topics_to_insert.append(topic) - print_progress(len(topics_to_insert), NUM_TOPICS) + text_content = fake.sentence() + csv_writer.writerow( + [ + random.choice(all_groups).group_id, + random.choice(all_users).user_id, + topic_type, + " ".join(fake.words(3)), + text_content if topic_type == "TEXT" else None, + f"

{text_content}

" if topic_type == "TEXT" else None, + fake.url() if topic_type == "LINK" else None, + f"{{{','.join([random.choice(all_tags) for _ in range(num_tags)])}}}", + ] + ) + print_progress(i + 1, NUM_TOPICS) print("Writing topics to database") - statement = text( - r""" - INSERT INTO topics (group_id, user_id, topic_type, title, markdown, - rendered_html, link, tags) - VALUES (:group_id, :user_id, :topic_type, :title, :markdown, :rendered_html, - :link, :tags\:\:ltree[]) - """ + raw_conn = db_session.get_bind().raw_connection() + cursor = raw_conn.cursor() + topics_output.seek(0) + # Set timeout to an arbitrarily high number since COPY can take a while. + cursor.execute("SET statement_timeout = '999999s'") + cursor.copy_from( + topics_output, + "topics", + columns=( + "group_id", + "user_id", + "topic_type", + "title", + "markdown", + "rendered_html", + "link", + "tags", + ), ) - for topic in topics_to_insert: - db_session.execute(statement, topic) - db_session.flush() + raw_conn.commit() def get_ini_file_path() -> str: