You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
132 lines
3.7 KiB
132 lines
3.7 KiB
import logging
|
|
import os
|
|
import sqlite3
|
|
import uuid
|
|
from datetime import datetime
|
|
from typing import Optional, Tuple
|
|
|
|
from flask import g
|
|
|
|
from models import RegistrationCode, RegisteredUser
|
|
|
|
DATABASE = os.getenv("APP_DATA_DIRECTORY", ".") + "/data.db"
|
|
|
|
|
|
REGISTRATION_CODE_INSERT_SQL = """INSERT INTO
|
|
registration_codes(code, creationTime, expirationTime, usages, maxUsages)
|
|
VALUES(?, ?, ?, ?, ?)"""
|
|
|
|
REGISTRATION_CODE_UPDATE_SQL = """UPDATE registration_codes
|
|
SET expirationTime = ?, usages = ?, maxUsages = ?
|
|
WHERE code = ?"""
|
|
|
|
REGISTERED_USER_INSERT_SQL = """INSERT INTO
|
|
registered_users(registrationCode, username, registeredUserId, registeredTime)
|
|
VALUES(?, ?, ?, ?)
|
|
"""
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def get_db():
|
|
db = getattr(g, '_database', None)
|
|
if db is None:
|
|
log.info("Using database at: %s" % DATABASE)
|
|
db = g._database = sqlite3.connect(DATABASE)
|
|
return db
|
|
|
|
|
|
def _query_db(query, args=(), one=False):
|
|
cur = get_db().execute(query, args)
|
|
rv = cur.fetchall()
|
|
cur.close()
|
|
return (rv[0] if rv else None) if one else rv
|
|
|
|
|
|
def get_registration_codes() -> [RegistrationCode]:
|
|
registration_codes = []
|
|
for registration_code in _query_db("SELECT * FROM registration_codes"):
|
|
registration_codes.append(RegistrationCode.from_db(registration_code))
|
|
return registration_codes
|
|
|
|
|
|
def get_registration_code(code: str) -> Optional[RegistrationCode]:
|
|
registration_code: Optional[Tuple] = _query_db(
|
|
"SELECT * FROM registration_codes WHERE code = ?", [code], one=True)
|
|
if registration_code is not None:
|
|
return RegistrationCode.from_db(registration_code)
|
|
return None
|
|
|
|
|
|
def add_registration_code(
|
|
expiration_time: datetime = None,
|
|
max_usages: int = 1) -> RegistrationCode:
|
|
code = RegistrationCode(
|
|
uuid.uuid4().__str__(), expiration_time=expiration_time, max_usages=max_usages)
|
|
db = get_db()
|
|
db.execute(
|
|
REGISTRATION_CODE_INSERT_SQL,
|
|
[code.code,
|
|
code.creation_time,
|
|
code.expiration_time,
|
|
code.usages,
|
|
code.max_usages]
|
|
)
|
|
db.commit()
|
|
return get_registration_code(code.code)
|
|
|
|
|
|
def update_registration_code(code: RegistrationCode) -> RegistrationCode:
|
|
db = get_db()
|
|
db.execute(REGISTRATION_CODE_UPDATE_SQL,
|
|
[code.expiration_time, code.usages, code.max_usages, code.code]
|
|
)
|
|
db.commit()
|
|
return get_registration_code(code.code)
|
|
|
|
|
|
def expire_registration_code(code: str):
|
|
db = get_db()
|
|
db.execute(
|
|
"UPDATE registration_codes SET expirationTime = ? WHERE code = ?",
|
|
[datetime.now(), code]
|
|
)
|
|
db.commit()
|
|
|
|
|
|
def delete_registration_code(code:str):
|
|
db = get_db()
|
|
db.execute(
|
|
"DELETE FROM registration_codes WHERE code = ?",
|
|
[code]
|
|
)
|
|
db.commit()
|
|
|
|
|
|
def get_registered_users() -> [RegisteredUser]:
|
|
registered_users = []
|
|
for registered_user in _query_db("SELECT * FROM registered_users"):
|
|
registered_users.append(RegisteredUser.from_db(registered_user))
|
|
return registered_users
|
|
|
|
|
|
def get_registered_user(username: str) -> Optional[RegisteredUser]:
|
|
registered_user: Optional[Tuple] = _query_db(
|
|
"SELECT * FROM registered_users WHERE username = ?", [username], one=True)
|
|
if registered_user is not None:
|
|
return RegisteredUser.from_db(registered_user)
|
|
return None
|
|
|
|
|
|
def add_registered_user(code: str, username: str,
|
|
matrix_user_id: str) -> RegisteredUser:
|
|
registration_code = get_registration_code(code)
|
|
registration_code.usages += 1
|
|
update_registration_code(registration_code)
|
|
db = get_db()
|
|
db.execute(REGISTERED_USER_INSERT_SQL,
|
|
[code, username, matrix_user_id, datetime.now()])
|
|
db.commit()
|
|
return get_registered_user(username)
|
|
|
|
|