From 8b6b069617e7485bae230946b3eb1cfaccb8fe53 Mon Sep 17 00:00:00 2001 From: Drew Short Date: Wed, 30 Jan 2019 22:23:12 -0600 Subject: [PATCH] Added column to record the user matrix_id on successful registration --- app.py | 13 +++++++++---- db.py | 10 ++++++---- models.py | 18 ++++++++++++++---- register_new_matrix_user.py | 33 ++++++++++++++++++--------------- schema.sql | 11 +++++++---- 5 files changed, 54 insertions(+), 31 deletions(-) diff --git a/app.py b/app.py index 4cca343..9081486 100644 --- a/app.py +++ b/app.py @@ -6,7 +6,8 @@ from urllib.parse import urlparse, urljoin import flask from flask import Flask, redirect, render_template, request, g, flash, url_for -from flask_login import LoginManager, login_required, login_user, logout_user, UserMixin +from flask_login import LoginManager, login_required, login_user, logout_user, \ + UserMixin from flask_wtf import CSRFProtect from db import get_db, get_registration_codes, add_registration_code, \ @@ -145,14 +146,18 @@ def registration(): flash("Registration Configuration Is Invalid. Contact Administrator!") return render_template('register.html', form=form) else: - successful = register_new_user( + response = register_new_user( form.username.data, form.password.data, app.config.get("MATRIX_HOMESERVER"), app.config.get("MATRIX_SHARED_SECRET") ) - if successful: - add_registered_user(form.registration_code.data, form.username.data) + if response is not None: + registered_user_id = response['user_id'] + add_registered_user( + form.registration_code.data, + form.username.data, + registered_user_id) return redirect(get_successful_registration_redirect()) else: flash("Registration Failure. Contact Administrator!") diff --git a/db.py b/db.py index c0dc682..c26903e 100644 --- a/db.py +++ b/db.py @@ -21,8 +21,8 @@ SET expirationTime = ?, usages = ?, maxUsages = ? WHERE code = ?""" REGISTERED_USER_INSERT_SQL = """INSERT INTO -registered_users(registrationCode, username, registeredTime) -VALUES(?, ?, ?) +registered_users(registrationCode, username, registeredUserId, registeredTime) +VALUES(?, ?, ?, ?) """ log = logging.getLogger(__name__) @@ -118,12 +118,14 @@ def get_registered_user(username: str) -> Optional[RegisteredUser]: return None -def add_registered_user(code: str, username: str) -> RegisteredUser: +def add_registered_user(code: str, username: str, + matrix_user_id: str) -> RegisteredUser: registration_code = get_registration_code(code) registration_code.usages += 1 update_registration_code(registration_code) db = get_db() - db.execute(REGISTERED_USER_INSERT_SQL, [code, username, datetime.now()]) + db.execute(REGISTERED_USER_INSERT_SQL, + [code, username, matrix_user_id, datetime.now()]) db.commit() return get_registered_user(username) diff --git a/models.py b/models.py index 1018846..b56ea39 100644 --- a/models.py +++ b/models.py @@ -2,6 +2,14 @@ from datetime import datetime from typing import Tuple +def get_time_or_now(time): + return time if time is not None else datetime.now() + + +def parse_optional_time(time): + return None if time is None else datetime.fromisoformat(time) + + class RegistrationCode: def __init__(self, code: str, @@ -10,18 +18,17 @@ class RegistrationCode: usages: int = 0, max_usages: int = 1): self.code = code - self.creation_time = creation_time if creation_time is not None else datetime.now() + self.creation_time = get_time_or_now(creation_time) self.expiration_time = expiration_time self.usages = usages self.max_usages = max_usages @staticmethod def from_db(db_registration_code: Tuple) -> 'RegistrationCode': - expiration_time = None if db_registration_code[2] is None else datetime.fromisoformat(db_registration_code[2]) return RegistrationCode( db_registration_code[0], datetime.fromisoformat(db_registration_code[1]), - expiration_time, + parse_optional_time(db_registration_code[2]), db_registration_code[3], db_registration_code[4] ) @@ -37,15 +44,18 @@ class RegisteredUser: def __init__(self, registration_code: str, username: str, + registered_user_id: str, registered_time: datetime = None): self.registration_code = registration_code self.username = username - self.registered_time = registered_time if registered_time is not None else datetime.now() + self.registered_user_id = registered_user_id + self.registered_time = get_time_or_now(registered_time) @staticmethod def from_db(db_registered_user: Tuple) -> 'RegisteredUser': return RegisteredUser( db_registered_user[0], db_registered_user[1], + db_registered_user[3], datetime.fromisoformat(db_registered_user[2]) ) diff --git a/register_new_matrix_user.py b/register_new_matrix_user.py index 02fee2c..5e3b01e 100644 --- a/register_new_matrix_user.py +++ b/register_new_matrix_user.py @@ -25,23 +25,21 @@ import hashlib import hmac import logging +from typing import Dict, Optional import requests as _requests - log = logging.getLogger(__name__) def request_registration( - user, - password, - server_location, - shared_secret, - admin=False, - user_type=None, - requests=_requests -): - + user, + password, + server_location, + shared_secret, + admin=False, + user_type=None, + requests=_requests) -> Optional[Dict]: url = "%s/_matrix/client/r0/admin/register" % (server_location,) # Get the nonce @@ -54,7 +52,7 @@ def request_registration( log.error(r.json()["error"]) except Exception as e: log.error(e) - return False + return None nonce = r.json()["nonce"] @@ -92,9 +90,14 @@ def request_registration( log.error(r.json()["error"]) except Exception as e: log.error(e) - return False - return True + return None + return r.json() -def register_new_user(user, password, server_location, shared_secret): - return request_registration(user, password, server_location, shared_secret, False, None) +def register_new_user( + user: str, + password: str, + server_location: str, + shared_secret: str) -> Optional[Dict]: + return request_registration( + user, password, server_location, shared_secret, False, None) diff --git a/schema.sql b/schema.sql index b6521d6..7a28105 100644 --- a/schema.sql +++ b/schema.sql @@ -7,7 +7,10 @@ CREATE TABLE IF NOT EXISTS registration_codes ( ); CREATE TABLE IF NOT EXISTS registered_users ( - registrationCode VARCHAR(60) NOT NULL UNIQUE, - username VARCHAR(30) NOT NULL UNIQUE , - registeredTime TIMESTAMP NOT NULL -); \ No newline at end of file + registrationCode VARCHAR(60) NOT NULL UNIQUE, + username VARCHAR(30) NOT NULL UNIQUE, + registeredTime TIMESTAMP NOT NULL +); + +ALTER TABLE registered_users + ADD COLUMN registeredUserId VARCHAR(60); \ No newline at end of file