Browse Source

Merge branch 'database-seeding' into 'master'

Add development database seeding script

See merge request tildes/tildes!147
merge-requests/147/merge
Daniel Woznicki 2 months ago
parent
commit
0c2b91e87b
  1. 1
      tildes/requirements-dev.in
  2. 1
      tildes/requirements-dev.txt
  3. 289
      tildes/scripts/seed_dev_database.py

1
tildes/requirements-dev.in

@ -1,5 +1,6 @@
-r requirements.in -r requirements.in
black black
faker
freezegun freezegun
html5validator html5validator
mypy mypy

1
tildes/requirements-dev.txt

@ -15,6 +15,7 @@ click==8.0.1
cornice==5.2.0 cornice==5.2.0
decorator==5.0.9 decorator==5.0.9
dodgy==0.2.1 dodgy==0.2.1
faker==18.11.2
flake8==3.9.2 flake8==3.9.2
flake8-polyfill==1.0.2 flake8-polyfill==1.0.2
freezegun==1.1.0 freezegun==1.1.0

289
tildes/scripts/seed_dev_database.py

@ -0,0 +1,289 @@
# Copyright (c) 2023 Tildes contributors <code@tildes.net>
# SPDX-License-Identifier: AGPL-3.0-or-later
"""Script to seed a dev database with fake data.
If you run this script as is, no data will be inserted. To get this script to do
something, you must modify at least one of the following variables and change the value
to a number greater than zero.
- NUM_GROUPS
- NUM_TOPICS (also NUM_TAGS for this one)
- NUM_USERS
Note: This script is mean to be run on *dev databases only*. Running on production will
insert a ton of junk data.
"""
import os
import sys
import random
from io import StringIO
import csv
from typing import Any, Tuple
from sqlalchemy import func
from sqlalchemy.orm import Session
from faker import Faker
from tildes.lib.database import get_session_from_config
from tildes.lib.hash import hash_string
from tildes.models.group import Group
from tildes.models.topic import Topic
from tildes.models.user import User
class WeightedValues:
"""Mapping of values by weight. Useful for working with semi-random data."""
def __init__(self, default_value: Any, odds: list[Tuple[float, Any]]):
"""Create new weighted values."""
self.default_value = default_value
self.odds = odds
def get(self) -> Any:
"""Get a random value from the odds dict, or the default value.
Values with a higher ratio are more likely be to returned. For example, if this
class instance was constructed as
WeightedValues("dunno", {0.5: "hello", 0.2: "goodbye"})
then this function will have a 50% chance to return "hello", a 20% chance to
return "goodbye", and a remaining 30% chance to return "dunno".
"""
rand = random.random()
num = 0.0
for (ratio, value) in self.odds:
num += ratio
if rand < num:
return value
return self.default_value
# Number of each resource to insert during seeding.
# A value of `0` means no new records will be inserted.
NUM_GROUPS = 0
NUM_TOPICS = 0
NUM_TAGS = 0 # must be greater than 0 when creating topics
NUM_USERS = 0
TAG_NESTING_ODDS = WeightedValues(
1,
[
(0.25, 2), # "a.b"
(0.05, 3), # "a.b.c"
],
)
TOPIC_LINK_ODDS = WeightedValues(
"TEXT",
[
(0.4, "LINK"),
],
)
TAGS_PER_TOPIC_ODDS = WeightedValues(
1,
[
(0.2, 2),
(0.4, 3),
(0.1, 4),
(0.1, 5),
(0.1, 6),
],
)
fake = Faker()
def seed_dev_database() -> None:
"""Seed a dev database with one or more resource.
Note that not all resources are represented yet because I'm lazy.
"""
ini_file_path = get_ini_file_path()
db_session = get_session_from_config(ini_file_path)
# This function might call `sys.exit()` if the user does not okay the seeding.
confirm_everything_looks_okay(db_session, ini_file_path)
# NOTE: Creating records in bulk using the Tildes models is quite inefficient.
# Instead, we'll insert using PostgreSQL's COPY and raw CSV strings.
if NUM_GROUPS > 0:
seed_groups(db_session)
if NUM_USERS > 0:
seed_users(db_session)
if NUM_TOPICS > 0:
seed_topics(db_session)
db_session.commit()
print("Database seeding completed")
def seed_groups(db_session: Session) -> None:
"""Insert dummy groups into database."""
print(f"Generating {NUM_GROUPS} groups")
groups_output = StringIO()
csv_writer = csv.writer(groups_output, delimiter="\t")
print_progress(0, NUM_GROUPS)
for i in range(NUM_GROUPS):
csv_writer.writerow(
[
fake.unique.word(),
"automatically generated by seed_dev_database.py",
]
)
print_progress(i + 1, NUM_GROUPS)
print("Writing groups to database")
raw_conn = db_session.get_bind().raw_connection()
cursor = raw_conn.cursor()
groups_output.seek(0)
# Set timeout to an arbitrarily high number since COPY can take a while.
cursor.execute("SET statement_timeout = '999999s'")
cursor.copy_from(
groups_output,
"groups",
columns=(
"path",
"short_description",
),
)
raw_conn.commit()
def seed_users(db_session: Session) -> None:
"""Insert dummy users into database."""
print(f"Generating {NUM_USERS} users")
users_output = StringIO()
csv_writer = csv.writer(users_output, delimiter="\t")
print_progress(0, NUM_USERS)
password_hash = hash_string("password")
for i in range(NUM_USERS):
csv_writer.writerow(
[
fake.unique.user_name(),
password_hash,
]
)
print_progress(i + 1, NUM_USERS)
print("Writing users to database")
raw_conn = db_session.get_bind().raw_connection()
cursor = raw_conn.cursor()
users_output.seek(0)
# Set timeout to an arbitrarily high number since COPY can take a while.
cursor.execute("SET statement_timeout = '999999s'")
cursor.copy_from(
users_output,
"users",
columns=(
"username",
"password_hash",
),
)
raw_conn.commit()
def seed_topics(db_session: Session) -> None:
"""Insert dummy topics into database."""
print(f"Generating {NUM_TOPICS} topics")
topics_output = StringIO()
csv_writer = csv.writer(topics_output, delimiter="\t")
all_groups = db_session.query(Group).all()
if not all_groups:
raise Exception("At least one group is required when seeding topics")
all_users = db_session.query(User).all()
if not all_users:
raise Exception("At least one user is required when seeding topics")
all_tags = []
if NUM_TAGS > 0:
for _ in range(NUM_TAGS):
nesting_level = TAG_NESTING_ODDS.get()
tag = ".".join([fake.word() for _ in range(nesting_level)])
all_tags.append(tag)
if not all_tags:
raise Exception("At least one tag is required for seeding topics")
print_progress(0, NUM_TOPICS)
for i in range(NUM_TOPICS):
num_tags = TAGS_PER_TOPIC_ODDS.get()
topic_type = TOPIC_LINK_ODDS.get()
text_content = fake.sentence()
csv_writer.writerow(
[
random.choice(all_groups).group_id,
random.choice(all_users).user_id,
topic_type,
" ".join(fake.words(3)),
text_content if topic_type == "TEXT" else None,
f"<p>{text_content}</p>" if topic_type == "TEXT" else None,
fake.url() if topic_type == "LINK" else None,
f"{{{','.join([random.choice(all_tags) for _ in range(num_tags)])}}}",
]
)
print_progress(i + 1, NUM_TOPICS)
print("Writing topics to database")
raw_conn = db_session.get_bind().raw_connection()
cursor = raw_conn.cursor()
topics_output.seek(0)
# Set timeout to an arbitrarily high number since COPY can take a while.
cursor.execute("SET statement_timeout = '999999s'")
cursor.copy_from(
topics_output,
"topics",
columns=(
"group_id",
"user_id",
"topic_type",
"title",
"markdown",
"rendered_html",
"link",
"tags",
),
)
raw_conn.commit()
def get_ini_file_path() -> str:
"""Get the ini file path from environment, or the default development.ini path."""
if "INI_FILE" in os.environ:
return os.environ["INI_FILE"]
return os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "development.ini")
)
def confirm_everything_looks_okay(db_session: Session, ini_file_path: str) -> None:
"""Ask the user to confirm this database seeding.
This function calls `sys.exit` if they bail.
"""
num_existing_groups = db_session.query(func.count(Group.group_id)).scalar()
num_existing_topics = db_session.query(func.count(Topic.topic_id)).scalar()
num_existing_users = db_session.query(func.count(User.user_id)).scalar()
print("Confirm this database seeding.")
print("")
print(f" ini file: {ini_file_path}")
print(
f" groups: {NUM_GROUPS} groups will be inserted, "
f"{num_existing_groups} already exist"
)
print(
f" topics: {NUM_TOPICS} topic swill be inserted with {NUM_TAGS} "
f"possible tags, {num_existing_topics} already exist"
)
print(f" users: {NUM_USERS} will be inserted, {num_existing_users} already exist")
print("")
answer = input("Proceed with database seeding? [y/N] ")
if not answer.lower().startswith("y"):
sys.exit()
def print_progress(num_finished: int, num_total: int) -> None:
"""Print the progress for some task in "completed/total" format."""
print(f"Completed {num_finished}/{num_total}", end="\r")
if num_finished >= num_total:
print("")
print("Done", flush=True)
if __name__ == "__main__":
seed_dev_database()
Loading…
Cancel
Save