From 1ce7bdea83a8a90f2f0f1205b1006887df9fb392 Mon Sep 17 00:00:00 2001 From: Daniel Woznicki Date: Mon, 3 Jul 2023 22:25:46 -0700 Subject: [PATCH 1/2] Added script to seed a development database with dummy data. Currently supports groups, topics, and users. --- tildes/requirements-dev.in | 1 + tildes/requirements-dev.txt | 1 + tildes/scripts/seed_dev_database.py | 266 ++++++++++++++++++++++++++++ 3 files changed, 268 insertions(+) create mode 100644 tildes/scripts/seed_dev_database.py diff --git a/tildes/requirements-dev.in b/tildes/requirements-dev.in index 09af633..49d0417 100644 --- a/tildes/requirements-dev.in +++ b/tildes/requirements-dev.in @@ -1,5 +1,6 @@ -r requirements.in black +faker freezegun html5validator mypy diff --git a/tildes/requirements-dev.txt b/tildes/requirements-dev.txt index b29fba1..4465b4c 100644 --- a/tildes/requirements-dev.txt +++ b/tildes/requirements-dev.txt @@ -15,6 +15,7 @@ click==8.0.1 cornice==5.2.0 decorator==5.0.9 dodgy==0.2.1 +faker==18.11.2 flake8==3.9.2 flake8-polyfill==1.0.2 freezegun==1.1.0 diff --git a/tildes/scripts/seed_dev_database.py b/tildes/scripts/seed_dev_database.py new file mode 100644 index 0000000..19cda4f --- /dev/null +++ b/tildes/scripts/seed_dev_database.py @@ -0,0 +1,266 @@ +# Copyright (c) 2023 Tildes contributors +# SPDX-License-Identifier: AGPL-3.0-or-later + +"""Script to seed a dev database with fake data. + +If you run this script as is, no data will be inserted. To get this script to do +something, you must modify at least one of the following variables and change the value +to a number greater than zero. + - NUM_GROUPS + - NUM_TOPICS (also NUM_TAGS for this one) + - NUM_USERS + +Note: This script is mean to be run on *dev databases only*. Running on production will +insert a ton of junk data. +""" + +import os +import sys +import random +from typing import Any, Tuple +from sqlalchemy import func +from sqlalchemy.orm import Session +from sqlalchemy.sql import text +from faker import Faker + +from tildes.lib.database import get_session_from_config +from tildes.lib.hash import hash_string +from tildes.lib.markdown import convert_markdown_to_safe_html +from tildes.models.group import Group +from tildes.models.topic import Topic +from tildes.models.user import User + + +class WeightedValues: + """Mapping of values by weight. Useful for working with semi-random data.""" + + def __init__(self, default_value: Any, odds: list[Tuple[float, Any]]): + """Create new weighted values.""" + self.default_value = default_value + self.odds = odds + + def get(self) -> Any: + """Get a random value from the odds dict, or the default value. + + Values with a higher ratio are more likely be to returned. For example, if this + class instance was constructed as + + WeightedValues("dunno", {0.5: "hello", 0.2: "goodbye"}) + + then this function will have a 50% chance to return "hello", a 20% chance to + return "goodbye", and a remaining 30% chance to return "dunno". + """ + rand = random.random() + num = 0.0 + for (ratio, value) in self.odds: + num += ratio + if rand < num: + return value + return self.default_value + + +# Number of each resource to insert during seeding. +# A value of `0` means no new records will be inserted. +NUM_GROUPS = 0 +NUM_TOPICS = 0 +NUM_TAGS = 0 # must be greater than 0 when creating topics +NUM_USERS = 0 + +TAG_NESTING_ODDS = WeightedValues( + 1, + [ + (0.25, 2), # "a.b" + (0.05, 3), # "a.b.c" + ], +) +TOPIC_LINK_ODDS = WeightedValues( + "TEXT", + [ + (0.4, "LINK"), + ], +) +TAGS_PER_TOPIC_ODDS = WeightedValues( + 1, + [ + (0.2, 2), + (0.4, 3), + (0.1, 4), + (0.1, 5), + (0.1, 6), + ], +) + +fake = Faker() + + +def seed_dev_database() -> None: + """Seed a dev database with one or more resource. + + Note that not all resources are represented yet because I'm lazy. + """ + ini_file_path = get_ini_file_path() + db_session = get_session_from_config(ini_file_path) + + # This function might call `sys.exit()` if the user does not okay the seeding. + confirm_everything_looks_okay(db_session, ini_file_path) + + # NOTE: Creating records in bulk using the Tildes models is quite inefficient. + # Instead, we'll insert using raw SQL and dictionaries. + if NUM_GROUPS > 0: + seed_groups(db_session) + + if NUM_USERS > 0: + seed_users(db_session) + + if NUM_TOPICS > 0: + seed_topics(db_session) + + db_session.commit() + print("Database seeding completed") + + +def seed_groups(db_session: Session) -> None: + """Insert dummy groups into database.""" + print(f"Generating {NUM_GROUPS} groups") + groups_to_insert = [] + print_progress(0, NUM_GROUPS) + for _ in range(NUM_GROUPS): + group = { + "path": fake.unique.word(), + "short_description": "automatically generated by seed_dev_database.py", + } + groups_to_insert.append(group) + print_progress(len(groups_to_insert), NUM_GROUPS) + print("Writing groups to database") + statement = text( + """ + INSERT INTO groups (path, short_description) + VALUES (:path, :short_description) + """ + ) + for group in groups_to_insert: + db_session.execute(statement, group) + db_session.flush() + + +def seed_users(db_session: Session) -> None: + """Insert dummy users into database.""" + print(f"Generating {NUM_USERS} users") + users_to_insert = [] + print_progress(0, NUM_USERS) + password_hash = hash_string("password") + for _ in range(NUM_USERS): + user = { + "username": fake.unique.user_name(), + "password_hash": password_hash, + } + users_to_insert.append(user) + print_progress(len(users_to_insert), NUM_USERS) + print("Writing users to database") + statement = text( + """ + INSERT INTO users (username, password_hash) + VALUES (:username, :password_hash) + """ + ) + for user in users_to_insert: + db_session.execute(statement, user) + db_session.flush() + + +def seed_topics(db_session: Session) -> None: + """Insert dummy topics into database.""" + print(f"Generating {NUM_TOPICS} topics") + topics_to_insert = [] + all_groups = db_session.query(Group).all() + if not all_groups: + raise Exception("At least one group is required when seeding topics") + all_users = db_session.query(User).all() + if not all_users: + raise Exception("At least one user is required when seeding topics") + all_tags = [] + if NUM_TAGS > 0: + for _ in range(NUM_TAGS): + nesting_level = TAG_NESTING_ODDS.get() + tag = ".".join([fake.word() for _ in range(nesting_level)]) + all_tags.append(tag) + if not all_tags: + raise Exception("At least one tag is required for seeding topics") + print_progress(0, NUM_TOPICS) + for _ in range(NUM_TOPICS): + num_tags = TAGS_PER_TOPIC_ODDS.get() + topic_type = TOPIC_LINK_ODDS.get() + markdown = fake.sentence() + topic = { + "group_id": random.choice(all_groups).group_id, + "user_id": random.choice(all_users).user_id, + "topic_type": topic_type, + "title": " ".join(fake.words(3)), + "markdown": markdown if topic_type == "TEXT" else None, + "rendered_html": convert_markdown_to_safe_html(markdown) + if topic_type == "TEXT" + else None, + "link": fake.url() if topic_type == "LINK" else None, + "tags": [random.choice(all_tags) for _ in range(num_tags)], + } + topics_to_insert.append(topic) + print_progress(len(topics_to_insert), NUM_TOPICS) + print("Writing topics to database") + statement = text( + r""" + INSERT INTO topics (group_id, user_id, topic_type, title, markdown, + rendered_html, link, tags) + VALUES (:group_id, :user_id, :topic_type, :title, :markdown, :rendered_html, + :link, :tags\:\:ltree[]) + """ + ) + for topic in topics_to_insert: + db_session.execute(statement, topic) + db_session.flush() + + +def get_ini_file_path() -> str: + """Get the ini file path from environment, or the default development.ini path.""" + if "INI_FILE" in os.environ: + return os.environ["INI_FILE"] + return os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "development.ini") + ) + + +def confirm_everything_looks_okay(db_session: Session, ini_file_path: str) -> None: + """Ask the user to confirm this database seeding. + + This function calls `sys.exit` if they bail. + """ + num_existing_groups = db_session.query(func.count(Group.group_id)).scalar() + num_existing_topics = db_session.query(func.count(Topic.topic_id)).scalar() + num_existing_users = db_session.query(func.count(User.user_id)).scalar() + print("Confirm this database seeding.") + print("") + print(f" ini file: {ini_file_path}") + print( + f" groups: {NUM_GROUPS} groups will be inserted, " + f"{num_existing_groups} already exist" + ) + print( + f" topics: {NUM_TOPICS} topic swill be inserted with {NUM_TAGS} " + f"possible tags, {num_existing_topics} already exist" + ) + print(f" users: {NUM_USERS} will be inserted, {num_existing_users} already exist") + print("") + answer = input("Proceed with database seeding? [y/N] ") + if not answer.lower().startswith("y"): + sys.exit() + + +def print_progress(num_finished: int, num_total: int) -> None: + """Print the progress for some task in "completed/total" format.""" + print(f"Completed {num_finished}/{num_total}", end="\r") + if num_finished >= num_total: + print("") + print("Done", flush=True) + + +if __name__ == "__main__": + seed_dev_database() From 18570a4ac3f4f5435337401a5e4c58ed73e327fd Mon Sep 17 00:00:00 2001 From: Daniel Woznicki Date: Tue, 4 Jul 2023 13:55:06 -0700 Subject: [PATCH 2/2] Modified seeding script to use PostgreSQL COPY function instead of INSERTs. This *greatly* speeds up the seeding process. Removed call to `convert_markdown_to_safe_html` since it was slow. --- tildes/scripts/seed_dev_database.py | 147 ++++++++++++++++------------ 1 file changed, 85 insertions(+), 62 deletions(-) diff --git a/tildes/scripts/seed_dev_database.py b/tildes/scripts/seed_dev_database.py index 19cda4f..16b5049 100644 --- a/tildes/scripts/seed_dev_database.py +++ b/tildes/scripts/seed_dev_database.py @@ -17,15 +17,15 @@ insert a ton of junk data. import os import sys import random +from io import StringIO +import csv from typing import Any, Tuple from sqlalchemy import func from sqlalchemy.orm import Session -from sqlalchemy.sql import text from faker import Faker from tildes.lib.database import get_session_from_config from tildes.lib.hash import hash_string -from tildes.lib.markdown import convert_markdown_to_safe_html from tildes.models.group import Group from tildes.models.topic import Topic from tildes.models.user import User @@ -105,7 +105,7 @@ def seed_dev_database() -> None: confirm_everything_looks_okay(db_session, ini_file_path) # NOTE: Creating records in bulk using the Tildes models is quite inefficient. - # Instead, we'll insert using raw SQL and dictionaries. + # Instead, we'll insert using PostgreSQL's COPY and raw CSV strings. if NUM_GROUPS > 0: seed_groups(db_session) @@ -122,56 +122,71 @@ def seed_dev_database() -> None: def seed_groups(db_session: Session) -> None: """Insert dummy groups into database.""" print(f"Generating {NUM_GROUPS} groups") - groups_to_insert = [] + groups_output = StringIO() + csv_writer = csv.writer(groups_output, delimiter="\t") print_progress(0, NUM_GROUPS) - for _ in range(NUM_GROUPS): - group = { - "path": fake.unique.word(), - "short_description": "automatically generated by seed_dev_database.py", - } - groups_to_insert.append(group) - print_progress(len(groups_to_insert), NUM_GROUPS) + for i in range(NUM_GROUPS): + csv_writer.writerow( + [ + fake.unique.word(), + "automatically generated by seed_dev_database.py", + ] + ) + print_progress(i + 1, NUM_GROUPS) print("Writing groups to database") - statement = text( - """ - INSERT INTO groups (path, short_description) - VALUES (:path, :short_description) - """ + raw_conn = db_session.get_bind().raw_connection() + cursor = raw_conn.cursor() + groups_output.seek(0) + # Set timeout to an arbitrarily high number since COPY can take a while. + cursor.execute("SET statement_timeout = '999999s'") + cursor.copy_from( + groups_output, + "groups", + columns=( + "path", + "short_description", + ), ) - for group in groups_to_insert: - db_session.execute(statement, group) - db_session.flush() + raw_conn.commit() def seed_users(db_session: Session) -> None: """Insert dummy users into database.""" print(f"Generating {NUM_USERS} users") - users_to_insert = [] + users_output = StringIO() + csv_writer = csv.writer(users_output, delimiter="\t") print_progress(0, NUM_USERS) password_hash = hash_string("password") - for _ in range(NUM_USERS): - user = { - "username": fake.unique.user_name(), - "password_hash": password_hash, - } - users_to_insert.append(user) - print_progress(len(users_to_insert), NUM_USERS) + for i in range(NUM_USERS): + csv_writer.writerow( + [ + fake.unique.user_name(), + password_hash, + ] + ) + print_progress(i + 1, NUM_USERS) print("Writing users to database") - statement = text( - """ - INSERT INTO users (username, password_hash) - VALUES (:username, :password_hash) - """ + raw_conn = db_session.get_bind().raw_connection() + cursor = raw_conn.cursor() + users_output.seek(0) + # Set timeout to an arbitrarily high number since COPY can take a while. + cursor.execute("SET statement_timeout = '999999s'") + cursor.copy_from( + users_output, + "users", + columns=( + "username", + "password_hash", + ), ) - for user in users_to_insert: - db_session.execute(statement, user) - db_session.flush() + raw_conn.commit() def seed_topics(db_session: Session) -> None: """Insert dummy topics into database.""" print(f"Generating {NUM_TOPICS} topics") - topics_to_insert = [] + topics_output = StringIO() + csv_writer = csv.writer(topics_output, delimiter="\t") all_groups = db_session.query(Group).all() if not all_groups: raise Exception("At least one group is required when seeding topics") @@ -187,36 +202,44 @@ def seed_topics(db_session: Session) -> None: if not all_tags: raise Exception("At least one tag is required for seeding topics") print_progress(0, NUM_TOPICS) - for _ in range(NUM_TOPICS): + for i in range(NUM_TOPICS): num_tags = TAGS_PER_TOPIC_ODDS.get() topic_type = TOPIC_LINK_ODDS.get() - markdown = fake.sentence() - topic = { - "group_id": random.choice(all_groups).group_id, - "user_id": random.choice(all_users).user_id, - "topic_type": topic_type, - "title": " ".join(fake.words(3)), - "markdown": markdown if topic_type == "TEXT" else None, - "rendered_html": convert_markdown_to_safe_html(markdown) - if topic_type == "TEXT" - else None, - "link": fake.url() if topic_type == "LINK" else None, - "tags": [random.choice(all_tags) for _ in range(num_tags)], - } - topics_to_insert.append(topic) - print_progress(len(topics_to_insert), NUM_TOPICS) + text_content = fake.sentence() + csv_writer.writerow( + [ + random.choice(all_groups).group_id, + random.choice(all_users).user_id, + topic_type, + " ".join(fake.words(3)), + text_content if topic_type == "TEXT" else None, + f"

{text_content}

" if topic_type == "TEXT" else None, + fake.url() if topic_type == "LINK" else None, + f"{{{','.join([random.choice(all_tags) for _ in range(num_tags)])}}}", + ] + ) + print_progress(i + 1, NUM_TOPICS) print("Writing topics to database") - statement = text( - r""" - INSERT INTO topics (group_id, user_id, topic_type, title, markdown, - rendered_html, link, tags) - VALUES (:group_id, :user_id, :topic_type, :title, :markdown, :rendered_html, - :link, :tags\:\:ltree[]) - """ + raw_conn = db_session.get_bind().raw_connection() + cursor = raw_conn.cursor() + topics_output.seek(0) + # Set timeout to an arbitrarily high number since COPY can take a while. + cursor.execute("SET statement_timeout = '999999s'") + cursor.copy_from( + topics_output, + "topics", + columns=( + "group_id", + "user_id", + "topic_type", + "title", + "markdown", + "rendered_html", + "link", + "tags", + ), ) - for topic in topics_to_insert: - db_session.execute(statement, topic) - db_session.flush() + raw_conn.commit() def get_ini_file_path() -> str: