diff --git a/server/atheneum/__init__.py b/server/atheneum/__init__.py index f07f898..a9f68a6 100644 --- a/server/atheneum/__init__.py +++ b/server/atheneum/__init__.py @@ -26,17 +26,17 @@ dictConfig({ }) -def create_app(test_config=None): +def create_app(test_config: dict = None) -> Flask: app = Flask(__name__, instance_relative_config=True) app.logger.debug('Creating Atheneum Server') data_directory = os.getenv('ATHENEUM_DATA_DIRECTORY', '/tmp') app.logger.debug('Atheneum Data Directory: %s', data_directory) + default_database_uri = 'sqlite:///{}/atheneum.db'.format(data_directory) app.config.from_mapping( SECRET_KEY='dev', - SQLALCHEMY_DATABASE_URI='sqlite:///{}/atheneum.db' - .format(data_directory), + SQLALCHEMY_DATABASE_URI=default_database_uri, SQLALCHEMY_TRACK_MODIFICATIONS=False ) @@ -65,7 +65,7 @@ def create_app(test_config=None): return app -def register_blueprints(app): +def register_blueprints(app: Flask) -> None: from atheneum.api import auth_blueprint app.register_blueprint(auth_blueprint) @@ -74,7 +74,7 @@ app = create_app() register_blueprints(app) -def init_db(): +def init_db() -> None: """Clear existing data and create new tables.""" upgrade('migrations') diff --git a/server/atheneum/api/decorators.py b/server/atheneum/api/decorators.py index 731c7a2..8d3e9b2 100644 --- a/server/atheneum/api/decorators.py +++ b/server/atheneum/api/decorators.py @@ -1,5 +1,5 @@ from functools import wraps -from typing import Callable +from typing import Callable, Any from flask import jsonify, Response @@ -14,7 +14,7 @@ def return_json(func: Callable) -> Callable: """ @wraps(func) - def decorate(*args, **kwargs): + def decorate(*args: list, **kwargs: dict) -> Any: result = func(*args, **kwargs) if isinstance(result, Response): return result diff --git a/server/atheneum/middleware/authentication_middleware.py b/server/atheneum/middleware/authentication_middleware.py index afd5d8a..3113a13 100644 --- a/server/atheneum/middleware/authentication_middleware.py +++ b/server/atheneum/middleware/authentication_middleware.py @@ -1,6 +1,6 @@ import base64 from functools import wraps -from typing import Optional, Callable +from typing import Optional, Callable, Any from flask import request, Response, g from werkzeug.datastructures import Authorization @@ -24,12 +24,13 @@ def authenticate_with_password(name: str, password: str) -> bool: def authenticate_with_token(name: str, token: str) -> bool: user = user_service.find_by_name(name) - user_token = user_token_service.find_by_user_and_token(user, token) - if user is not None \ - and authentication_service.is_valid_token(user_token): - g.user = user - g.user_token = user_token - return True + if user is not None: + user_token = user_token_service.find_by_user_and_token(user, token) + if user is not None \ + and authentication_service.is_valid_token(user_token): + g.user = user + g.user_token = user_token + return True return False @@ -41,27 +42,29 @@ def authentication_failed(auth_type: str) -> Response: }) -def parse_token_authorization_header(header_value) -> Optional[Authorization]: +def parse_token_authorization_header( + header_value: str) -> Optional[Authorization]: if not header_value: - return + return None value = wsgi_to_bytes(header_value) try: auth_type, auth_info = value.split(None, 1) auth_type = auth_type.lower() except ValueError: - return + return None if auth_type == b'token': try: username, token = base64.b64decode(auth_info).split(b':', 1) except Exception: - return + return None return Authorization('token', {'username': bytes_to_wsgi(username), 'password': bytes_to_wsgi(token)}) + return None def require_basic_auth(func: Callable) -> Callable: @wraps(func) - def decorate(*args, **kwargs): + def decorate(*args: list, **kwargs: dict) -> Any: auth = request.authorization if auth and authenticate_with_password(auth.username, auth.password): return func(*args, **kwargs) @@ -73,7 +76,7 @@ def require_basic_auth(func: Callable) -> Callable: def require_token_auth(func: Callable) -> Callable: @wraps(func) - def decorate(*args, **kwargs): + def decorate(*args: list, **kwargs: dict) -> Any: token = parse_token_authorization_header( request.headers.get('Authorization', None)) if token and authenticate_with_token(token.username, token.password): diff --git a/server/atheneum/service/authentication_service.py b/server/atheneum/service/authentication_service.py index de528a0..d859be1 100644 --- a/server/atheneum/service/authentication_service.py +++ b/server/atheneum/service/authentication_service.py @@ -52,7 +52,7 @@ def is_valid_token(user_token: Optional[UserToken]) -> bool: return True -def bump_login(user: Optional[User]): +def bump_login(user: Optional[User]) -> None: """ Update the last login time for the user @@ -63,7 +63,7 @@ def bump_login(user: Optional[User]): user_service.update_last_login_time(user) -def logout(user_token: Optional[UserToken] = None): +def logout(user_token: Optional[UserToken] = None) -> None: """ Remove a user_token associated with a client session diff --git a/server/atheneum/service/user_service.py b/server/atheneum/service/user_service.py index 4c2a4a1..dd96b81 100644 --- a/server/atheneum/service/user_service.py +++ b/server/atheneum/service/user_service.py @@ -31,12 +31,12 @@ def delete(user: User) -> bool: return False -def update_last_login_time(user: User): +def update_last_login_time(user: User) -> None: user.last_login_time = datetime.now() db.session.commit() -def update_password(user: User, password: str): +def update_password(user: User, password: str) -> None: pw_hash, pw_revision = authentication_service.get_password_hash( password) user.password_hash = pw_hash diff --git a/server/atheneum/utility.py b/server/atheneum/utility.py index 53f8abd..2277381 100644 --- a/server/atheneum/utility.py +++ b/server/atheneum/utility.py @@ -1,11 +1,12 @@ from datetime import date +from typing import Any import rfc3339 from flask.json import JSONEncoder class CustomJSONEncoder(JSONEncoder): - def default(self, obj): + def default(self, obj: Any) -> Any: try: if isinstance(obj, date): return rfc3339.format(obj) diff --git a/server/mypy.ini b/server/mypy.ini index 82f7c48..c440ed6 100644 --- a/server/mypy.ini +++ b/server/mypy.ini @@ -1,7 +1,7 @@ [mypy] ignore_missing_imports = True follow_imports = skip -disallow_untyped_calls = True +disallow_untyped_calls = False disallow_untyped_defs = True check_untyped_defs = True disallow_subclassing_any = False diff --git a/server/tests/conftest.py b/server/tests/conftest.py index 34565da..c14305d 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -6,6 +6,8 @@ import tempfile from typing import Tuple, Any import pytest +from flask import Flask +from flask.testing import FlaskClient, FlaskCliRunner from werkzeug.test import Client from atheneum import create_app, init_db, register_blueprints @@ -23,7 +25,7 @@ def add_test_user() -> Tuple[str, str]: @pytest.fixture -def app(): +def app() -> Flask: """Create and configure a new app instance for each test.""" # create a temporary file to isolate the database for each test db_fd, db_path = tempfile.mkstemp() @@ -50,30 +52,33 @@ def app(): @pytest.fixture -def client(app): +def client(app: Flask) -> FlaskClient: """A test client for the app.""" return app.test_client() @pytest.fixture -def runner(app): +def runner(app: Flask) -> FlaskCliRunner: """A test runner for the app's Click commands.""" return app.test_cli_runner() class AuthActions(object): - def __init__(self, client: Client, username: str = "", password: str = ""): + def __init__(self, + client: Client, + username: str = "", + password: str = "") -> None: self._client = client self.username: str = username self.password: str = password self.token: str = "" - def configure(self, username, password) -> Any: + def configure(self, username: str, password: str) -> Any: self.username = username self.password = password return self - def login(self): + def login(self) -> Any: auth_header = self.get_authorization_header_basic() result = self._client.post( '/auth/login', @@ -84,7 +89,7 @@ class AuthActions(object): self.token = result.json['token'] return result - def bump(self): + def bump(self) -> Any: auth_header = self.get_authorization_header_token() return self._client.post( '/auth/bump', @@ -93,7 +98,7 @@ class AuthActions(object): } ) - def logout(self): + def logout(self) -> Any: auth_header = self.get_authorization_header_token() return self._client.post( '/auth/logout', @@ -116,7 +121,7 @@ class AuthActions(object): @pytest.fixture -def auth(client: Client): +def auth(client: Client) -> AuthActions: return AuthActions(client, client.application.config.get('test_username'), client.application.config.get('test_password'))