Browse Source

Modified seeding script to use PostgreSQL COPY function instead of INSERTs. This *greatly* speeds up the seeding process.

Removed call to `convert_markdown_to_safe_html` since it was slow.
merge-requests/147/head
Daniel Woznicki 2 years ago
parent
commit
18570a4ac3
  1. 147
      tildes/scripts/seed_dev_database.py

147
tildes/scripts/seed_dev_database.py

@ -17,15 +17,15 @@ insert a ton of junk data.
import os
import sys
import random
from io import StringIO
import csv
from typing import Any, Tuple
from sqlalchemy import func
from sqlalchemy.orm import Session
from sqlalchemy.sql import text
from faker import Faker
from tildes.lib.database import get_session_from_config
from tildes.lib.hash import hash_string
from tildes.lib.markdown import convert_markdown_to_safe_html
from tildes.models.group import Group
from tildes.models.topic import Topic
from tildes.models.user import User
@ -105,7 +105,7 @@ def seed_dev_database() -> None:
confirm_everything_looks_okay(db_session, ini_file_path)
# NOTE: Creating records in bulk using the Tildes models is quite inefficient.
# Instead, we'll insert using raw SQL and dictionaries.
# Instead, we'll insert using PostgreSQL's COPY and raw CSV strings.
if NUM_GROUPS > 0:
seed_groups(db_session)
@ -122,56 +122,71 @@ def seed_dev_database() -> None:
def seed_groups(db_session: Session) -> None:
"""Insert dummy groups into database."""
print(f"Generating {NUM_GROUPS} groups")
groups_to_insert = []
groups_output = StringIO()
csv_writer = csv.writer(groups_output, delimiter="\t")
print_progress(0, NUM_GROUPS)
for _ in range(NUM_GROUPS):
group = {
"path": fake.unique.word(),
"short_description": "automatically generated by seed_dev_database.py",
}
groups_to_insert.append(group)
print_progress(len(groups_to_insert), NUM_GROUPS)
for i in range(NUM_GROUPS):
csv_writer.writerow(
[
fake.unique.word(),
"automatically generated by seed_dev_database.py",
]
)
print_progress(i + 1, NUM_GROUPS)
print("Writing groups to database")
statement = text(
"""
INSERT INTO groups (path, short_description)
VALUES (:path, :short_description)
"""
raw_conn = db_session.get_bind().raw_connection()
cursor = raw_conn.cursor()
groups_output.seek(0)
# Set timeout to an arbitrarily high number since COPY can take a while.
cursor.execute("SET statement_timeout = '999999s'")
cursor.copy_from(
groups_output,
"groups",
columns=(
"path",
"short_description",
),
)
for group in groups_to_insert:
db_session.execute(statement, group)
db_session.flush()
raw_conn.commit()
def seed_users(db_session: Session) -> None:
"""Insert dummy users into database."""
print(f"Generating {NUM_USERS} users")
users_to_insert = []
users_output = StringIO()
csv_writer = csv.writer(users_output, delimiter="\t")
print_progress(0, NUM_USERS)
password_hash = hash_string("password")
for _ in range(NUM_USERS):
user = {
"username": fake.unique.user_name(),
"password_hash": password_hash,
}
users_to_insert.append(user)
print_progress(len(users_to_insert), NUM_USERS)
for i in range(NUM_USERS):
csv_writer.writerow(
[
fake.unique.user_name(),
password_hash,
]
)
print_progress(i + 1, NUM_USERS)
print("Writing users to database")
statement = text(
"""
INSERT INTO users (username, password_hash)
VALUES (:username, :password_hash)
"""
raw_conn = db_session.get_bind().raw_connection()
cursor = raw_conn.cursor()
users_output.seek(0)
# Set timeout to an arbitrarily high number since COPY can take a while.
cursor.execute("SET statement_timeout = '999999s'")
cursor.copy_from(
users_output,
"users",
columns=(
"username",
"password_hash",
),
)
for user in users_to_insert:
db_session.execute(statement, user)
db_session.flush()
raw_conn.commit()
def seed_topics(db_session: Session) -> None:
"""Insert dummy topics into database."""
print(f"Generating {NUM_TOPICS} topics")
topics_to_insert = []
topics_output = StringIO()
csv_writer = csv.writer(topics_output, delimiter="\t")
all_groups = db_session.query(Group).all()
if not all_groups:
raise Exception("At least one group is required when seeding topics")
@ -187,36 +202,44 @@ def seed_topics(db_session: Session) -> None:
if not all_tags:
raise Exception("At least one tag is required for seeding topics")
print_progress(0, NUM_TOPICS)
for _ in range(NUM_TOPICS):
for i in range(NUM_TOPICS):
num_tags = TAGS_PER_TOPIC_ODDS.get()
topic_type = TOPIC_LINK_ODDS.get()
markdown = fake.sentence()
topic = {
"group_id": random.choice(all_groups).group_id,
"user_id": random.choice(all_users).user_id,
"topic_type": topic_type,
"title": " ".join(fake.words(3)),
"markdown": markdown if topic_type == "TEXT" else None,
"rendered_html": convert_markdown_to_safe_html(markdown)
if topic_type == "TEXT"
else None,
"link": fake.url() if topic_type == "LINK" else None,
"tags": [random.choice(all_tags) for _ in range(num_tags)],
}
topics_to_insert.append(topic)
print_progress(len(topics_to_insert), NUM_TOPICS)
text_content = fake.sentence()
csv_writer.writerow(
[
random.choice(all_groups).group_id,
random.choice(all_users).user_id,
topic_type,
" ".join(fake.words(3)),
text_content if topic_type == "TEXT" else None,
f"<p>{text_content}</p>" if topic_type == "TEXT" else None,
fake.url() if topic_type == "LINK" else None,
f"{{{','.join([random.choice(all_tags) for _ in range(num_tags)])}}}",
]
)
print_progress(i + 1, NUM_TOPICS)
print("Writing topics to database")
statement = text(
r"""
INSERT INTO topics (group_id, user_id, topic_type, title, markdown,
rendered_html, link, tags)
VALUES (:group_id, :user_id, :topic_type, :title, :markdown, :rendered_html,
:link, :tags\:\:ltree[])
"""
raw_conn = db_session.get_bind().raw_connection()
cursor = raw_conn.cursor()
topics_output.seek(0)
# Set timeout to an arbitrarily high number since COPY can take a while.
cursor.execute("SET statement_timeout = '999999s'")
cursor.copy_from(
topics_output,
"topics",
columns=(
"group_id",
"user_id",
"topic_type",
"title",
"markdown",
"rendered_html",
"link",
"tags",
),
)
for topic in topics_to_insert:
db_session.execute(statement, topic)
db_session.flush()
raw_conn.commit()
def get_ini_file_path() -> str:

Loading…
Cancel
Save