mirror of https://gitlab.com/tildes/tildes.git
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
81 lines
2.5 KiB
81 lines
2.5 KiB
"""Script for doing the initial setup of database tables."""
|
|
|
|
import os
|
|
import subprocess
|
|
from typing import Optional
|
|
|
|
from alembic import command
|
|
from alembic.config import Config
|
|
from sqlalchemy.engine import Connectable, Engine
|
|
|
|
from tildes.lib.database import get_session_from_config
|
|
from tildes.models import DatabaseModel
|
|
from tildes.models.group import Group
|
|
from tildes.models.log import Log
|
|
from tildes.models.user import User
|
|
|
|
|
|
def initialize_db(config_path: str, alembic_config_path: Optional[str] = None) -> None:
|
|
"""Load the app config and create the database tables."""
|
|
db_session = get_session_from_config(config_path)
|
|
engine = db_session.bind
|
|
|
|
create_tables(engine)
|
|
|
|
run_sql_scripts_in_dir("sql/init/", engine)
|
|
|
|
# if an Alembic config file wasn't specified, assume it's alembic.ini in the same
|
|
# directory
|
|
if not alembic_config_path:
|
|
path = os.path.split(config_path)[0]
|
|
alembic_config_path = os.path.join(path, "alembic.ini")
|
|
|
|
# mark current Alembic revision in db so migrations start from this point
|
|
alembic_cfg = Config(alembic_config_path)
|
|
command.stamp(alembic_cfg, "head")
|
|
|
|
|
|
def create_tables(connectable: Connectable) -> None:
|
|
"""Create the database tables."""
|
|
# tables to skip (due to inheritance or other need to create manually)
|
|
excluded_tables = Log.INHERITED_TABLES
|
|
|
|
tables = [
|
|
table
|
|
for table in DatabaseModel.metadata.tables.values()
|
|
if table.name not in excluded_tables
|
|
]
|
|
DatabaseModel.metadata.create_all(connectable, tables=tables)
|
|
|
|
|
|
def run_sql_scripts_in_dir(path: str, engine: Engine) -> None:
|
|
"""Run all sql scripts in a directory."""
|
|
for root, _, files in os.walk(path):
|
|
sql_files = [filename for filename in files if filename.endswith(".sql")]
|
|
for sql_file in sql_files:
|
|
subprocess.call(
|
|
[
|
|
"psql",
|
|
"-U",
|
|
engine.url.username,
|
|
"-f",
|
|
os.path.join(root, sql_file),
|
|
engine.url.database,
|
|
]
|
|
)
|
|
|
|
|
|
def insert_dev_data(config_path: str) -> None:
|
|
"""Load the app config and insert some "starter" data for a dev version."""
|
|
session = get_session_from_config(config_path)
|
|
|
|
session.add_all(
|
|
[
|
|
User("TestUser", "password"),
|
|
Group(
|
|
"testing", "An automatically created group to use for testing purposes"
|
|
),
|
|
]
|
|
)
|
|
|
|
session.commit()
|