Browse Source

Address mypy complaints

merge-requests/1/head
Drew Short 6 years ago
parent
commit
ffed367e2a
  1. 10
      server/atheneum/__init__.py
  2. 4
      server/atheneum/api/decorators.py
  3. 29
      server/atheneum/middleware/authentication_middleware.py
  4. 4
      server/atheneum/service/authentication_service.py
  5. 4
      server/atheneum/service/user_service.py
  6. 3
      server/atheneum/utility.py
  7. 2
      server/mypy.ini
  8. 23
      server/tests/conftest.py

10
server/atheneum/__init__.py

@ -26,17 +26,17 @@ dictConfig({
}) })
def create_app(test_config=None):
def create_app(test_config: dict = None) -> Flask:
app = Flask(__name__, instance_relative_config=True) app = Flask(__name__, instance_relative_config=True)
app.logger.debug('Creating Atheneum Server') app.logger.debug('Creating Atheneum Server')
data_directory = os.getenv('ATHENEUM_DATA_DIRECTORY', '/tmp') data_directory = os.getenv('ATHENEUM_DATA_DIRECTORY', '/tmp')
app.logger.debug('Atheneum Data Directory: %s', data_directory) app.logger.debug('Atheneum Data Directory: %s', data_directory)
default_database_uri = 'sqlite:///{}/atheneum.db'.format(data_directory)
app.config.from_mapping( app.config.from_mapping(
SECRET_KEY='dev', SECRET_KEY='dev',
SQLALCHEMY_DATABASE_URI='sqlite:///{}/atheneum.db'
.format(data_directory),
SQLALCHEMY_DATABASE_URI=default_database_uri,
SQLALCHEMY_TRACK_MODIFICATIONS=False SQLALCHEMY_TRACK_MODIFICATIONS=False
) )
@ -65,7 +65,7 @@ def create_app(test_config=None):
return app return app
def register_blueprints(app):
def register_blueprints(app: Flask) -> None:
from atheneum.api import auth_blueprint from atheneum.api import auth_blueprint
app.register_blueprint(auth_blueprint) app.register_blueprint(auth_blueprint)
@ -74,7 +74,7 @@ app = create_app()
register_blueprints(app) register_blueprints(app)
def init_db():
def init_db() -> None:
"""Clear existing data and create new tables.""" """Clear existing data and create new tables."""
upgrade('migrations') upgrade('migrations')

4
server/atheneum/api/decorators.py

@ -1,5 +1,5 @@
from functools import wraps from functools import wraps
from typing import Callable
from typing import Callable, Any
from flask import jsonify, Response from flask import jsonify, Response
@ -14,7 +14,7 @@ def return_json(func: Callable) -> Callable:
""" """
@wraps(func) @wraps(func)
def decorate(*args, **kwargs):
def decorate(*args: list, **kwargs: dict) -> Any:
result = func(*args, **kwargs) result = func(*args, **kwargs)
if isinstance(result, Response): if isinstance(result, Response):
return result return result

29
server/atheneum/middleware/authentication_middleware.py

@ -1,6 +1,6 @@
import base64 import base64
from functools import wraps from functools import wraps
from typing import Optional, Callable
from typing import Optional, Callable, Any
from flask import request, Response, g from flask import request, Response, g
from werkzeug.datastructures import Authorization from werkzeug.datastructures import Authorization
@ -24,12 +24,13 @@ def authenticate_with_password(name: str, password: str) -> bool:
def authenticate_with_token(name: str, token: str) -> bool: def authenticate_with_token(name: str, token: str) -> bool:
user = user_service.find_by_name(name) user = user_service.find_by_name(name)
user_token = user_token_service.find_by_user_and_token(user, token)
if user is not None \
and authentication_service.is_valid_token(user_token):
g.user = user
g.user_token = user_token
return True
if user is not None:
user_token = user_token_service.find_by_user_and_token(user, token)
if user is not None \
and authentication_service.is_valid_token(user_token):
g.user = user
g.user_token = user_token
return True
return False return False
@ -41,27 +42,29 @@ def authentication_failed(auth_type: str) -> Response:
}) })
def parse_token_authorization_header(header_value) -> Optional[Authorization]:
def parse_token_authorization_header(
header_value: str) -> Optional[Authorization]:
if not header_value: if not header_value:
return
return None
value = wsgi_to_bytes(header_value) value = wsgi_to_bytes(header_value)
try: try:
auth_type, auth_info = value.split(None, 1) auth_type, auth_info = value.split(None, 1)
auth_type = auth_type.lower() auth_type = auth_type.lower()
except ValueError: except ValueError:
return
return None
if auth_type == b'token': if auth_type == b'token':
try: try:
username, token = base64.b64decode(auth_info).split(b':', 1) username, token = base64.b64decode(auth_info).split(b':', 1)
except Exception: except Exception:
return
return None
return Authorization('token', {'username': bytes_to_wsgi(username), return Authorization('token', {'username': bytes_to_wsgi(username),
'password': bytes_to_wsgi(token)}) 'password': bytes_to_wsgi(token)})
return None
def require_basic_auth(func: Callable) -> Callable: def require_basic_auth(func: Callable) -> Callable:
@wraps(func) @wraps(func)
def decorate(*args, **kwargs):
def decorate(*args: list, **kwargs: dict) -> Any:
auth = request.authorization auth = request.authorization
if auth and authenticate_with_password(auth.username, auth.password): if auth and authenticate_with_password(auth.username, auth.password):
return func(*args, **kwargs) return func(*args, **kwargs)
@ -73,7 +76,7 @@ def require_basic_auth(func: Callable) -> Callable:
def require_token_auth(func: Callable) -> Callable: def require_token_auth(func: Callable) -> Callable:
@wraps(func) @wraps(func)
def decorate(*args, **kwargs):
def decorate(*args: list, **kwargs: dict) -> Any:
token = parse_token_authorization_header( token = parse_token_authorization_header(
request.headers.get('Authorization', None)) request.headers.get('Authorization', None))
if token and authenticate_with_token(token.username, token.password): if token and authenticate_with_token(token.username, token.password):

