mirror of https://gitlab.com/tildes/tildes.git
Browse Source
Merge branch 'database-seeding' into 'master'
Merge branch 'database-seeding' into 'master'
Add development database seeding script See merge request tildes/tildes!147merge-requests/147/merge
3 changed files with 291 additions and 0 deletions
@ -0,0 +1,289 @@ |
|||||
|
# Copyright (c) 2023 Tildes contributors <code@tildes.net> |
||||
|
# SPDX-License-Identifier: AGPL-3.0-or-later |
||||
|
|
||||
|
"""Script to seed a dev database with fake data. |
||||
|
|
||||
|
If you run this script as is, no data will be inserted. To get this script to do |
||||
|
something, you must modify at least one of the following variables and change the value |
||||
|
to a number greater than zero. |
||||
|
- NUM_GROUPS |
||||
|
- NUM_TOPICS (also NUM_TAGS for this one) |
||||
|
- NUM_USERS |
||||
|
|
||||
|
Note: This script is mean to be run on *dev databases only*. Running on production will |
||||
|
insert a ton of junk data. |
||||
|
""" |
||||
|
|
||||
|
import os |
||||
|
import sys |
||||
|
import random |
||||
|
from io import StringIO |
||||
|
import csv |
||||
|
from typing import Any, Tuple |
||||
|
from sqlalchemy import func |
||||
|
from sqlalchemy.orm import Session |
||||
|
from faker import Faker |
||||
|
|
||||
|
from tildes.lib.database import get_session_from_config |
||||
|
from tildes.lib.hash import hash_string |
||||
|
from tildes.models.group import Group |
||||
|
from tildes.models.topic import Topic |
||||
|
from tildes.models.user import User |
||||
|
|
||||
|
|
||||
|
class WeightedValues: |
||||
|
"""Mapping of values by weight. Useful for working with semi-random data.""" |
||||
|
|
||||
|
def __init__(self, default_value: Any, odds: list[Tuple[float, Any]]): |
||||
|
"""Create new weighted values.""" |
||||
|
self.default_value = default_value |
||||
|
self.odds = odds |
||||
|
|
||||
|
def get(self) -> Any: |
||||
|
"""Get a random value from the odds dict, or the default value. |
||||
|
|
||||
|
Values with a higher ratio are more likely be to returned. For example, if this |
||||
|
class instance was constructed as |
||||
|
|
||||
|
WeightedValues("dunno", {0.5: "hello", 0.2: "goodbye"}) |
||||
|
|
||||
|
then this function will have a 50% chance to return "hello", a 20% chance to |
||||
|
return "goodbye", and a remaining 30% chance to return "dunno". |
||||
|
""" |
||||
|
rand = random.random() |
||||
|
num = 0.0 |
||||
|
for (ratio, value) in self.odds: |
||||
|
num += ratio |
||||
|
if rand < num: |
||||
|
return value |
||||
|
return self.default_value |
||||
|
|
||||
|
|
||||
|
# Number of each resource to insert during seeding. |
||||
|
# A value of `0` means no new records will be inserted. |
||||
|
NUM_GROUPS = 0 |
||||
|
NUM_TOPICS = 0 |
||||
|
NUM_TAGS = 0 # must be greater than 0 when creating topics |
||||
|
NUM_USERS = 0 |
||||
|
|
||||
|
TAG_NESTING_ODDS = WeightedValues( |
||||
|
1, |
||||
|
[ |
||||
|
(0.25, 2), # "a.b" |
||||
|
(0.05, 3), # "a.b.c" |
||||
|
], |
||||
|
) |
||||
|
TOPIC_LINK_ODDS = WeightedValues( |
||||
|
"TEXT", |
||||
|
[ |
||||
|
(0.4, "LINK"), |
||||
|
], |
||||
|
) |
||||
|
TAGS_PER_TOPIC_ODDS = WeightedValues( |
||||
|
1, |
||||
|
[ |
||||
|
(0.2, 2), |
||||
|
(0.4, 3), |
||||
|
(0.1, 4), |
||||
|
(0.1, 5), |
||||
|
(0.1, 6), |
||||
|
], |
||||
|
) |
||||
|
|
||||
|
fake = Faker() |
||||
|
|
||||
|
|
||||
|
def seed_dev_database() -> None: |
||||
|
"""Seed a dev database with one or more resource. |
||||
|
|
||||
|
Note that not all resources are represented yet because I'm lazy. |
||||
|
""" |
||||
|
ini_file_path = get_ini_file_path() |
||||
|
db_session = get_session_from_config(ini_file_path) |
||||
|
|
||||
|
# This function might call `sys.exit()` if the user does not okay the seeding. |
||||
|
confirm_everything_looks_okay(db_session, ini_file_path) |
||||
|
|
||||
|
# NOTE: Creating records in bulk using the Tildes models is quite inefficient. |
||||
|
# Instead, we'll insert using PostgreSQL's COPY and raw CSV strings. |
||||
|
if NUM_GROUPS > 0: |
||||
|
seed_groups(db_session) |
||||
|
|
||||
|
if NUM_USERS > 0: |
||||
|
seed_users(db_session) |
||||
|
|
||||
|
if NUM_TOPICS > 0: |
||||
|
seed_topics(db_session) |
||||
|
|
||||
|
db_session.commit() |
||||
|
print("Database seeding completed") |
||||
|
|
||||
|
|
||||
|
def seed_groups(db_session: Session) -> None: |
||||
|
"""Insert dummy groups into database.""" |
||||
|
print(f"Generating {NUM_GROUPS} groups") |
||||
|
groups_output = StringIO() |
||||
|
csv_writer = csv.writer(groups_output, delimiter="\t") |
||||
|
print_progress(0, NUM_GROUPS) |
||||
|
for i in range(NUM_GROUPS): |
||||
|
csv_writer.writerow( |
||||
|
[ |
||||
|
fake.unique.word(), |
||||
|
"automatically generated by seed_dev_database.py", |
||||
|
] |
||||
|
) |
||||
|
print_progress(i + 1, NUM_GROUPS) |
||||
|
print("Writing groups to database") |
||||
|
raw_conn = db_session.get_bind().raw_connection() |
||||
|
cursor = raw_conn.cursor() |
||||
|
groups_output.seek(0) |
||||
|
# Set timeout to an arbitrarily high number since COPY can take a while. |
||||
|
cursor.execute("SET statement_timeout = '999999s'") |
||||
|
cursor.copy_from( |
||||
|
groups_output, |
||||
|
"groups", |
||||
|
columns=( |
||||
|
"path", |
||||
|
"short_description", |
||||
|
), |
||||
|
) |
||||
|
raw_conn.commit() |
||||
|
|
||||
|
|
||||
|
def seed_users(db_session: Session) -> None: |
||||
|
"""Insert dummy users into database.""" |
||||
|
print(f"Generating {NUM_USERS} users") |
||||
|
users_output = StringIO() |
||||
|
csv_writer = csv.writer(users_output, delimiter="\t") |
||||
|
print_progress(0, NUM_USERS) |
||||
|
password_hash = hash_string("password") |
||||
|
for i in range(NUM_USERS): |
||||
|
csv_writer.writerow( |
||||
|
[ |
||||
|
fake.unique.user_name(), |
||||
|
password_hash, |
||||
|
] |
||||
|
) |
||||
|
print_progress(i + 1, NUM_USERS) |
||||
|
print("Writing users to database") |
||||
|
raw_conn = db_session.get_bind().raw_connection() |
||||
|
cursor = raw_conn.cursor() |
||||
|
users_output.seek(0) |
||||
|
# Set timeout to an arbitrarily high number since COPY can take a while. |
||||
|
cursor.execute("SET statement_timeout = '999999s'") |
||||
|
cursor.copy_from( |
||||
|
users_output, |
||||
|
"users", |
||||
|
columns=( |
||||
|
"username", |
||||
|
"password_hash", |
||||
|
), |
||||
|
) |
||||
|
raw_conn.commit() |
||||
|
|
||||
|
|
||||
|
def seed_topics(db_session: Session) -> None: |
||||
|
"""Insert dummy topics into database.""" |
||||
|
print(f"Generating {NUM_TOPICS} topics") |
||||
|
topics_output = StringIO() |
||||
|
csv_writer = csv.writer(topics_output, delimiter="\t") |
||||
|
all_groups = db_session.query(Group).all() |
||||
|
if not all_groups: |
||||
|
raise Exception("At least one group is required when seeding topics") |
||||
|
all_users = db_session.query(User).all() |
||||
|
if not all_users: |
||||
|
raise Exception("At least one user is required when seeding topics") |
||||
|
all_tags = [] |
||||
|
if NUM_TAGS > 0: |
||||
|
for _ in range(NUM_TAGS): |
||||
|
nesting_level = TAG_NESTING_ODDS.get() |
||||
|
tag = ".".join([fake.word() for _ in range(nesting_level)]) |
||||
|
all_tags.append(tag) |
||||
|
if not all_tags: |
||||
|
raise Exception("At least one tag is required for seeding topics") |
||||
|
print_progress(0, NUM_TOPICS) |
||||
|
for i in range(NUM_TOPICS): |
||||
|
num_tags = TAGS_PER_TOPIC_ODDS.get() |
||||
|
topic_type = TOPIC_LINK_ODDS.get() |
||||
|
text_content = fake.sentence() |
||||
|
csv_writer.writerow( |
||||
|
[ |
||||
|
random.choice(all_groups).group_id, |
||||
|
random.choice(all_users).user_id, |
||||
|
topic_type, |
||||
|
" ".join(fake.words(3)), |
||||
|
text_content if topic_type == "TEXT" else None, |
||||
|
f"<p>{text_content}</p>" if topic_type == "TEXT" else None, |
||||
|
fake.url() if topic_type == "LINK" else None, |
||||
|
f"{{{','.join([random.choice(all_tags) for _ in range(num_tags)])}}}", |
||||
|
] |
||||
|
) |
||||
|
print_progress(i + 1, NUM_TOPICS) |
||||
|
print("Writing topics to database") |
||||
|
raw_conn = db_session.get_bind().raw_connection() |
||||
|
cursor = raw_conn.cursor() |
||||
|
topics_output.seek(0) |
||||
|
# Set timeout to an arbitrarily high number since COPY can take a while. |
||||
|
cursor.execute("SET statement_timeout = '999999s'") |
||||
|
cursor.copy_from( |
||||
|
topics_output, |
||||
|
"topics", |
||||
|
columns=( |
||||
|
"group_id", |
||||
|
"user_id", |
||||
|
"topic_type", |
||||
|
"title", |
||||
|
"markdown", |
||||
|
"rendered_html", |
||||
|
"link", |
||||
|
"tags", |
||||
|
), |
||||
|
) |
||||
|
raw_conn.commit() |
||||
|
|
||||
|
|
||||
|
def get_ini_file_path() -> str: |
||||
|
"""Get the ini file path from environment, or the default development.ini path.""" |
||||
|
if "INI_FILE" in os.environ: |
||||
|
return os.environ["INI_FILE"] |
||||
|
return os.path.abspath( |
||||
|
os.path.join(os.path.dirname(__file__), "..", "development.ini") |
||||
|
) |
||||
|
|
||||
|
|
||||
|
def confirm_everything_looks_okay(db_session: Session, ini_file_path: str) -> None: |
||||
|
"""Ask the user to confirm this database seeding. |
||||
|
|
||||
|
This function calls `sys.exit` if they bail. |
||||
|
""" |
||||
|
num_existing_groups = db_session.query(func.count(Group.group_id)).scalar() |
||||
|
num_existing_topics = db_session.query(func.count(Topic.topic_id)).scalar() |
||||
|
num_existing_users = db_session.query(func.count(User.user_id)).scalar() |
||||
|
print("Confirm this database seeding.") |
||||
|
print("") |
||||
|
print(f" ini file: {ini_file_path}") |
||||
|
print( |
||||
|
f" groups: {NUM_GROUPS} groups will be inserted, " |
||||
|
f"{num_existing_groups} already exist" |
||||
|
) |
||||
|
print( |
||||
|
f" topics: {NUM_TOPICS} topic swill be inserted with {NUM_TAGS} " |
||||
|
f"possible tags, {num_existing_topics} already exist" |
||||
|
) |
||||
|
print(f" users: {NUM_USERS} will be inserted, {num_existing_users} already exist") |
||||
|
print("") |
||||
|
answer = input("Proceed with database seeding? [y/N] ") |
||||
|
if not answer.lower().startswith("y"): |
||||
|
sys.exit() |
||||
|
|
||||
|
|
||||
|
def print_progress(num_finished: int, num_total: int) -> None: |
||||
|
"""Print the progress for some task in "completed/total" format.""" |
||||
|
print(f"Completed {num_finished}/{num_total}", end="\r") |
||||
|
if num_finished >= num_total: |
||||
|
print("") |
||||
|
print("Done", flush=True) |
||||
|
|
||||
|
|
||||
|
if __name__ == "__main__": |
||||
|
seed_dev_database() |
Write
Preview
Loading…
Cancel
Save
Reference in new issue