From ba9c6aa3f6f24daa81bedcd2098d0066b6aa433a Mon Sep 17 00:00:00 2001 From: Drew Short Date: Thu, 26 Jul 2018 17:27:18 -0500 Subject: [PATCH] Patching service backed by validation service. --- server/atheneum/api/user_api.py | 12 +- server/atheneum/service/patch_service.py | 129 ++++---------- server/atheneum/service/role_service.py | 22 ++- .../service/transformation_service.py | 1 - server/atheneum/service/user_service.py | 58 +++++- server/atheneum/service/validation_service.py | 165 ++++++++++++++++++ server/tests/service/test_patch_service.py | 35 +++- server/tests/service/test_role_service.py | 7 + .../tests/service/test_validation_service.py | 25 +++ 9 files changed, 342 insertions(+), 112 deletions(-) create mode 100644 server/atheneum/service/validation_service.py create mode 100644 server/tests/service/test_validation_service.py diff --git a/server/atheneum/api/user_api.py b/server/atheneum/api/user_api.py index af0fe61..ef382f2 100644 --- a/server/atheneum/api/user_api.py +++ b/server/atheneum/api/user_api.py @@ -1,11 +1,15 @@ """User API blueprint and endpoint definitions.""" -from flask import Blueprint, abort, request +from flask import Blueprint, abort, request, g from atheneum.api.decorators import return_json from atheneum.api.model import APIResponse from atheneum.middleware import authentication_middleware from atheneum.model import User -from atheneum.service import user_service, transformation_service, patch_service +from atheneum.service import ( + patch_service, + transformation_service, + user_service +) from atheneum.service.role_service import Role USER_BLUEPRINT = Blueprint( @@ -43,10 +47,10 @@ def patch_user(name: str) -> APIResponse: user_patch: User = transformation_service.deserialize_model( User, request.json) try: - user = patch_service.patch(user, user_patch) + patched_user = patch_service.patch(g.user, user, user_patch) + return APIResponse(patched_user, 200) except ValueError: return abort(400) - return APIResponse(user, 200) return abort(404) diff --git a/server/atheneum/service/patch_service.py b/server/atheneum/service/patch_service.py index 49af7f6..1728427 100644 --- a/server/atheneum/service/patch_service.py +++ b/server/atheneum/service/patch_service.py @@ -1,109 +1,43 @@ -from typing import Type, List, Any, Dict +"""Patching support for db.Model objects.""" -from sqlalchemy import orm +from typing import Type, Set from atheneum import db +from atheneum.model import User +from atheneum.service import validation_service -_patchable_attribute_names = {} -_restricted_attribute_names = {} - -def get_patchable_attribute_names(model: Type[db.Model]) -> List[str]: - """ - Retrieve columns from a SQLAlchemy model. - Caches already seen models to improve performance. - - :param model: - :return: A list of patchable model attribute names - """ - class_name = model.__class__.__name__ - if class_name in _patchable_attribute_names: - return _patchable_attribute_names[class_name] - else: - model_attributes = [prop.key for prop in - orm.class_mapper(model.__class__).iterate_properties - if isinstance(prop, orm.ColumnProperty)] - _patchable_attribute_names[class_name] = model_attributes - return model_attributes - - -def is_restricted_attribute(column_property: orm.ColumnProperty) -> bool: - """ - A primary_key or unique columns are not patchable - - :param column_property: The SQLAlchemy column element - :return: A boolean indicating if a field is restricted - """ - column = column_property.columns[0] - return column.primary_key or column.unique is True - - -def get_restricted_attribute_names(model: Type[db.Model]) -> List[str]: - """ - Retrieve primary_key or unique columns from a SQLAlchemy model. - Caches already seen models to improve performance. - - :param model: - :return: A list of patchable model attribute names - """ - class_name = model.__class__.__name__ - if class_name in _restricted_attribute_names: - return _restricted_attribute_names[class_name] - else: - model_attributes = [prop.key for prop in - orm.class_mapper(model.__class__).iterate_properties - if isinstance(prop, orm.ColumnProperty) - and is_restricted_attribute(prop)] - _restricted_attribute_names[class_name] = model_attributes - return model_attributes - - -def determine_patch(original_model: Type[db.Model], - patch_model: Type[db.Model], - model_attributes: List[str] - ) -> Dict[str, Any]: - """ - Determine the patch set for two models - - :param original_model: - :param patch_model: - :param model_attributes: - :return: - """ - patch_set = {} - for attribute in model_attributes: - original_attribute = getattr(original_model, attribute) - patch_attribute = getattr(patch_model, attribute) - if original_attribute != patch_attribute: - patch_set[attribute] = patch_attribute - return patch_set - - -def perform_patch(original_model: Type[db.Model], +def perform_patch(request_user: User, + original_model: Type[db.Model], patch_model: Type[db.Model], - model_attributes: List[str]) -> Type[db.Model]: + model_attributes: Set[str]) -> Type[db.Model]: """ - Patch the attributes from the patch_model onto the original_model when - the attribute values differ. + Patch changed attributes onto original model. + :param request_user: :param original_model: The model to apply the patches to :param patch_model: The model to pull the patch information from :param model_attributes: The attributes that are valid for patching :return: Thd patched original_model """ - patch_set = determine_patch(original_model, patch_model, model_attributes) - restricted_attributes = get_restricted_attribute_names(original_model) - if set(patch_set.keys()).isdisjoint(restricted_attributes): - for attribute, value in patch_set.items(): + change_set = validation_service.determine_change_set( + original_model, patch_model, model_attributes) + model_validation = validation_service.validate_model( + request_user, original_model, change_set) + if model_validation.success: + for attribute, value in change_set.items(): setattr(original_model, attribute, value) + db.session.commit() else: - raise ValueError('Restricted attributes modified. Invalid Patch Set.') + raise ValueError( + 'Restricted attributes modified. Invalid Patch Set.') return original_model -def versioning_aware_patch(original_model: Type[db.Model], +def versioning_aware_patch(request_user: User, + original_model: Type[db.Model], patch_model: Type[db.Model], - model_attributes: List[str]) -> Type[db.Model]: + model_attributes: Set[str]) -> Type[db.Model]: """ Account for version numbers in the model. @@ -112,6 +46,7 @@ def versioning_aware_patch(original_model: Type[db.Model], the version on the model by 1 to prevent other reads from performing a simultaneous edit. + :param request_user: :param original_model: The model to apply the patches to :param patch_model: The model to pull the patch information from :param model_attributes: The attributes that are valid for patching @@ -119,31 +54,33 @@ def versioning_aware_patch(original_model: Type[db.Model], """ if original_model.version == patch_model.version: patch_model.version = patch_model.version + 1 - return perform_patch(original_model, patch_model, model_attributes) - else: - raise ValueError() + return perform_patch( + request_user, original_model, patch_model, model_attributes) + raise ValueError('Versions do not match. Concurrent edit in progress.') def patch( + request_user: User, original_model: Type[db.Model], patch_model: Type[db.Model]) -> Type[db.Model]: """ - Given two matching models, patch the original model - with the patch model data. + Patch the original model with the patch model data. + :param request_user: :param original_model: The model to apply the patches to :param patch_model: The model to pull the patch information from :return: The patched original_model """ if type(original_model) is type(patch_model): - model_attributes = get_patchable_attribute_names(original_model) + model_attributes = validation_service.get_changable_attribute_names( + original_model) if original_model.id != patch_model.id: raise ValueError('Cannot change ids through patching') if 'version' in model_attributes: return versioning_aware_patch( - original_model, patch_model, model_attributes) - else: - return perform_patch(original_model, patch_model, model_attributes) + request_user, original_model, patch_model, model_attributes) + return perform_patch( + request_user, original_model, patch_model, model_attributes) else: raise ValueError( 'Model types "{}" and "{}" do not match'.format( diff --git a/server/atheneum/service/role_service.py b/server/atheneum/service/role_service.py index 6ea309d..95d3d01 100644 --- a/server/atheneum/service/role_service.py +++ b/server/atheneum/service/role_service.py @@ -66,13 +66,33 @@ class RoleTree(defaultdict): return [self.data] + self.parent.get_parent_roles() return [self.data] + def get_children_roles(self) -> List[Role]: + """Return all the roles from self to the lowest child.""" + if self.roles and ( + len(self.roles.keys()) > 1 or len(self.roles[self.data]) > 1): + child_roles = [self.data] + for role in self.roles.keys(): + for role_tree in self.roles[role]: + if role_tree.data != self.data: + child_roles.extend(role_tree.get_children_roles()) + return child_roles + return [self.data] + def find_roles_in_hierarchy(self, request_role: Role) -> Set[Role]: """Find a set of all roles that fall within the hierarchy.""" roles: List[Role] = [] role_trees = self.find_role(request_role) for role_tree in role_trees: roles.extend(role_tree.get_parent_roles()) - return set(role for role in roles) + return set(roles) + + def find_children_roles(self, request_role: Role) -> Set[Role]: + """Find all children roles, including this role.""" + roles: List[Role] = [] + role_trees = self.find_role(request_role) + for role_tree in role_trees: + roles.extend(role_tree.get_children_roles()) + return set(roles) ROLES = RoleTree(None, Role.ADMIN) diff --git a/server/atheneum/service/transformation_service.py b/server/atheneum/service/transformation_service.py index 20eaee6..f9781f1 100644 --- a/server/atheneum/service/transformation_service.py +++ b/server/atheneum/service/transformation_service.py @@ -15,7 +15,6 @@ class BaseTransformer: def __init__(self, model: Type[db.Model]) -> None: """Initialize the base serializer.""" - self._fields: Dict[str, Callable[[db.Model], Any]] = {} self.model = model def serialize(self, options: Optional[List[str]]) -> Any: diff --git a/server/atheneum/service/user_service.py b/server/atheneum/service/user_service.py index 7fbc042..9d8cde6 100644 --- a/server/atheneum/service/user_service.py +++ b/server/atheneum/service/user_service.py @@ -3,16 +3,20 @@ import logging import random import string from datetime import datetime -from typing import Optional, Dict, Callable, Any +from typing import Optional, Dict, Callable, Any, Tuple from atheneum import errors from atheneum.db import db from atheneum.model import User -from atheneum.service.role_service import Role +from atheneum.service import role_service from atheneum.service.transformation_service import ( BaseTransformer, register_transformer ) +from atheneum.service.validation_service import ( + BaseValidator, + register_validator +) from atheneum.utility import authentication_utility LOGGER = logging.getLogger(__name__) @@ -90,7 +94,55 @@ class UserTransformer(BaseTransformer): @staticmethod def deserialize_role(model: User, role_value: str) -> None: """User role.""" - model.role = Role(role_value) + model.role = role_service.Role(role_value) + + +@register_validator +class UserValidator(BaseValidator): + """Validate User model.""" + + type = User + + def _validators( + self) -> Dict[str, Callable[[Any], Tuple[bool, str]]]: + return { + 'id': self.no_validation, + 'name': self.validate_name, + 'role': self.validate_role, + 'password_hash': self.no_validation, + 'password_revision': self.no_validation, + 'creation_time': self.no_validation, + 'last_login_time': self.no_validation, + 'version': self.validate_version + } + + def validate_name(self, new_name: Any) -> Tuple[bool, str]: + """ + Name changes are only allowed to be performed by an Admin. + + :param new_name: + :return: + """ + validation_result = (self.request_user.role == role_service.Role.ADMIN + or new_name is None) + if validation_result: + return validation_result, '' + return (validation_result, + 'Names can only be changed by an administrator') + + def validate_role(self, new_role: Any) -> Tuple[bool, str]: + """ + Roles can only be increased to the level of the request_user. + + :param new_role: + :return: + """ + acceptable_roles = role_service.ROLES.find_children_roles( + self.request_user.role) + role = new_role if new_role is not None else self.model.role + if role in acceptable_roles: + return True, '' + return False, 'Role escalation is not permitted' def find_by_name(name: str) -> Optional[User]: diff --git a/server/atheneum/service/validation_service.py b/server/atheneum/service/validation_service.py new file mode 100644 index 0000000..b225333 --- /dev/null +++ b/server/atheneum/service/validation_service.py @@ -0,0 +1,165 @@ +"""Validation service for Atheneum models.""" + +from typing import Type, Dict, Callable, Any, Set, Optional, Tuple + +from sqlalchemy import orm + +from atheneum import db, errors +from atheneum.model import User + +_changable_attribute_names: Dict[str, Set[str]] = {} + + +def get_changable_attribute_names(model: Type[db.Model]) -> Set[str]: + """ + Retrieve columns from a SQLAlchemy model. + + Caches already seen models to improve performance. + + :param model: + :return: A list of changeable model attribute names + """ + class_name = model.__class__.__name__ + if class_name in _changable_attribute_names: + return _changable_attribute_names[class_name] + + model_attributes = set([prop.key for prop in + orm.class_mapper( + model.__class__).iterate_properties + if isinstance(prop, orm.ColumnProperty)]) + _changable_attribute_names[class_name] = model_attributes + return model_attributes + + +def determine_change_set(original_model: Type[db.Model], + update_model: Type[db.Model], + model_attributes: Set[str]) -> Dict[str, Any]: + """ + Determine the change set for two models. + + :param original_model: + :param update_model: + :param model_attributes: + :return: + """ + change_set = {} + for attribute in model_attributes: + original_attribute = getattr(original_model, attribute) + changed_attribute = getattr(update_model, attribute) + if original_attribute != changed_attribute: + change_set[attribute] = changed_attribute + return change_set + + +class ModelValidationResult: # pylint: disable=too-few-public-methods + """Result from model validation.""" + + field_results: Dict[str, Tuple[bool, str]] + success: bool + failed: Dict[str, str] = {} + + def __init__(self, field_results: Dict[str, Tuple[bool, str]]) -> None: + """Initialize the validation results.""" + self.field_results = field_results + self.success = len( + [result for (result, _) in self.field_results.values() if + result is False]) == 0 + if not self.success: + failed = [(field, rslt[1]) for (field, rslt) in + self.field_results.items() if rslt[0] is False] + self.failed = {} + for field, reason in failed: + self.failed[field] = reason + + +def get_change_set_value( + change_set: Optional[Dict[str, Any]], field: str) -> Any: + """Read a value or default from changeset.""" + if change_set is not None and field in change_set.keys(): + return change_set[field] + return None + + +class BaseValidator: + """Base Model validator.""" + + type: Type[db.Model] + + def __init__(self, request_user: User, model: Type[db.Model]) -> None: + """Initialize the base validator.""" + self.request_user = request_user + self._fields: Set[str] = get_changable_attribute_names(model) + self.model = model + + def validate(self, + change_set: Optional[Dict[str, Any]] = None) \ + -> ModelValidationResult: + """Validate Model fields.""" + field_validators = self._validators() + fields_to_validate = self._fields + if change_set: + fields_to_validate = set(change_set.keys()) + validation_results: Dict[str, Tuple[bool, str]] = {} + for field in fields_to_validate: + if field not in field_validators: + raise errors.ValidationError( + 'Invalid key: %r. Valid keys: %r.' % ( + field, list(sorted(field_validators.keys())))) + field_validator = field_validators[field] + field_result = field_validator( + get_change_set_value(change_set, field)) + validation_results[field] = field_result + return ModelValidationResult(validation_results) + + def _validators( + self) -> Dict[str, Callable[[Any], Tuple[bool, str]]]: + """Field definitions.""" + raise NotImplementedError() + + @staticmethod + def no_validation(_new_value: Any) -> Tuple[bool, str]: + """Perform no validation.""" + return True, '' + + def validate_version(self, new_version: Any) -> Tuple[bool, str]: + """Perform a standard version validation.""" + if new_version is not None: + version_increasing = self.model.version <= new_version + if version_increasing: + return version_increasing, '' + return version_increasing, 'Unacceptable version change' + return True, '' + + +_model_validators: Dict[str, Type[BaseValidator]] = {} + + +def register_validator( + model_validator: Type[BaseValidator]) -> Type[BaseValidator]: + """Add a model to the serializer mapping.""" + model_name = model_validator.type.__name__ + if model_name not in _model_validators: + _model_validators[model_name] = model_validator + else: + raise KeyError( + ' '.join([ + 'A validator for type "{}" already exists with class "{}".', + 'Cannot register a new validator with class "{}"' + ]).format( + model_name, + _model_validators[model_name].__name__, + model_validator.__name__)) + return model_validator + + +def validate_model(request_user: User, + model_obj: db.Model, + change_set: Optional[Dict[str, Any]] = None) \ + -> ModelValidationResult: + """Lookup a Model and hand off to the validator.""" + try: + return _model_validators[type(model_obj).__name__]( + request_user, model_obj).validate(change_set) + except KeyError: + raise NotImplementedError( + '{} has no registered validator'.format(model_obj.__name__)) diff --git a/server/tests/service/test_patch_service.py b/server/tests/service/test_patch_service.py index 21f7c97..63adda7 100644 --- a/server/tests/service/test_patch_service.py +++ b/server/tests/service/test_patch_service.py @@ -1,12 +1,20 @@ from datetime import datetime, timedelta import pytest +from mock import MagicMock, patch from atheneum.model import UserToken, User -from atheneum.service import patch_service +from atheneum.service import patch_service, role_service +service_module = 'atheneum.service.patch_service' + + +@patch(service_module + '.db.session.commit') +def test_patch_models( + mock_db_session_commit: MagicMock): + request_user = User() + request_user.role = role_service.Role.ADMIN -def test_patch_models(): user = User() user.name = 'TestUser' user.version = 1 @@ -17,20 +25,27 @@ def test_patch_models(): user_patch.version = 1 user_patch.last_login_time = datetime.now() - patched_user = patch_service.patch(user, user_patch) + patched_user = patch_service.patch(request_user, user, user_patch) assert patched_user.version > 1 assert patched_user.last_login_time == user_patch.last_login_time + mock_db_session_commit.assert_called_once() def test_patch_of_different_types(): + request_user = User() + request_user.role = role_service.Role.ADMIN + user = User() user_token = UserToken() with pytest.raises(ValueError) as error_info: - patch_service.patch(user, user_token) + patch_service.patch(request_user, user, user_token) def test_patch_different_ids(): + request_user = User() + request_user.role = role_service.Role.ADMIN + user1 = User() user1.id = 1 @@ -38,10 +53,13 @@ def test_patch_different_ids(): user2.id = 2 with pytest.raises(ValueError) as error_info: - patch_service.patch(user1, user2) + patch_service.patch(request_user, user1, user2) def test_patch_different_versions(): + request_user = User() + request_user.role = role_service.Role.ADMIN + user1 = User() user1.version = 1 @@ -49,10 +67,13 @@ def test_patch_different_versions(): user2.version = 2 with pytest.raises(ValueError) as error_info: - patch_service.patch(user1, user2) + patch_service.patch(request_user, user1, user2) def test_patch_restricted_attributes(): + request_user = User() + request_user.role = role_service.Role.USER + user1 = User() user1.version = 1 user1.name = 'Bob' @@ -62,4 +83,4 @@ def test_patch_restricted_attributes(): user2.name = 'Chris' with pytest.raises(ValueError) as error_info: - patch_service.patch(user1, user2) + patch_service.patch(request_user, user1, user2) diff --git a/server/tests/service/test_role_service.py b/server/tests/service/test_role_service.py index 3d3fa9b..48d2c51 100644 --- a/server/tests/service/test_role_service.py +++ b/server/tests/service/test_role_service.py @@ -15,5 +15,12 @@ def test_role_tree_find_roles_in_hierarchy(): assert Role.ADMIN in roles +def test_role_tree_find_children_roles(): + roles = ROLES.find_children_roles(Role.USER) + assert len(roles) == 2 + assert Role.USER in roles + assert Role.ANONYMOUS in roles + + def test_role_tree_find_role_key_error(): assert len(ROLES.find_role(Role.NONE)) == 0 diff --git a/server/tests/service/test_validation_service.py b/server/tests/service/test_validation_service.py new file mode 100644 index 0000000..1b3f22a --- /dev/null +++ b/server/tests/service/test_validation_service.py @@ -0,0 +1,25 @@ +from atheneum.model import User +from atheneum.service import role_service, validation_service + + +def test_successful_validation(): + request_user = User() + request_user.role = role_service.Role.USER + + user = User() + user.role = role_service.Role.USER + + validation_result = validation_service.validate_model(request_user, user) + assert validation_result.success + + +def test_failed_validation(): + request_user = User() + request_user.role = role_service.Role.ANONYMOUS + + user = User() + user.role = role_service.Role.USER + + validation_result = validation_service.validate_model(request_user, user) + assert validation_result.success is False + assert 'role' in validation_result.failed