import logging import os import sqlite3 import uuid from datetime import datetime from typing import Optional, Tuple from flask import g from models import RegistrationCode, RegisteredUser DATABASE = os.getenv("APP_DATA_DIRECTORY", ".") + "/data.db" REGISTRATION_CODE_INSERT_SQL = """INSERT INTO registration_codes(code, creationTime, expirationTime, usages, maxUsages) VALUES(?, ?, ?, ?, ?)""" REGISTRATION_CODE_UPDATE_SQL = """UPDATE registration_codes SET expirationTime = ?, usages = ?, maxUsages = ? WHERE code = ?""" REGISTERED_USER_INSERT_SQL = """INSERT INTO registered_users(registrationCode, username, registeredUserId, registeredTime) VALUES(?, ?, ?, ?) """ log = logging.getLogger(__name__) def get_db(): db = getattr(g, '_database', None) if db is None: db = g._database = sqlite3.connect(DATABASE) return db def _query_db(query, args=(), one=False): cur = get_db().execute(query, args) rv = cur.fetchall() cur.close() return (rv[0] if rv else None) if one else rv def get_registration_codes() -> [RegistrationCode]: registration_codes = [] for registration_code in _query_db("SELECT * FROM registration_codes"): registration_codes.append(RegistrationCode.from_db(registration_code)) return registration_codes def get_registration_code(code: str) -> Optional[RegistrationCode]: registration_code: Optional[Tuple] = _query_db( "SELECT * FROM registration_codes WHERE code = ?", [code], one=True) if registration_code is not None: return RegistrationCode.from_db(registration_code) return None def add_registration_code( expiration_time: datetime = None, max_usages: int = 1) -> RegistrationCode: code = RegistrationCode( uuid.uuid4().__str__(), expiration_time=expiration_time, max_usages=max_usages) db = get_db() db.execute( REGISTRATION_CODE_INSERT_SQL, [code.code, code.creation_time, code.expiration_time, code.usages, code.max_usages] ) db.commit() return get_registration_code(code.code) def update_registration_code(code: RegistrationCode) -> RegistrationCode: db = get_db() db.execute(REGISTRATION_CODE_UPDATE_SQL, [code.expiration_time, code.usages, code.max_usages, code.code] ) db.commit() return get_registration_code(code.code) def expire_registration_code(code: str): db = get_db() db.execute( "UPDATE registration_codes SET expirationTime = ? WHERE code = ?", [datetime.now(), code] ) db.commit() def delete_registration_code(code:str): db = get_db() db.execute( "DELETE FROM registration_codes WHERE code = ?", [code] ) db.commit() def get_registered_users() -> [RegisteredUser]: registered_users = [] for registered_user in _query_db("SELECT * FROM registered_users"): registered_users.append(RegisteredUser.from_db(registered_user)) return registered_users def get_registered_user(username: str) -> Optional[RegisteredUser]: registered_user: Optional[Tuple] = _query_db( "SELECT * FROM registered_users WHERE username = ?", [username], one=True) if registered_user is not None: return RegisteredUser.from_db(registered_user) return None def add_registered_user(code: str, username: str, matrix_user_id: str) -> RegisteredUser: registration_code = get_registration_code(code) registration_code.usages += 1 update_registration_code(registration_code) db = get_db() db.execute(REGISTERED_USER_INSERT_SQL, [code, username, matrix_user_id, datetime.now()]) db.commit() return get_registered_user(username)