diff --git a/.gitignore b/.gitignore index 65752d4..ef19105 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ instance/ .idea +*.iml .admin_credentials *__pycache__/ .pytest_cache/ diff --git a/server/atheneum/__init__.py b/server/atheneum/__init__.py index cc01025..1b4acf0 100644 --- a/server/atheneum/__init__.py +++ b/server/atheneum/__init__.py @@ -77,8 +77,9 @@ def register_blueprints(app: Flask) -> None: :param app: :return: """ - from atheneum.api import AUTH_BLUEPRINT + from atheneum.api import AUTH_BLUEPRINT, USER_BLUEPRINT app.register_blueprint(AUTH_BLUEPRINT) + app.register_blueprint(USER_BLUEPRINT) atheneum = create_app() # pylint: disable=C0103 diff --git a/server/atheneum/api/__init__.py b/server/atheneum/api/__init__.py index 6777d4c..b6e114f 100644 --- a/server/atheneum/api/__init__.py +++ b/server/atheneum/api/__init__.py @@ -1,2 +1,3 @@ """API blueprint exports.""" from atheneum.api.authentication_api import AUTH_BLUEPRINT +from atheneum.api.user_api import USER_BLUEPRINT diff --git a/server/atheneum/api/user_api.py b/server/atheneum/api/user_api.py new file mode 100644 index 0000000..e3aa8c5 --- /dev/null +++ b/server/atheneum/api/user_api.py @@ -0,0 +1,25 @@ +"""User API blueprint and endpoint definitions.""" +from flask import Blueprint, abort + +from atheneum.api.decorators import return_json +from atheneum.api.model import APIResponse +from atheneum.middleware import authentication_middleware +from atheneum.service import user_service + +USER_BLUEPRINT = Blueprint( + name='user', import_name=__name__, url_prefix='/user') + + +@USER_BLUEPRINT.route('/', methods=['GET']) +@return_json +@authentication_middleware.require_token_auth +def get_user(name: str) -> APIResponse: + """ + Get a token for continued authentication. + + :return: A login token for continued authentication + """ + user = user_service.find_by_name(name) + if user is not None: + return APIResponse(user, 200) + return abort(404) diff --git a/server/atheneum/errors.py b/server/atheneum/errors.py new file mode 100644 index 0000000..e5a89f3 --- /dev/null +++ b/server/atheneum/errors.py @@ -0,0 +1,20 @@ +"""Error definitions for Atheneum.""" +from typing import Dict + + +class BaseError(RuntimeError): + """Atheneum Base Error Class.""" + + def __init__( + self, + message: str = 'Unknown error', + extra_fields: Dict[str, str] = None) -> None: + """Populate The Error Definition.""" + super().__init__(message) + self.extra_fields = extra_fields + + +class ValidationError(BaseError): + """Atheneum Validation Error.""" + + pass diff --git a/server/atheneum/service/serialization_service.py b/server/atheneum/service/serialization_service.py new file mode 100644 index 0000000..9986d66 --- /dev/null +++ b/server/atheneum/service/serialization_service.py @@ -0,0 +1,52 @@ +"""Handle Model Serialization.""" +from typing import Dict, Callable, Any, List, Optional, Type + +from atheneum import errors +from atheneum.db import db + + +class BaseSerializer: # pylint: disable=too-few-public-methods + """Base Model serializer.""" + + def __init__(self, model: Type[db.Model]) -> None: + """Initialize the base serializer.""" + self._fields: Dict[str, Callable[[db.Model], Any]] = {} + self.model = model + + def serialize(self, options: Optional[List[str]]) -> Any: + """Convert Model field and factories to dicts.""" + field_factories = self._serializers() + if not options: + options = list(field_factories.keys()) + ret = {} + for key in options: + if key not in field_factories: + raise errors.ValidationError( + 'Invalid key: %r. Valid keys: %r.' % ( + key, list(sorted(field_factories.keys())))) + factory = field_factories[key] + ret[key] = factory() + return ret + + def _serializers(self) -> Dict[str, Callable[[], Any]]: + """Field definitions.""" + raise NotImplementedError() + + +_model_serializers: Dict[str, Type[BaseSerializer]] = {} + + +def register_serializer( + model_name: str, model_serializer: Type[BaseSerializer]) -> None: + """Add a model to the serializer mapping.""" + _model_serializers[model_name] = model_serializer + + +def serialize_model(model_obj: db.Model) -> Any: + """Lookup a Model and hand off to the serializer.""" + try: + return _model_serializers[ + type(model_obj).__name__](model_obj).serialize(None) + except KeyError: + raise NotImplementedError( + '{} has no registered serializers'.format(model_obj.__name__)) diff --git a/server/atheneum/service/user_service.py b/server/atheneum/service/user_service.py index c95fd88..4751c9d 100644 --- a/server/atheneum/service/user_service.py +++ b/server/atheneum/service/user_service.py @@ -1,15 +1,56 @@ """Service to handle user operations.""" import logging from datetime import datetime -from typing import Optional +from typing import Optional, Dict, Callable, Any from atheneum.db import db from atheneum.model import User from atheneum.utility import authentication_utility +from atheneum.service.serialization_service import ( + BaseSerializer, + register_serializer +) LOGGER = logging.getLogger(__name__) +class UserSerializer(BaseSerializer): + """Serialize User model.""" + + def _serializers(self) -> Dict[str, Callable[[], Any]]: + """Define the fields and the accompanying serializer factory.""" + return { + 'name': self.serialize_name, + 'creationTime': self.serialize_creation_time, + 'lastLoginTime': self.serialize_last_login_time, + 'version': self.serialize_version, + 'role': self.serialize_role, + } + + def serialize_name(self) -> str: + """User name.""" + return self.model.name + + def serialize_creation_time(self) -> datetime: + """User creation time.""" + return self.model.creation_time + + def serialize_last_login_time(self) -> datetime: + """User last login time.""" + return self.model.last_login_time + + def serialize_version(self) -> int: + """User version.""" + return self.model.version + + def serialize_role(self) -> str: + """User role.""" + return self.model.role + + +register_serializer(User.__name__, UserSerializer) + + def register(name: str, password: str, role: str) -> User: """ Register a new user. diff --git a/server/atheneum/utility/json_utility.py b/server/atheneum/utility/json_utility.py index 6953528..0b99bda 100644 --- a/server/atheneum/utility/json_utility.py +++ b/server/atheneum/utility/json_utility.py @@ -5,6 +5,9 @@ from typing import Any import rfc3339 from flask.json import JSONEncoder +from atheneum.db import db +from atheneum.service.serialization_service import serialize_model + class CustomJSONEncoder(JSONEncoder): """Ensure that datetime values are serialized correctly.""" @@ -14,6 +17,8 @@ class CustomJSONEncoder(JSONEncoder): try: if isinstance(o, date): return rfc3339.format(o) + if isinstance(o, db.Model): + return serialize_model(o) iterable = iter(o) except TypeError: pass diff --git a/server/tests/api/test_user_api.py b/server/tests/api/test_user_api.py new file mode 100644 index 0000000..135f896 --- /dev/null +++ b/server/tests/api/test_user_api.py @@ -0,0 +1,16 @@ +from flask.testing import FlaskClient + +from tests.conftest import AuthActions + + +def test_get_user_happy_path(auth: AuthActions, client: FlaskClient): + auth.login() + auth_header = auth.get_authorization_header_token() + result = client.get( + '/user/{}'.format(client.application.config['test_username']), + headers={ + auth_header[0]: auth_header[1] + }) + assert result.status_code == 200 + assert result.json is not None + assert result.json['name'] == client.application.config['test_username']