diff --git a/server/atheneum/api/user_api.py b/server/atheneum/api/user_api.py index 9067319..af0fe61 100644 --- a/server/atheneum/api/user_api.py +++ b/server/atheneum/api/user_api.py @@ -5,7 +5,7 @@ from atheneum.api.decorators import return_json from atheneum.api.model import APIResponse from atheneum.middleware import authentication_middleware from atheneum.model import User -from atheneum.service import user_service, transformation_service +from atheneum.service import user_service, transformation_service, patch_service from atheneum.service.role_service import Role USER_BLUEPRINT = Blueprint( @@ -15,6 +15,7 @@ USER_BLUEPRINT = Blueprint( @USER_BLUEPRINT.route('/', methods=['GET']) @return_json @authentication_middleware.require_token_auth +@authentication_middleware.require_role(required_role=Role.USER) def get_user(name: str) -> APIResponse: """ Get a user. @@ -27,6 +28,28 @@ def get_user(name: str) -> APIResponse: return abort(404) +@USER_BLUEPRINT.route('/', methods=['PUT']) +@return_json +@authentication_middleware.require_token_auth +@authentication_middleware.require_role(required_role=Role.USER) +def patch_user(name: str) -> APIResponse: + """ + Patch a user. + + :return: user if patched, 4xx error on patching issue, 404 on nonexistent + """ + user = user_service.find_by_name(name) + if user is not None: + user_patch: User = transformation_service.deserialize_model( + User, request.json) + try: + user = patch_service.patch(user, user_patch) + except ValueError: + return abort(400) + return APIResponse(user, 200) + return abort(404) + + @USER_BLUEPRINT.route('/', methods=['POST']) @return_json @authentication_middleware.require_token_auth @@ -38,7 +61,7 @@ def register_user() -> APIResponse: :return: The newly registered User """ new_user: User = transformation_service.deserialize_model( - User.__name__, request.json) + User, request.json) registered_user = user_service.register( name=new_user.name, password=None, diff --git a/server/atheneum/service/patch_service.py b/server/atheneum/service/patch_service.py new file mode 100644 index 0000000..49af7f6 --- /dev/null +++ b/server/atheneum/service/patch_service.py @@ -0,0 +1,152 @@ +from typing import Type, List, Any, Dict + +from sqlalchemy import orm + +from atheneum import db + +_patchable_attribute_names = {} +_restricted_attribute_names = {} + + +def get_patchable_attribute_names(model: Type[db.Model]) -> List[str]: + """ + Retrieve columns from a SQLAlchemy model. + Caches already seen models to improve performance. + + :param model: + :return: A list of patchable model attribute names + """ + class_name = model.__class__.__name__ + if class_name in _patchable_attribute_names: + return _patchable_attribute_names[class_name] + else: + model_attributes = [prop.key for prop in + orm.class_mapper(model.__class__).iterate_properties + if isinstance(prop, orm.ColumnProperty)] + _patchable_attribute_names[class_name] = model_attributes + return model_attributes + + +def is_restricted_attribute(column_property: orm.ColumnProperty) -> bool: + """ + A primary_key or unique columns are not patchable + + :param column_property: The SQLAlchemy column element + :return: A boolean indicating if a field is restricted + """ + column = column_property.columns[0] + return column.primary_key or column.unique is True + + +def get_restricted_attribute_names(model: Type[db.Model]) -> List[str]: + """ + Retrieve primary_key or unique columns from a SQLAlchemy model. + Caches already seen models to improve performance. + + :param model: + :return: A list of patchable model attribute names + """ + class_name = model.__class__.__name__ + if class_name in _restricted_attribute_names: + return _restricted_attribute_names[class_name] + else: + model_attributes = [prop.key for prop in + orm.class_mapper(model.__class__).iterate_properties + if isinstance(prop, orm.ColumnProperty) + and is_restricted_attribute(prop)] + _restricted_attribute_names[class_name] = model_attributes + return model_attributes + + +def determine_patch(original_model: Type[db.Model], + patch_model: Type[db.Model], + model_attributes: List[str] + ) -> Dict[str, Any]: + """ + Determine the patch set for two models + + :param original_model: + :param patch_model: + :param model_attributes: + :return: + """ + patch_set = {} + for attribute in model_attributes: + original_attribute = getattr(original_model, attribute) + patch_attribute = getattr(patch_model, attribute) + if original_attribute != patch_attribute: + patch_set[attribute] = patch_attribute + return patch_set + + +def perform_patch(original_model: Type[db.Model], + patch_model: Type[db.Model], + model_attributes: List[str]) -> Type[db.Model]: + """ + Patch the attributes from the patch_model onto the original_model when + the attribute values differ. + + :param original_model: The model to apply the patches to + :param patch_model: The model to pull the patch information from + :param model_attributes: The attributes that are valid for patching + :return: Thd patched original_model + """ + patch_set = determine_patch(original_model, patch_model, model_attributes) + restricted_attributes = get_restricted_attribute_names(original_model) + if set(patch_set.keys()).isdisjoint(restricted_attributes): + for attribute, value in patch_set.items(): + setattr(original_model, attribute, value) + else: + raise ValueError('Restricted attributes modified. Invalid Patch Set.') + return original_model + + +def versioning_aware_patch(original_model: Type[db.Model], + patch_model: Type[db.Model], + model_attributes: List[str]) -> Type[db.Model]: + """ + Account for version numbers in the model. + + Versions must match to perform the patching. Otherwise a simultaneous edit + error has occurred. If the versions match and the patch moves forward, bump + the version on the model by 1 to prevent other reads from performing a + simultaneous edit. + + :param original_model: The model to apply the patches to + :param patch_model: The model to pull the patch information from + :param model_attributes: The attributes that are valid for patching + :return: Thd patched original_model + """ + if original_model.version == patch_model.version: + patch_model.version = patch_model.version + 1 + return perform_patch(original_model, patch_model, model_attributes) + else: + raise ValueError() + + +def patch( + original_model: Type[db.Model], + patch_model: Type[db.Model]) -> Type[db.Model]: + """ + Given two matching models, patch the original model + with the patch model data. + + :param original_model: The model to apply the patches to + :param patch_model: The model to pull the patch information from + :return: The patched original_model + """ + if type(original_model) is type(patch_model): + model_attributes = get_patchable_attribute_names(original_model) + if original_model.id != patch_model.id: + raise ValueError('Cannot change ids through patching') + if 'version' in model_attributes: + return versioning_aware_patch( + original_model, patch_model, model_attributes) + else: + return perform_patch(original_model, patch_model, model_attributes) + else: + raise ValueError( + 'Model types "{}" and "{}" do not match'.format( + original_model.__class__.__name__, + patch_model.__class__.__name__ + )) diff --git a/server/tests/service/test_patch_service.py b/server/tests/service/test_patch_service.py new file mode 100644 index 0000000..21f7c97 --- /dev/null +++ b/server/tests/service/test_patch_service.py @@ -0,0 +1,65 @@ +from datetime import datetime, timedelta + +import pytest + +from atheneum.model import UserToken, User +from atheneum.service import patch_service + + +def test_patch_models(): + user = User() + user.name = 'TestUser' + user.version = 1 + user.last_login_time = datetime.now() - timedelta(days=1) + + user_patch = User() + user_patch.name = 'TestUser' + user_patch.version = 1 + user_patch.last_login_time = datetime.now() + + patched_user = patch_service.patch(user, user_patch) + assert patched_user.version > 1 + assert patched_user.last_login_time == user_patch.last_login_time + + +def test_patch_of_different_types(): + user = User() + user_token = UserToken() + + with pytest.raises(ValueError) as error_info: + patch_service.patch(user, user_token) + + +def test_patch_different_ids(): + user1 = User() + user1.id = 1 + + user2 = User() + user2.id = 2 + + with pytest.raises(ValueError) as error_info: + patch_service.patch(user1, user2) + + +def test_patch_different_versions(): + user1 = User() + user1.version = 1 + + user2 = User() + user2.version = 2 + + with pytest.raises(ValueError) as error_info: + patch_service.patch(user1, user2) + + +def test_patch_restricted_attributes(): + user1 = User() + user1.version = 1 + user1.name = 'Bob' + + user2 = User() + user2.version = 1 + user2.name = 'Chris' + + with pytest.raises(ValueError) as error_info: + patch_service.patch(user1, user2)