You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
283 lines
8.4 KiB
283 lines
8.4 KiB
import logging
|
|
import os
|
|
import uuid
|
|
from logging.config import dictConfig
|
|
from urllib.parse import urlparse, urljoin
|
|
|
|
import flask
|
|
from flask import Flask, redirect, render_template, request, g, flash, url_for
|
|
from flask_limiter import Limiter
|
|
from flask_limiter.util import get_remote_address
|
|
from flask_login import (LoginManager,
|
|
login_required,
|
|
login_user,
|
|
logout_user,
|
|
UserMixin)
|
|
from flask_wtf import CSRFProtect
|
|
from werkzeug.contrib.fixers import ProxyFix
|
|
|
|
from db import (
|
|
get_db,
|
|
get_registration_codes,
|
|
add_registration_code,
|
|
expire_registration_code,
|
|
delete_registration_code,
|
|
get_registered_users,
|
|
add_registered_user)
|
|
from forms import (RegistrationForm,
|
|
LoginForm,
|
|
RegistrationCodeForm,
|
|
ExpireRegistrationCodeForm)
|
|
from register_new_matrix_user import register_new_user
|
|
|
|
csrf = CSRFProtect()
|
|
login_manager = LoginManager()
|
|
|
|
dictConfig({
|
|
'version': 1,
|
|
'disable_existing_loggers': False,
|
|
'formatters': {'default': {
|
|
'format': '[%(asctime)s] %(levelname)s in %(module)s: %(message)s',
|
|
}},
|
|
'handlers': {
|
|
'wsgi': {
|
|
'class': 'logging.StreamHandler',
|
|
'stream': 'ext://flask.logging.wsgi_errors_stream',
|
|
'formatter': 'default'
|
|
}
|
|
},
|
|
'root': {
|
|
'level': 'INFO',
|
|
'handlers': ['wsgi']
|
|
}
|
|
})
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def init_db(flask_app):
|
|
with flask_app.app_context():
|
|
log.info("Initializing DB")
|
|
db = get_db()
|
|
with flask_app.open_resource('schema.sql', mode='r') as f:
|
|
db.cursor().executescript(f.read())
|
|
log.info("Initialized DB")
|
|
db.commit()
|
|
|
|
|
|
def create_app():
|
|
flask_app = Flask(__name__)
|
|
|
|
flask_app.config.update(dict(
|
|
APPLICATION_ROOT=os.getenv("APPLICATION_ROOT", "/"),
|
|
ADMIN_TOKEN=os.getenv("ADMIN_TOKEN", uuid.uuid4().__str__()),
|
|
SECRET_KEY=os.getenv("SECRET_KEY", "changeme"),
|
|
WTF_CSRF_SECRET_KEY=os.getenv("CSRF_SECRET_KEY", "csrf_changeme"),
|
|
MATRIX_HOMESERVER=os.getenv("MATRIX_HOMESERVER"),
|
|
MATRIX_SHARED_SECRET=os.getenv("MATRIX_SHARED_SECRET"),
|
|
REGISTRATION_SUCCESS_REDIRECT=os.getenv(
|
|
"REGISTRATION_SUCCESS_REDIRECT"),
|
|
RIOT_APP_TARGET=os.getenv("RIOT_APP_TARGET", "https://riot.im/app/")
|
|
))
|
|
|
|
log.info("Admin Token: %s" % flask_app.config.get("ADMIN_TOKEN"))
|
|
|
|
csrf.init_app(flask_app)
|
|
|
|
login_manager.init_app(flask_app)
|
|
login_manager.login_view = "admin_login"
|
|
|
|
init_db(flask_app)
|
|
|
|
log.info("Application ready")
|
|
|
|
return flask_app
|
|
|
|
|
|
app = create_app()
|
|
app.wsgi_app = ProxyFix(app.wsgi_app, num_proxies=1)
|
|
limiter = Limiter(app, key_func=get_remote_address)
|
|
|
|
if __name__ != '__main__':
|
|
gunicorn_logger = logging.getLogger('gunicorn.error')
|
|
app.logger.handlers = gunicorn_logger.handlers
|
|
app.logger.setLevel(gunicorn_logger.level)
|
|
|
|
|
|
def flash_form_errors(form):
|
|
if hasattr(form, 'errors') and len(form.errors) > 0:
|
|
for error in form.errors.items():
|
|
flash("%s: %s" % (form[error[0]].label.text, error[1]), 'error')
|
|
|
|
|
|
class User(UserMixin):
|
|
username: str
|
|
token: str
|
|
authenticated: bool = False
|
|
|
|
def __init__(self, username: str, token: str):
|
|
self.username = username
|
|
self.token = token
|
|
|
|
def is_authenticated(self):
|
|
return self.authenticated
|
|
|
|
def get_id(self):
|
|
return self.username
|
|
|
|
|
|
def is_safe_url(target):
|
|
ref_url = urlparse(request.host_url)
|
|
test_url = urlparse(urljoin(request.host_url, target))
|
|
return test_url.scheme in (
|
|
'http', 'https') and ref_url.netloc == test_url.netloc
|
|
|
|
|
|
def get_successful_registration_redirect():
|
|
target = app.config.get('REGISTRATION_SUCCESS_REDIRECT')
|
|
if target is None or not target.startswith('http'):
|
|
return url_for('index', _external=True)
|
|
return target
|
|
|
|
|
|
@login_manager.user_loader
|
|
def load_user(user_id):
|
|
if user_id == "admin":
|
|
user = User("admin", app.config.get("ADMIN_TOKEN"))
|
|
g.user = user
|
|
return user
|
|
else:
|
|
return None
|
|
|
|
|
|
@app.route('/')
|
|
def index():
|
|
return redirect(url_for('registration'))
|
|
|
|
|
|
@app.route('/register', methods=('GET', 'POST'))
|
|
@limiter.limit("5/minute")
|
|
def registration():
|
|
form = RegistrationForm()
|
|
if form.validate_on_submit():
|
|
if app.config.get("MATRIX_HOMESERVER") is None:
|
|
flash(
|
|
"Matrix Homeserver Currently Unavailable. Please Try Again Later!")
|
|
return render_template('register.html', form=form)
|
|
else:
|
|
if app.config.get("MATRIX_SHARED_SECRET") is None:
|
|
flash(
|
|
"Registration Configuration Is Invalid. Contact Administrator!")
|
|
return render_template('register.html', form=form)
|
|
else:
|
|
response = register_new_user(
|
|
form.username.data,
|
|
form.password.data,
|
|
app.config.get("MATRIX_HOMESERVER"),
|
|
app.config.get("MATRIX_SHARED_SECRET")
|
|
)
|
|
if response is not None:
|
|
registered_user_id = response['user_id']
|
|
add_registered_user(
|
|
form.registration_code.data,
|
|
form.username.data,
|
|
registered_user_id)
|
|
return redirect(get_successful_registration_redirect())
|
|
else:
|
|
flash("Registration Failure. Contact Administrator!")
|
|
return render_template('register.html', form=form)
|
|
|
|
flash_form_errors(form)
|
|
|
|
if 'registrationCode' in request.values:
|
|
form.registration_code.data = request.values['registrationCode']
|
|
else:
|
|
flash("Missing registration code. Use provided link.", "error")
|
|
|
|
return render_template('register.html', form=form)
|
|
|
|
|
|
@app.route('/admin')
|
|
@login_required
|
|
def admin_index():
|
|
context = {
|
|
'add_registration_code_form': RegistrationCodeForm(),
|
|
'registration_codes': get_registration_codes(),
|
|
'registered_users': get_registered_users()
|
|
}
|
|
return render_template('admin.html', **context)
|
|
|
|
|
|
@app.route('/admin/add_registration_code', methods=['POST'])
|
|
@login_required
|
|
def admin_add_registration_code():
|
|
form = RegistrationCodeForm()
|
|
if form.validate_on_submit():
|
|
expiration_time = form.expiration_time.data
|
|
max_usages = form.max_usages.data
|
|
add_registration_code(expiration_time, max_usages)
|
|
redirect(url_for('admin_index', _external=True))
|
|
|
|
flash_form_errors(form)
|
|
|
|
return redirect(url_for('admin_index', _external=True))
|
|
|
|
|
|
@app.route('/admin/expire_registration_code', methods=['POST'])
|
|
@login_required
|
|
def admin_expire_registration_code():
|
|
form = ExpireRegistrationCodeForm()
|
|
if form.validate_on_submit():
|
|
if form.expire.data:
|
|
expire_registration_code(form.registration_code.data)
|
|
elif form.delete.data:
|
|
delete_registration_code(form.registration_code.data)
|
|
redirect(url_for('admin_index', _external=True))
|
|
|
|
flash_form_errors(form)
|
|
|
|
return redirect(url_for('admin_index', _external=True))
|
|
|
|
|
|
@app.route('/admin/login', methods=('GET', 'POST'))
|
|
@limiter.limit("2/minute")
|
|
def admin_login():
|
|
form = LoginForm()
|
|
if form.validate_on_submit():
|
|
user = load_user(form.username.data)
|
|
if user is not None:
|
|
if form.token.data == user.token:
|
|
user.authenticated = True
|
|
login_user(user)
|
|
flask.flash('Logged in successfully.')
|
|
next_loc = flask.request.args.get('next')
|
|
if not is_safe_url(next_loc):
|
|
return flask.abort(400)
|
|
else:
|
|
if next_loc is not None:
|
|
return redirect(next_loc)
|
|
else:
|
|
return redirect(url_for('admin_index', _external=True))
|
|
|
|
flash_form_errors(form)
|
|
|
|
return render_template('login.html', form=form)
|
|
|
|
|
|
@app.route("/admin/logout")
|
|
@login_required
|
|
def admin_logout():
|
|
logout_user()
|
|
flask.flash('Logged out successfully.')
|
|
return redirect(url_for('index', _external=True))
|
|
|
|
|
|
@app.teardown_appcontext
|
|
def close_connection(exception):
|
|
db = getattr(g, '_database', None)
|
|
if db is not None:
|
|
db.close()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
app.run()
|