4
server/atheneum/service/authentication_service.py

@ -52,7 +52,7 @@ def is_valid_token(user_token: Optional[UserToken]) -> bool:
return True return True
def bump_login(user: Optional[User]):
def bump_login(user: Optional[User]) -> None:
""" """
Update the last login time for the user Update the last login time for the user
@ -63,7 +63,7 @@ def bump_login(user: Optional[User]):
user_service.update_last_login_time(user) user_service.update_last_login_time(user)
def logout(user_token: Optional[UserToken] = None):
def logout(user_token: Optional[UserToken] = None) -> None:
""" """
Remove a user_token associated with a client session Remove a user_token associated with a client session

4
server/atheneum/service/user_service.py

@ -31,12 +31,12 @@ def delete(user: User) -> bool:
return False return False
def update_last_login_time(user: User):
def update_last_login_time(user: User) -> None:
user.last_login_time = datetime.now() user.last_login_time = datetime.now()
db.session.commit() db.session.commit()
def update_password(user: User, password: str):
def update_password(user: User, password: str) -> None:
pw_hash, pw_revision = authentication_service.get_password_hash( pw_hash, pw_revision = authentication_service.get_password_hash(
password) password)
user.password_hash = pw_hash user.password_hash = pw_hash

3
server/atheneum/utility.py

@ -1,11 +1,12 @@
from datetime import date from datetime import date
from typing import Any
import rfc3339 import rfc3339
from flask.json import JSONEncoder from flask.json import JSONEncoder
class CustomJSONEncoder(JSONEncoder): class CustomJSONEncoder(JSONEncoder):
def default(self, obj):
def default(self, obj: Any) -> Any:
try: try:
if isinstance(obj, date): if isinstance(obj, date):
return rfc3339.format(obj) return rfc3339.format(obj)

2
server/mypy.ini

@ -1,7 +1,7 @@
[mypy] [mypy]
ignore_missing_imports = True ignore_missing_imports = True
follow_imports = skip follow_imports = skip
disallow_untyped_calls = True
disallow_untyped_calls = False
disallow_untyped_defs = True disallow_untyped_defs = True
check_untyped_defs = True check_untyped_defs = True
disallow_subclassing_any = False disallow_subclassing_any = False

23
server/tests/conftest.py

@ -6,6 +6,8 @@ import tempfile
from typing import Tuple, Any from typing import Tuple, Any
import pytest import pytest
from flask import Flask
from flask.testing import FlaskClient, FlaskCliRunner
from werkzeug.test import Client from werkzeug.test import Client
from atheneum import create_app, init_db, register_blueprints from atheneum import create_app, init_db, register_blueprints
@ -23,7 +25,7 @@ def add_test_user() -> Tuple[str, str]:
@pytest.fixture @pytest.fixture
def app():
def app() -> Flask:
"""Create and configure a new app instance for each test.""" """Create and configure a new app instance for each test."""
# create a temporary file to isolate the database for each test # create a temporary file to isolate the database for each test
db_fd, db_path = tempfile.mkstemp() db_fd, db_path = tempfile.mkstemp()
@ -50,30 +52,33 @@ def app():
@pytest.fixture @pytest.fixture
def client(app):
def client(app: Flask) -> FlaskClient:
"""A test client for the app.""" """A test client for the app."""
return app.test_client() return app.test_client()
@pytest.fixture @pytest.fixture
def runner(app):
def runner(app: Flask) -> FlaskCliRunner:
"""A test runner for the app's Click commands.""" """A test runner for the app's Click commands."""
return app.test_cli_runner() return app.test_cli_runner()
class AuthActions(object): class AuthActions(object):
def __init__(self, client: Client, username: str = "", password: str = ""):
def __init__(self,
client: Client,
username: str = "",
password: str = "") -> None:
self._client = client self._client = client
self.username: str = username self.username: str = username
self.password: str = password self.password: str = password
self.token: str = "" self.token: str = ""
def configure(self, username, password) -> Any:
def configure(self, username: str, password: str) -> Any:
self.username = username self.username = username
self.password = password self.password = password
return self return self
def login(self):
def login(self) -> Any:
auth_header = self.get_authorization_header_basic() auth_header = self.get_authorization_header_basic()
result = self._client.post( result = self._client.post(
'/auth/login', '/auth/login',
@ -84,7 +89,7 @@ class AuthActions(object):
self.token = result.json['token'] self.token = result.json['token']
return result return result
def bump(self):
def bump(self) -> Any:
auth_header = self.get_authorization_header_token() auth_header = self.get_authorization_header_token()
return self._client.post( return self._client.post(
'/auth/bump', '/auth/bump',
@ -93,7 +98,7 @@ class AuthActions(object):
} }
) )
def logout(self):
def logout(self) -> Any:
auth_header = self.get_authorization_header_token() auth_header = self.get_authorization_header_token()
return self._client.post( return self._client.post(
'/auth/logout', '/auth/logout',
@ -116,7 +121,7 @@ class AuthActions(object):
@pytest.fixture @pytest.fixture
def auth(client: Client):
def auth(client: Client) -> AuthActions:
return AuthActions(client, return AuthActions(client,
client.application.config.get('test_username'), client.application.config.get('test_username'),
client.application.config.get('test_password')) client.application.config.get('test_password'))
Loading…
Cancel
Save