import logging import os import uuid from urllib.parse import urlparse, urljoin import flask from flask import Flask, redirect, render_template, request, g from flask_login import LoginManager, login_required, login_user, logout_user, UserMixin, \ current_user from flask_wtf import CSRFProtect from db import get_db, get_registration_codes, add_registration_code, \ expire_registration_code, delete_registration_code, get_registered_users from forms import RegistrationForm, LoginForm, RegistrationCodeForm, \ ExpireRegistrationCodeForm csrf = CSRFProtect() login_manager = LoginManager() log = logging.getLogger(__name__) def init_db(flask_app): with flask_app.app_context(): db = get_db() with flask_app.open_resource('schema.sql', mode='r') as f: db.cursor().executescript(f.read()) log.info("Initialized DB") db.commit() def create_app(): flask_app = Flask(__name__) flask_app.config.update(dict( ADMIN_TOKEN=os.getenv("ADMIN_TOKEN", uuid.uuid4().__str__()), SECRET_KEY=os.getenv("SECRET_KEY", "changeme"), WTF_CSRF_SECRET_KEY=os.getenv("CSRF_SECRET_KEY", "csrf_changeme") )) print("Admin Token: %s" % flask_app.config.get("ADMIN_TOKEN")) csrf.init_app(flask_app) login_manager.init_app(flask_app) login_manager.login_view = "admin_login" init_db(flask_app) return flask_app app = create_app() class User(UserMixin): username: str token: str authenticated: bool = False def __init__(self, username: str, token: str): self.username = username self.token = token def is_authenticated(self): return self.authenticated def get_id(self): return self.username def is_safe_url(target): ref_url = urlparse(request.host_url) test_url = urlparse(urljoin(request.host_url, target)) return test_url.scheme in ('http', 'https') and ref_url.netloc == test_url.netloc @login_manager.user_loader def load_user(user_id): if user_id == "admin": return User("admin", app.config.get("ADMIN_TOKEN")) else: return None @app.route('/') def index(): return 'Hello World!' @app.route('/register', methods=('GET', 'POST')) def registration(): form = RegistrationForm() if form.validate_on_submit(): return redirect('/success') if 'registrationCode' in request.values: form.registration_code.data = request.values['registrationCode'] return render_template('register.html', form=form) @app.route('/admin') @login_required def admin_index(): context = { 'add_registration_code_form': RegistrationCodeForm(), 'registration_codes': get_registration_codes(), 'registered_users': get_registered_users() } return render_template('admin.html', **context) @app.route('/admin/add_registration_code', methods=['POST']) @login_required def admin_add_registration_code(): form = RegistrationCodeForm() if form.validate_on_submit(): expiration_time = form.expiration_time.data max_usages = form.max_usages.data add_registration_code(expiration_time, max_usages) redirect('/admin') return redirect('/admin') @app.route('/admin/expire_registration_code', methods=['POST']) @login_required def admin_expire_registration_code(): form = ExpireRegistrationCodeForm() if form.validate_on_submit(): if form.expire.data: expire_registration_code(form.registration_code.data) elif form.delete.data: delete_registration_code(form.registration_code.data) redirect('/admin') return redirect('/admin') @app.route('/admin/login', methods=('GET', 'POST')) def admin_login(): form = LoginForm() tmp = current_user if form.validate_on_submit(): user = load_user(form.username.data) if user is not None: if form.token.data == user.token: user.authenticated = True login_user(user) flask.flash('Logged in successfully.') next_loc = flask.request.args.get('next') if not is_safe_url(next_loc): return flask.abort(400) else: if next_loc is not None: return redirect(next_loc) else: return redirect('/admin') return render_template('login.html', form=form) @app.route("/admin/logout") @login_required def admin_logout(): logout_user() flask.flash('Logged out successfully.') return redirect('/') @app.teardown_appcontext def close_connection(exception): db = getattr(g, '_database', None) if db is not None: db.close() if __name__ == '__main__': app.run()