diff --git a/tildes/requirements-dev.in b/tildes/requirements-dev.in index 09af633..49d0417 100644 --- a/tildes/requirements-dev.in +++ b/tildes/requirements-dev.in @@ -1,5 +1,6 @@ -r requirements.in black +faker freezegun html5validator mypy diff --git a/tildes/requirements-dev.txt b/tildes/requirements-dev.txt index b29fba1..4465b4c 100644 --- a/tildes/requirements-dev.txt +++ b/tildes/requirements-dev.txt @@ -15,6 +15,7 @@ click==8.0.1 cornice==5.2.0 decorator==5.0.9 dodgy==0.2.1 +faker==18.11.2 flake8==3.9.2 flake8-polyfill==1.0.2 freezegun==1.1.0 diff --git a/tildes/scripts/seed_dev_database.py b/tildes/scripts/seed_dev_database.py new file mode 100644 index 0000000..16b5049 --- /dev/null +++ b/tildes/scripts/seed_dev_database.py @@ -0,0 +1,289 @@ +# Copyright (c) 2023 Tildes contributors +# SPDX-License-Identifier: AGPL-3.0-or-later + +"""Script to seed a dev database with fake data. + +If you run this script as is, no data will be inserted. To get this script to do +something, you must modify at least one of the following variables and change the value +to a number greater than zero. + - NUM_GROUPS + - NUM_TOPICS (also NUM_TAGS for this one) + - NUM_USERS + +Note: This script is mean to be run on *dev databases only*. Running on production will +insert a ton of junk data. +""" + +import os +import sys +import random +from io import StringIO +import csv +from typing import Any, Tuple +from sqlalchemy import func +from sqlalchemy.orm import Session +from faker import Faker + +from tildes.lib.database import get_session_from_config +from tildes.lib.hash import hash_string +from tildes.models.group import Group +from tildes.models.topic import Topic +from tildes.models.user import User + + +class WeightedValues: + """Mapping of values by weight. Useful for working with semi-random data.""" + + def __init__(self, default_value: Any, odds: list[Tuple[float, Any]]): + """Create new weighted values.""" + self.default_value = default_value + self.odds = odds + + def get(self) -> Any: + """Get a random value from the odds dict, or the default value. + + Values with a higher ratio are more likely be to returned. For example, if this + class instance was constructed as + + WeightedValues("dunno", {0.5: "hello", 0.2: "goodbye"}) + + then this function will have a 50% chance to return "hello", a 20% chance to + return "goodbye", and a remaining 30% chance to return "dunno". + """ + rand = random.random() + num = 0.0 + for (ratio, value) in self.odds: + num += ratio + if rand < num: + return value + return self.default_value + + +# Number of each resource to insert during seeding. +# A value of `0` means no new records will be inserted. +NUM_GROUPS = 0 +NUM_TOPICS = 0 +NUM_TAGS = 0 # must be greater than 0 when creating topics +NUM_USERS = 0 + +TAG_NESTING_ODDS = WeightedValues( + 1, + [ + (0.25, 2), # "a.b" + (0.05, 3), # "a.b.c" + ], +) +TOPIC_LINK_ODDS = WeightedValues( + "TEXT", + [ + (0.4, "LINK"), + ], +) +TAGS_PER_TOPIC_ODDS = WeightedValues( + 1, + [ + (0.2, 2), + (0.4, 3), + (0.1, 4), + (0.1, 5), + (0.1, 6), + ], +) + +fake = Faker() + + +def seed_dev_database() -> None: + """Seed a dev database with one or more resource. + + Note that not all resources are represented yet because I'm lazy. + """ + ini_file_path = get_ini_file_path() + db_session = get_session_from_config(ini_file_path) + + # This function might call `sys.exit()` if the user does not okay the seeding. + confirm_everything_looks_okay(db_session, ini_file_path) + + # NOTE: Creating records in bulk using the Tildes models is quite inefficient. + # Instead, we'll insert using PostgreSQL's COPY and raw CSV strings. + if NUM_GROUPS > 0: + seed_groups(db_session) + + if NUM_USERS > 0: + seed_users(db_session) + + if NUM_TOPICS > 0: + seed_topics(db_session) + + db_session.commit() + print("Database seeding completed") + + +def seed_groups(db_session: Session) -> None: + """Insert dummy groups into database.""" + print(f"Generating {NUM_GROUPS} groups") + groups_output = StringIO() + csv_writer = csv.writer(groups_output, delimiter="\t") + print_progress(0, NUM_GROUPS) + for i in range(NUM_GROUPS): + csv_writer.writerow( + [ + fake.unique.word(), + "automatically generated by seed_dev_database.py", + ] + ) + print_progress(i + 1, NUM_GROUPS) + print("Writing groups to database") + raw_conn = db_session.get_bind().raw_connection() + cursor = raw_conn.cursor() + groups_output.seek(0) + # Set timeout to an arbitrarily high number since COPY can take a while. + cursor.execute("SET statement_timeout = '999999s'") + cursor.copy_from( + groups_output, + "groups", + columns=( + "path", + "short_description", + ), + ) + raw_conn.commit() + + +def seed_users(db_session: Session) -> None: + """Insert dummy users into database.""" + print(f"Generating {NUM_USERS} users") + users_output = StringIO() + csv_writer = csv.writer(users_output, delimiter="\t") + print_progress(0, NUM_USERS) + password_hash = hash_string("password") + for i in range(NUM_USERS): + csv_writer.writerow( + [ + fake.unique.user_name(), + password_hash, + ] + ) + print_progress(i + 1, NUM_USERS) + print("Writing users to database") + raw_conn = db_session.get_bind().raw_connection() + cursor = raw_conn.cursor() + users_output.seek(0) + # Set timeout to an arbitrarily high number since COPY can take a while. + cursor.execute("SET statement_timeout = '999999s'") + cursor.copy_from( + users_output, + "users", + columns=( + "username", + "password_hash", + ), + ) + raw_conn.commit() + + +def seed_topics(db_session: Session) -> None: + """Insert dummy topics into database.""" + print(f"Generating {NUM_TOPICS} topics") + topics_output = StringIO() + csv_writer = csv.writer(topics_output, delimiter="\t") + all_groups = db_session.query(Group).all() + if not all_groups: + raise Exception("At least one group is required when seeding topics") + all_users = db_session.query(User).all() + if not all_users: + raise Exception("At least one user is required when seeding topics") + all_tags = [] + if NUM_TAGS > 0: + for _ in range(NUM_TAGS): + nesting_level = TAG_NESTING_ODDS.get() + tag = ".".join([fake.word() for _ in range(nesting_level)]) + all_tags.append(tag) + if not all_tags: + raise Exception("At least one tag is required for seeding topics") + print_progress(0, NUM_TOPICS) + for i in range(NUM_TOPICS): + num_tags = TAGS_PER_TOPIC_ODDS.get() + topic_type = TOPIC_LINK_ODDS.get() + text_content = fake.sentence() + csv_writer.writerow( + [ + random.choice(all_groups).group_id, + random.choice(all_users).user_id, + topic_type, + " ".join(fake.words(3)), + text_content if topic_type == "TEXT" else None, + f"

