mirror of https://gitlab.com/tildes/tildes.git
Browse Source
Added script to seed a development database with dummy data. Currently supports groups, topics, and users.
merge-requests/147/head
Added script to seed a development database with dummy data. Currently supports groups, topics, and users.
merge-requests/147/head
3 changed files with 268 additions and 0 deletions
@ -0,0 +1,266 @@ |
|||
# Copyright (c) 2023 Tildes contributors <code@tildes.net> |
|||
# SPDX-License-Identifier: AGPL-3.0-or-later |
|||
|
|||
"""Script to seed a dev database with fake data. |
|||
|
|||
If you run this script as is, no data will be inserted. To get this script to do |
|||
something, you must modify at least one of the following variables and change the value |
|||
to a number greater than zero. |
|||
- NUM_GROUPS |
|||
- NUM_TOPICS (also NUM_TAGS for this one) |
|||
- NUM_USERS |
|||
|
|||
Note: This script is mean to be run on *dev databases only*. Running on production will |
|||
insert a ton of junk data. |
|||
""" |
|||
|
|||
import os |
|||
import sys |
|||
import random |
|||
from typing import Any, Tuple |
|||
from sqlalchemy import func |
|||
from sqlalchemy.orm import Session |
|||
from sqlalchemy.sql import text |
|||
from faker import Faker |
|||
|
|||
from tildes.lib.database import get_session_from_config |
|||
from tildes.lib.hash import hash_string |
|||
from tildes.lib.markdown import convert_markdown_to_safe_html |
|||
from tildes.models.group import Group |
|||
from tildes.models.topic import Topic |
|||
from tildes.models.user import User |
|||
|
|||
|
|||
class WeightedValues: |
|||
"""Mapping of values by weight. Useful for working with semi-random data.""" |
|||
|
|||
def __init__(self, default_value: Any, odds: list[Tuple[float, Any]]): |
|||
"""Create new weighted values.""" |
|||
self.default_value = default_value |
|||
self.odds = odds |
|||
|
|||
def get(self) -> Any: |
|||
"""Get a random value from the odds dict, or the default value. |
|||
|
|||
Values with a higher ratio are more likely be to returned. For example, if this |
|||
class instance was constructed as |
|||
|
|||
WeightedValues("dunno", {0.5: "hello", 0.2: "goodbye"}) |
|||
|
|||
then this function will have a 50% chance to return "hello", a 20% chance to |
|||
return "goodbye", and a remaining 30% chance to return "dunno". |
|||
""" |
|||
rand = random.random() |
|||
num = 0.0 |
|||
for (ratio, value) in self.odds: |
|||
num += ratio |
|||
if rand < num: |
|||
return value |
|||
return self.default_value |
|||
|
|||
|
|||
# Number of each resource to insert during seeding. |
|||
# A value of `0` means no new records will be inserted. |
|||
NUM_GROUPS = 0 |
|||
NUM_TOPICS = 0 |
|||
NUM_TAGS = 0 # must be greater than 0 when creating topics |
|||
NUM_USERS = 0 |
|||
|
|||
TAG_NESTING_ODDS = WeightedValues( |
|||
1, |
|||
[ |
|||
(0.25, 2), # "a.b" |
|||
(0.05, 3), # "a.b.c" |
|||
], |
|||
) |
|||
TOPIC_LINK_ODDS = WeightedValues( |
|||
"TEXT", |
|||
[ |
|||
(0.4, "LINK"), |
|||
], |
|||
) |
|||
TAGS_PER_TOPIC_ODDS = WeightedValues( |
|||
1, |
|||
[ |
|||
(0.2, 2), |
|||
(0.4, 3), |
|||
(0.1, 4), |
|||
(0.1, 5), |
|||
(0.1, 6), |
|||
], |
|||
) |
|||
|
|||
fake = Faker() |
|||
|
|||
|
|||
def seed_dev_database() -> None: |
|||
"""Seed a dev database with one or more resource. |
|||
|
|||
Note that not all resources are represented yet because I'm lazy. |
|||
""" |
|||
ini_file_path = get_ini_file_path() |
|||
db_session = get_session_from_config(ini_file_path) |
|||
|
|||
# This function might call `sys.exit()` if the user does not okay the seeding. |
|||
confirm_everything_looks_okay(db_session, ini_file_path) |
|||
|
|||
# NOTE: Creating records in bulk using the Tildes models is quite inefficient. |
|||
# Instead, we'll insert using raw SQL and dictionaries. |
|||
if NUM_GROUPS > 0: |
|||
seed_groups(db_session) |
|||
|
|||
if NUM_USERS > 0: |
|||
seed_users(db_session) |
|||
|
|||
if NUM_TOPICS > 0: |
|||
seed_topics(db_session) |
|||
|
|||
db_session.commit() |
|||
print("Database seeding completed") |
|||
|
|||
|
|||
def seed_groups(db_session: Session) -> None: |
|||
"""Insert dummy groups into database.""" |
|||
print(f"Generating {NUM_GROUPS} groups") |
|||
groups_to_insert = [] |
|||
print_progress(0, NUM_GROUPS) |
|||
for _ in range(NUM_GROUPS): |
|||
group = { |
|||
"path": fake.unique.word(), |
|||
"short_description": "automatically generated by seed_dev_database.py", |
|||
} |
|||
groups_to_insert.append(group) |
|||
print_progress(len(groups_to_insert), NUM_GROUPS) |
|||
print("Writing groups to database") |
|||
statement = text( |
|||
""" |
|||
INSERT INTO groups (path, short_description) |
|||
VALUES (:path, :short_description) |
|||
""" |
|||
) |
|||
for group in groups_to_insert: |
|||
db_session.execute(statement, group) |
|||
db_session.flush() |
|||
|
|||
|
|||
def seed_users(db_session: Session) -> None: |
|||
"""Insert dummy users into database.""" |
|||
print(f"Generating {NUM_USERS} users") |
|||
users_to_insert = [] |
|||
print_progress(0, NUM_USERS) |
|||
password_hash = hash_string("password") |
|||
for _ in range(NUM_USERS): |
|||
user = { |
|||
"username": fake.unique.user_name(), |
|||
"password_hash": password_hash, |
|||
} |
|||
users_to_insert.append(user) |
|||
print_progress(len(users_to_insert), NUM_USERS) |
|||
print("Writing users to database") |
|||
statement = text( |
|||
""" |
|||
INSERT INTO users (username, password_hash) |
|||
VALUES (:username, :password_hash) |
|||
""" |
|||
) |
|||
for user in users_to_insert: |
|||
db_session.execute(statement, user) |
|||
db_session.flush() |
|||
|
|||
|
|||
def seed_topics(db_session: Session) -> None: |
|||
"""Insert dummy topics into database.""" |
|||
print(f"Generating {NUM_TOPICS} topics") |
|||
topics_to_insert = [] |
|||
all_groups = db_session.query(Group).all() |
|||
if not all_groups: |
|||
raise Exception("At least one group is required when seeding topics") |
|||
all_users = db_session.query(User).all() |
|||
if not all_users: |
|||
raise Exception("At least one user is required when seeding topics") |
|||
all_tags = [] |
|||
if NUM_TAGS > 0: |
|||
for _ in range(NUM_TAGS): |
|||
nesting_level = TAG_NESTING_ODDS.get() |
|||
tag = ".".join([fake.word() for _ in range(nesting_level)]) |
|||
all_tags.append(tag) |
|||
if not all_tags: |
|||
raise Exception("At least one tag is required for seeding topics") |
|||
print_progress(0, NUM_TOPICS) |
|||
for _ in range(NUM_TOPICS): |
|||
num_tags = TAGS_PER_TOPIC_ODDS.get() |
|||
topic_type = TOPIC_LINK_ODDS.get() |
|||
markdown = fake.sentence() |
|||
topic = { |
|||
"group_id": random.choice(all_groups).group_id, |
|||
"user_id": random.choice(all_users).user_id, |
|||
"topic_type": topic_type, |
|||
"title": " ".join(fake.words(3)), |
|||
"markdown": markdown if topic_type == "TEXT" else None, |
|||
"rendered_html": convert_markdown_to_safe_html(markdown) |
|||
if topic_type == "TEXT" |
|||
else None, |
|||
"link": fake.url() if topic_type == "LINK" else None, |
|||
"tags": [random.choice(all_tags) for _ in range(num_tags)], |
|||
} |
|||
topics_to_insert.append(topic) |
|||
print_progress(len(topics_to_insert), NUM_TOPICS) |
|||
print("Writing topics to database") |
|||
statement = text( |
|||
r""" |
|||
INSERT INTO topics (group_id, user_id, topic_type, title, markdown, |
|||
rendered_html, link, tags) |
|||
VALUES (:group_id, :user_id, :topic_type, :title, :markdown, :rendered_html, |
|||
:link, :tags\:\:ltree[]) |
|||
""" |
|||
) |
|||
for topic in topics_to_insert: |
|||
db_session.execute(statement, topic) |
|||
db_session.flush() |
|||
|
|||
|
|||
def get_ini_file_path() -> str: |
|||
"""Get the ini file path from environment, or the default development.ini path.""" |
|||
if "INI_FILE" in os.environ: |
|||
return os.environ["INI_FILE"] |
|||
return os.path.abspath( |
|||
os.path.join(os.path.dirname(__file__), "..", "development.ini") |
|||
) |
|||
|
|||
|
|||
def confirm_everything_looks_okay(db_session: Session, ini_file_path: str) -> None: |
|||
"""Ask the user to confirm this database seeding. |
|||
|
|||
This function calls `sys.exit` if they bail. |
|||
""" |
|||
num_existing_groups = db_session.query(func.count(Group.group_id)).scalar() |
|||
num_existing_topics = db_session.query(func.count(Topic.topic_id)).scalar() |
|||
num_existing_users = db_session.query(func.count(User.user_id)).scalar() |
|||
print("Confirm this database seeding.") |
|||
print("") |
|||
print(f" ini file: {ini_file_path}") |
|||
print( |
|||
f" groups: {NUM_GROUPS} groups will be inserted, " |
|||
f"{num_existing_groups} already exist" |
|||
) |
|||
print( |
|||
f" topics: {NUM_TOPICS} topic swill be inserted with {NUM_TAGS} " |
|||
f"possible tags, {num_existing_topics} already exist" |
|||
) |
|||
print(f" users: {NUM_USERS} will be inserted, {num_existing_users} already exist") |
|||
print("") |
|||
answer = input("Proceed with database seeding? [y/N] ") |
|||
if not answer.lower().startswith("y"): |
|||
sys.exit() |
|||
|
|||
|
|||
def print_progress(num_finished: int, num_total: int) -> None: |
|||
"""Print the progress for some task in "completed/total" format.""" |
|||
print(f"Completed {num_finished}/{num_total}", end="\r") |
|||
if num_finished >= num_total: |
|||
print("") |
|||
print("Done", flush=True) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
seed_dev_database() |
Write
Preview
Loading…
Cancel
Save
Reference in new issue