|
|
@ -17,15 +17,15 @@ insert a ton of junk data. |
|
|
|
import os |
|
|
|
import sys |
|
|
|
import random |
|
|
|
from io import StringIO |
|
|
|
import csv |
|
|
|
from typing import Any, Tuple |
|
|
|
from sqlalchemy import func |
|
|
|
from sqlalchemy.orm import Session |
|
|
|
from sqlalchemy.sql import text |
|
|
|
from faker import Faker |
|
|
|
|
|
|
|
from tildes.lib.database import get_session_from_config |
|
|
|
from tildes.lib.hash import hash_string |
|
|
|
from tildes.lib.markdown import convert_markdown_to_safe_html |
|
|
|
from tildes.models.group import Group |
|
|
|
from tildes.models.topic import Topic |
|
|
|
from tildes.models.user import User |
|
|
@ -105,7 +105,7 @@ def seed_dev_database() -> None: |
|
|
|
confirm_everything_looks_okay(db_session, ini_file_path) |
|
|
|
|
|
|
|
# NOTE: Creating records in bulk using the Tildes models is quite inefficient. |
|
|
|
# Instead, we'll insert using raw SQL and dictionaries. |
|
|
|
# Instead, we'll insert using PostgreSQL's COPY and raw CSV strings. |
|
|
|
if NUM_GROUPS > 0: |
|
|
|
seed_groups(db_session) |
|
|
|
|
|
|
@ -122,56 +122,71 @@ def seed_dev_database() -> None: |
|
|
|
def seed_groups(db_session: Session) -> None: |
|
|
|
"""Insert dummy groups into database.""" |
|
|
|
print(f"Generating {NUM_GROUPS} groups") |
|
|
|
groups_to_insert = [] |
|
|
|
groups_output = StringIO() |
|
|
|
csv_writer = csv.writer(groups_output, delimiter="\t") |
|
|
|
print_progress(0, NUM_GROUPS) |
|
|
|
for _ in range(NUM_GROUPS): |
|
|
|
group = { |
|
|
|
"path": fake.unique.word(), |
|
|
|
"short_description": "automatically generated by seed_dev_database.py", |
|
|
|
} |
|
|
|
groups_to_insert.append(group) |
|
|
|
print_progress(len(groups_to_insert), NUM_GROUPS) |
|
|
|
for i in range(NUM_GROUPS): |
|
|
|
csv_writer.writerow( |
|
|
|
[ |
|
|
|
fake.unique.word(), |
|
|
|
"automatically generated by seed_dev_database.py", |
|
|
|
] |
|
|
|
) |
|
|
|
print_progress(i + 1, NUM_GROUPS) |
|
|
|
print("Writing groups to database") |
|
|
|
statement = text( |
|
|
|
""" |
|
|
|
INSERT INTO groups (path, short_description) |
|
|
|
VALUES (:path, :short_description) |
|
|
|
""" |
|
|
|
raw_conn = db_session.get_bind().raw_connection() |
|
|
|
cursor = raw_conn.cursor() |
|
|
|
groups_output.seek(0) |
|
|
|
# Set timeout to an arbitrarily high number since COPY can take a while. |
|
|
|
cursor.execute("SET statement_timeout = '999999s'") |
|
|
|
cursor.copy_from( |
|
|
|
groups_output, |
|
|
|
"groups", |
|
|
|
columns=( |
|
|
|
"path", |
|
|
|
"short_description", |
|
|
|
), |
|
|
|
) |
|
|
|
for group in groups_to_insert: |
|
|
|
db_session.execute(statement, group) |
|
|
|
db_session.flush() |
|
|
|
raw_conn.commit() |
|
|
|
|
|
|
|
|
|
|
|
def seed_users(db_session: Session) -> None: |
|
|
|
"""Insert dummy users into database.""" |
|
|
|
print(f"Generating {NUM_USERS} users") |
|
|
|
users_to_insert = [] |
|
|
|
users_output = StringIO() |
|
|
|
csv_writer = csv.writer(users_output, delimiter="\t") |
|
|
|
print_progress(0, NUM_USERS) |
|
|
|
password_hash = hash_string("password") |
|
|
|
for _ in range(NUM_USERS): |
|
|
|
user = { |
|
|
|
"username": fake.unique.user_name(), |
|
|
|
"password_hash": password_hash, |
|
|
|
} |
|
|
|
users_to_insert.append(user) |
|
|
|
print_progress(len(users_to_insert), NUM_USERS) |
|
|
|
for i in range(NUM_USERS): |
|
|
|
csv_writer.writerow( |
|
|
|
[ |
|
|
|
fake.unique.user_name(), |
|
|
|
password_hash, |
|
|
|
] |
|
|
|
) |
|
|
|
print_progress(i + 1, NUM_USERS) |
|
|
|
print("Writing users to database") |
|
|
|
statement = text( |
|
|
|
""" |
|
|
|
INSERT INTO users (username, password_hash) |
|
|
|
VALUES (:username, :password_hash) |
|
|
|
""" |
|
|
|
raw_conn = db_session.get_bind().raw_connection() |
|
|
|
cursor = raw_conn.cursor() |
|
|
|
users_output.seek(0) |
|
|
|
# Set timeout to an arbitrarily high number since COPY can take a while. |
|
|
|
cursor.execute("SET statement_timeout = '999999s'") |
|
|
|
cursor.copy_from( |
|
|
|
users_output, |
|
|
|
"users", |
|
|
|
columns=( |
|
|
|
"username", |
|
|
|
"password_hash", |
|
|
|
), |
|
|
|
) |
|
|
|
for user in users_to_insert: |
|
|
|
db_session.execute(statement, user) |
|
|
|
db_session.flush() |
|
|
|
raw_conn.commit() |
|
|
|
|
|
|
|
|
|
|
|
def seed_topics(db_session: Session) -> None: |
|
|
|
"""Insert dummy topics into database.""" |
|
|
|
print(f"Generating {NUM_TOPICS} topics") |
|
|
|
topics_to_insert = [] |
|
|
|
topics_output = StringIO() |
|
|
|
csv_writer = csv.writer(topics_output, delimiter="\t") |
|
|
|
all_groups = db_session.query(Group).all() |
|
|
|
if not all_groups: |
|
|
|
raise Exception("At least one group is required when seeding topics") |
|
|
@ -187,36 +202,44 @@ def seed_topics(db_session: Session) -> None: |
|
|
|
if not all_tags: |
|
|
|
raise Exception("At least one tag is required for seeding topics") |
|
|
|
print_progress(0, NUM_TOPICS) |
|
|
|
for _ in range(NUM_TOPICS): |
|
|
|
for i in range(NUM_TOPICS): |
|
|
|
num_tags = TAGS_PER_TOPIC_ODDS.get() |
|
|
|
topic_type = TOPIC_LINK_ODDS.get() |
|
|
|
markdown = fake.sentence() |
|
|
|
topic = { |
|
|
|
"group_id": random.choice(all_groups).group_id, |
|
|
|
"user_id": random.choice(all_users).user_id, |
|
|
|
"topic_type": topic_type, |
|
|
|
"title": " ".join(fake.words(3)), |
|
|
|
"markdown": markdown if topic_type == "TEXT" else None, |
|
|
|
"rendered_html": convert_markdown_to_safe_html(markdown) |
|
|
|
if topic_type == "TEXT" |
|
|
|
else None, |
|
|
|
"link": fake.url() if topic_type == "LINK" else None, |
|
|
|
"tags": [random.choice(all_tags) for _ in range(num_tags)], |
|
|
|
} |
|
|
|
topics_to_insert.append(topic) |
|
|
|
print_progress(len(topics_to_insert), NUM_TOPICS) |
|
|
|
text_content = fake.sentence() |
|
|
|
csv_writer.writerow( |
|
|
|
[ |
|
|
|
random.choice(all_groups).group_id, |
|
|
|
random.choice(all_users).user_id, |
|
|
|
topic_type, |
|
|
|
" ".join(fake.words(3)), |
|
|
|
text_content if topic_type == "TEXT" else None, |
|
|
|
f"<p>{text_content}</p>" if topic_type == "TEXT" else None, |
|
|
|
fake.url() if topic_type == "LINK" else None, |
|
|
|
f"{{{','.join([random.choice(all_tags) for _ in range(num_tags)])}}}", |
|
|
|
] |
|
|
|
) |
|
|
|
print_progress(i + 1, NUM_TOPICS) |
|
|
|
print("Writing topics to database") |
|
|
|
statement = text( |
|
|
|
r""" |
|
|
|
INSERT INTO topics (group_id, user_id, topic_type, title, markdown, |
|
|
|
rendered_html, link, tags) |
|
|
|
VALUES (:group_id, :user_id, :topic_type, :title, :markdown, :rendered_html, |
|
|
|
:link, :tags\:\:ltree[]) |
|
|
|
""" |
|
|
|
raw_conn = db_session.get_bind().raw_connection() |
|
|
|
cursor = raw_conn.cursor() |
|
|
|
topics_output.seek(0) |
|
|
|
# Set timeout to an arbitrarily high number since COPY can take a while. |
|
|
|
cursor.execute("SET statement_timeout = '999999s'") |
|
|
|
cursor.copy_from( |
|
|
|
topics_output, |
|
|
|
"topics", |
|
|
|
columns=( |
|
|
|
"group_id", |
|
|
|
"user_id", |
|
|
|
"topic_type", |
|
|
|
"title", |
|
|
|
"markdown", |
|
|
|
"rendered_html", |
|
|
|
"link", |
|
|
|
"tags", |
|
|
|
), |
|
|
|
) |
|
|
|
for topic in topics_to_insert: |
|
|
|
db_session.execute(statement, topic) |
|
|
|
db_session.flush() |
|
|
|
raw_conn.commit() |
|
|
|
|
|
|
|
|
|
|
|
def get_ini_file_path() -> str: |
|
|
|