{text_content}

" if topic_type == "TEXT" else None, + fake.url() if topic_type == "LINK" else None, + f"{{{','.join([random.choice(all_tags) for _ in range(num_tags)])}}}", + ] + ) + print_progress(i + 1, NUM_TOPICS) + print("Writing topics to database") + raw_conn = db_session.get_bind().raw_connection() + cursor = raw_conn.cursor() + topics_output.seek(0) + # Set timeout to an arbitrarily high number since COPY can take a while. + cursor.execute("SET statement_timeout = '999999s'") + cursor.copy_from( + topics_output, + "topics", + columns=( + "group_id", + "user_id", + "topic_type", + "title", + "markdown", + "rendered_html", + "link", + "tags", + ), + ) + raw_conn.commit() + + +def get_ini_file_path() -> str: + """Get the ini file path from environment, or the default development.ini path.""" + if "INI_FILE" in os.environ: + return os.environ["INI_FILE"] + return os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "development.ini") + ) + + +def confirm_everything_looks_okay(db_session: Session, ini_file_path: str) -> None: + """Ask the user to confirm this database seeding. + + This function calls `sys.exit` if they bail. + """ + num_existing_groups = db_session.query(func.count(Group.group_id)).scalar() + num_existing_topics = db_session.query(func.count(Topic.topic_id)).scalar() + num_existing_users = db_session.query(func.count(User.user_id)).scalar() + print("Confirm this database seeding.") + print("") + print(f" ini file: {ini_file_path}") + print( + f" groups: {NUM_GROUPS} groups will be inserted, " + f"{num_existing_groups} already exist" + ) + print( + f" topics: {NUM_TOPICS} topic swill be inserted with {NUM_TAGS} " + f"possible tags, {num_existing_topics} already exist" + ) + print(f" users: {NUM_USERS} will be inserted, {num_existing_users} already exist") + print("") + answer = input("Proceed with database seeding? [y/N] ") + if not answer.lower().startswith("y"): + sys.exit() + + +def print_progress(num_finished: int, num_total: int) -> None: + """Print the progress for some task in "completed/total" format.""" + print(f"Completed {num_finished}/{num_total}", end="\r") + if num_finished >= num_total: + print("") + print("Done", flush=True) + + +if __name__ == "__main__": + seed_dev_database()