Browse Source

Added column to record the user matrix_id on successful registration

master
Drew Short 6 years ago
parent
commit
8b6b069617
  1. 13
      app.py
  2. 10
      db.py
  3. 18
      models.py
  4. 21
      register_new_matrix_user.py
  5. 3
      schema.sql

13
app.py

@ -6,7 +6,8 @@ from urllib.parse import urlparse, urljoin
import flask import flask
from flask import Flask, redirect, render_template, request, g, flash, url_for from flask import Flask, redirect, render_template, request, g, flash, url_for
from flask_login import LoginManager, login_required, login_user, logout_user, UserMixin
from flask_login import LoginManager, login_required, login_user, logout_user, \
UserMixin
from flask_wtf import CSRFProtect from flask_wtf import CSRFProtect
from db import get_db, get_registration_codes, add_registration_code, \ from db import get_db, get_registration_codes, add_registration_code, \
@ -145,14 +146,18 @@ def registration():
flash("Registration Configuration Is Invalid. Contact Administrator!") flash("Registration Configuration Is Invalid. Contact Administrator!")
return render_template('register.html', form=form) return render_template('register.html', form=form)
else: else:
successful = register_new_user(
response = register_new_user(
form.username.data, form.username.data,
form.password.data, form.password.data,
app.config.get("MATRIX_HOMESERVER"), app.config.get("MATRIX_HOMESERVER"),
app.config.get("MATRIX_SHARED_SECRET") app.config.get("MATRIX_SHARED_SECRET")
) )
if successful:
add_registered_user(form.registration_code.data, form.username.data)
if response is not None:
registered_user_id = response['user_id']
add_registered_user(
form.registration_code.data,
form.username.data,
registered_user_id)
return redirect(get_successful_registration_redirect()) return redirect(get_successful_registration_redirect())
else: else:
flash("Registration Failure. Contact Administrator!") flash("Registration Failure. Contact Administrator!")

10
db.py

@ -21,8 +21,8 @@ SET expirationTime = ?, usages = ?, maxUsages = ?
WHERE code = ?""" WHERE code = ?"""
REGISTERED_USER_INSERT_SQL = """INSERT INTO REGISTERED_USER_INSERT_SQL = """INSERT INTO
registered_users(registrationCode, username, registeredTime)
VALUES(?, ?, ?)
registered_users(registrationCode, username, registeredUserId, registeredTime)
VALUES(?, ?, ?, ?)
""" """
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -118,12 +118,14 @@ def get_registered_user(username: str) -> Optional[RegisteredUser]:
return None return None
def add_registered_user(code: str, username: str) -> RegisteredUser:
def add_registered_user(code: str, username: str,
matrix_user_id: str) -> RegisteredUser:
registration_code = get_registration_code(code) registration_code = get_registration_code(code)
registration_code.usages += 1 registration_code.usages += 1
update_registration_code(registration_code) update_registration_code(registration_code)
db = get_db() db = get_db()
db.execute(REGISTERED_USER_INSERT_SQL, [code, username, datetime.now()])
db.execute(REGISTERED_USER_INSERT_SQL,
[code, username, matrix_user_id, datetime.now()])
db.commit() db.commit()
return get_registered_user(username) return get_registered_user(username)

18
models.py

@ -2,6 +2,14 @@ from datetime import datetime
from typing import Tuple from typing import Tuple
def get_time_or_now(time):
return time if time is not None else datetime.now()
def parse_optional_time(time):
return None if time is None else datetime.fromisoformat(time)
class RegistrationCode: class RegistrationCode:
def __init__(self, def __init__(self,
code: str, code: str,
@ -10,18 +18,17 @@ class RegistrationCode:
usages: int = 0, usages: int = 0,
max_usages: int = 1): max_usages: int = 1):
self.code = code self.code = code
self.creation_time = creation_time if creation_time is not None else datetime.now()
self.creation_time = get_time_or_now(creation_time)
self.expiration_time = expiration_time self.expiration_time = expiration_time
self.usages = usages self.usages = usages
self.max_usages = max_usages self.max_usages = max_usages
@staticmethod @staticmethod
def from_db(db_registration_code: Tuple) -> 'RegistrationCode': def from_db(db_registration_code: Tuple) -> 'RegistrationCode':
expiration_time = None if db_registration_code[2] is None else datetime.fromisoformat(db_registration_code[2])
return RegistrationCode( return RegistrationCode(
db_registration_code[0], db_registration_code[0],
datetime.fromisoformat(db_registration_code[1]), datetime.fromisoformat(db_registration_code[1]),
expiration_time,
parse_optional_time(db_registration_code[2]),
db_registration_code[3], db_registration_code[3],
db_registration_code[4] db_registration_code[4]
) )
@ -37,15 +44,18 @@ class RegisteredUser:
def __init__(self, def __init__(self,
registration_code: str, registration_code: str,
username: str, username: str,
registered_user_id: str,
registered_time: datetime = None): registered_time: datetime = None):
self.registration_code = registration_code self.registration_code = registration_code
self.username = username self.username = username
self.registered_time = registered_time if registered_time is not None else datetime.now()
self.registered_user_id = registered_user_id
self.registered_time = get_time_or_now(registered_time)
@staticmethod @staticmethod
def from_db(db_registered_user: Tuple) -> 'RegisteredUser': def from_db(db_registered_user: Tuple) -> 'RegisteredUser':
return RegisteredUser( return RegisteredUser(
db_registered_user[0], db_registered_user[0],
db_registered_user[1], db_registered_user[1],
db_registered_user[3],
datetime.fromisoformat(db_registered_user[2]) datetime.fromisoformat(db_registered_user[2])
) )

21
register_new_matrix_user.py

@ -25,10 +25,10 @@
import hashlib import hashlib
import hmac import hmac
import logging import logging
from typing import Dict, Optional
import requests as _requests import requests as _requests
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -39,9 +39,7 @@ def request_registration(
shared_secret, shared_secret,
admin=False, admin=False,
user_type=None, user_type=None,
requests=_requests
):
requests=_requests) -> Optional[Dict]:
url = "%s/_matrix/client/r0/admin/register" % (server_location,) url = "%s/_matrix/client/r0/admin/register" % (server_location,)
# Get the nonce # Get the nonce
@ -54,7 +52,7 @@ def request_registration(
log.error(r.json()["error"]) log.error(r.json()["error"])
except Exception as e: except Exception as e:
log.error(e) log.error(e)
return False
return None
nonce = r.json()["nonce"] nonce = r.json()["nonce"]
@ -92,9 +90,14 @@ def request_registration(
log.error(r.json()["error"]) log.error(r.json()["error"])
except Exception as e: except Exception as e:
log.error(e) log.error(e)
return False
return True
return None
return r.json()
def register_new_user(user, password, server_location, shared_secret):
return request_registration(user, password, server_location, shared_secret, False, None)
def register_new_user(
user: str,
password: str,
server_location: str,
shared_secret: str) -> Optional[Dict]:
return request_registration(
user, password, server_location, shared_secret, False, None)

3
schema.sql

@ -11,3 +11,6 @@ CREATE TABLE IF NOT EXISTS registered_users (
username VARCHAR(30) NOT NULL UNIQUE, username VARCHAR(30) NOT NULL UNIQUE,
registeredTime TIMESTAMP NOT NULL registeredTime TIMESTAMP NOT NULL
); );
ALTER TABLE registered_users
ADD COLUMN registeredUserId VARCHAR(60);
Loading…
Cancel
Save