Browse Source

Patching service backed by validation service.

merge-requests/1/head
Drew Short 6 years ago
parent
commit
ba9c6aa3f6
  1. 12
      server/atheneum/api/user_api.py
  2. 129
      server/atheneum/service/patch_service.py
  3. 22
      server/atheneum/service/role_service.py
  4. 1
      server/atheneum/service/transformation_service.py
  5. 58
      server/atheneum/service/user_service.py
  6. 165
      server/atheneum/service/validation_service.py
  7. 35
      server/tests/service/test_patch_service.py
  8. 7
      server/tests/service/test_role_service.py
  9. 25
      server/tests/service/test_validation_service.py

12
server/atheneum/api/user_api.py

@ -1,11 +1,15 @@
"""User API blueprint and endpoint definitions."""
from flask import Blueprint, abort, request
from flask import Blueprint, abort, request, g
from atheneum.api.decorators import return_json
from atheneum.api.model import APIResponse
from atheneum.middleware import authentication_middleware
from atheneum.model import User
from atheneum.service import user_service, transformation_service, patch_service
from atheneum.service import (
patch_service,
transformation_service,
user_service
)
from atheneum.service.role_service import Role
USER_BLUEPRINT = Blueprint(
@ -43,10 +47,10 @@ def patch_user(name: str) -> APIResponse:
user_patch: User = transformation_service.deserialize_model(
User, request.json)
try:
user = patch_service.patch(user, user_patch)
patched_user = patch_service.patch(g.user, user, user_patch)
return APIResponse(patched_user, 200)
except ValueError:
return abort(400)
return APIResponse(user, 200)
return abort(404)

129
server/atheneum/service/patch_service.py

@ -1,109 +1,43 @@
from typing import Type, List, Any, Dict
"""Patching support for db.Model objects."""
from sqlalchemy import orm
from typing import Type, Set
from atheneum import db
from atheneum.model import User
from atheneum.service import validation_service
_patchable_attribute_names = {}
_restricted_attribute_names = {}
def get_patchable_attribute_names(model: Type[db.Model]) -> List[str]:
"""
Retrieve columns from a SQLAlchemy model.
Caches already seen models to improve performance.
:param model:
:return: A list of patchable model attribute names
"""
class_name = model.__class__.__name__
if class_name in _patchable_attribute_names:
return _patchable_attribute_names[class_name]
else:
model_attributes = [prop.key for prop in
orm.class_mapper(model.__class__).iterate_properties
if isinstance(prop, orm.ColumnProperty)]
_patchable_attribute_names[class_name] = model_attributes
return model_attributes
def is_restricted_attribute(column_property: orm.ColumnProperty) -> bool:
"""
A primary_key or unique columns are not patchable
:param column_property: The SQLAlchemy column element
:return: A boolean indicating if a field is restricted
"""
column = column_property.columns[0]
return column.primary_key or column.unique is True
def get_restricted_attribute_names(model: Type[db.Model]) -> List[str]:
"""
Retrieve primary_key or unique columns from a SQLAlchemy model.
Caches already seen models to improve performance.
:param model:
:return: A list of patchable model attribute names
"""
class_name = model.__class__.__name__
if class_name in _restricted_attribute_names:
return _restricted_attribute_names[class_name]
else:
model_attributes = [prop.key for prop in
orm.class_mapper(model.__class__).iterate_properties
if isinstance(prop, orm.ColumnProperty)
and is_restricted_attribute(prop)]
_restricted_attribute_names[class_name] = model_attributes
return model_attributes
def determine_patch(original_model: Type[db.Model],
patch_model: Type[db.Model],
model_attributes: List[str]
) -> Dict[str, Any]:
"""
Determine the patch set for two models
:param original_model:
:param patch_model:
:param model_attributes:
:return:
"""
patch_set = {}
for attribute in model_attributes:
original_attribute = getattr(original_model, attribute)
patch_attribute = getattr(patch_model, attribute)
if original_attribute != patch_attribute:
patch_set[attribute] = patch_attribute
return patch_set
def perform_patch(original_model: Type[db.Model],
def perform_patch(request_user: User,
original_model: Type[db.Model],
patch_model: Type[db.Model],
model_attributes: List[str]) -> Type[db.Model]:
model_attributes: Set[str]) -> Type[db.Model]:
"""
Patch the attributes from the patch_model onto the original_model when
the attribute values differ.
Patch changed attributes onto original model.
:param request_user:
:param original_model: The model to apply the patches to
:param patch_model: The model to pull the patch information from
:param model_attributes: The attributes that are valid for patching
:return: Thd patched original_model
"""
patch_set = determine_patch(original_model, patch_model, model_attributes)
restricted_attributes = get_restricted_attribute_names(original_model)
if set(patch_set.keys()).isdisjoint(restricted_attributes):
for attribute, value in patch_set.items():
change_set = validation_service.determine_change_set(
original_model, patch_model, model_attributes)
model_validation = validation_service.validate_model(
request_user, original_model, change_set)
if model_validation.success:
for attribute, value in change_set.items():
setattr(original_model, attribute, value)
db.session.commit()
else:
raise ValueError('Restricted attributes modified. Invalid Patch Set.')
raise ValueError(
'Restricted attributes modified. Invalid Patch Set.')
return original_model
def versioning_aware_patch(original_model: Type[db.Model],
def versioning_aware_patch(request_user: User,
original_model: Type[db.Model],
patch_model: Type[db.Model],
model_attributes: List[str]) -> Type[db.Model]:
model_attributes: Set[str]) -> Type[db.Model]:
"""
Account for version numbers in the model.
@ -112,6 +46,7 @@ def versioning_aware_patch(original_model: Type[db.Model],
the version on the model by 1 to prevent other reads from performing a
simultaneous edit.
:param request_user:
:param original_model: The model to apply the patches to
:param patch_model: The model to pull the patch information from
:param model_attributes: The attributes that are valid for patching
@ -119,31 +54,33 @@ def versioning_aware_patch(original_model: Type[db.Model],
"""
if original_model.version == patch_model.version:
patch_model.version = patch_model.version + 1
return perform_patch(original_model, patch_model, model_attributes)
else:
raise ValueError()
return perform_patch(
request_user, original_model, patch_model, model_attributes)
raise ValueError('Versions do not match. Concurrent edit in progress.')
def patch(
request_user: User,
original_model: Type[db.Model],
patch_model: Type[db.Model]) -> Type[db.Model]:
"""
Given two matching models, patch the original model
with the patch model data.
Patch the original model with the patch model data.
:param request_user:
:param original_model: The model to apply the patches to
:param patch_model: The model to pull the patch information from
:return: The patched original_model
"""
if type(original_model) is type(patch_model):
model_attributes = get_patchable_attribute_names(original_model)
model_attributes = validation_service.get_changable_attribute_names(
original_model)
if original_model.id != patch_model.id:
raise ValueError('Cannot change ids through patching')
if 'version' in model_attributes:
return versioning_aware_patch(
original_model, patch_model, model_attributes)
else:
return perform_patch(original_model, patch_model, model_attributes)
request_user, original_model, patch_model, model_attributes)
return perform_patch(
request_user, original_model, patch_model, model_attributes)
else:
raise ValueError(
'Model types "{}" and "{}" do not match'.format(

22
server/atheneum/service/role_service.py

@ -66,13 +66,33 @@ class RoleTree(defaultdict):
return [self.data] + self.parent.get_parent_roles()
return [self.data]
def get_children_roles(self) -> List[Role]:
"""Return all the roles from self to the lowest child."""
if self.roles and (
len(self.roles.keys()) > 1 or len(self.roles[self.data]) > 1):
child_roles = [self.data]
for role in self.roles.keys():
for role_tree in self.roles[role]:
if role_tree.data != self.data:
child_roles.extend(role_tree.get_children_roles())
return child_roles
return [self.data]
def find_roles_in_hierarchy(self, request_role: Role) -> Set[Role]:
"""Find a set of all roles that fall within the hierarchy."""
roles: List[Role] = []
role_trees = self.find_role(request_role)
for role_tree in role_trees:
roles.extend(role_tree.get_parent_roles())
return set(role for role in roles)
return set(roles)
def find_children_roles(self, request_role: Role) -> Set[Role]:
"""Find all children roles, including this role."""
roles: List[Role] = []
role_trees = self.find_role(request_role)
for role_tree in role_trees:
roles.extend(role_tree.get_children_roles())
return set(roles)
ROLES = RoleTree(None, Role.ADMIN)

1
server/atheneum/service/transformation_service.py

@ -15,7 +15,6 @@ class BaseTransformer:
def __init__(self, model: Type[db.Model]) -> None:
"""Initialize the base serializer."""
self._fields: Dict[str, Callable[[db.Model], Any]] = {}
self.model = model
def serialize(self, options: Optional[List[str]]) -> Any:

58
server/atheneum/service/user_service.py

@ -3,16 +3,20 @@ import logging
import random
import string
from datetime import datetime
from typing import Optional, Dict, Callable, Any
from typing import Optional, Dict, Callable, Any, Tuple
from atheneum import errors
from atheneum.db import db
from atheneum.model import User
from atheneum.service.role_service import Role
from atheneum.service import role_service
from atheneum.service.transformation_service import (
BaseTransformer,
register_transformer
)
from atheneum.service.validation_service import (
BaseValidator,
register_validator
)
from atheneum.utility import authentication_utility
LOGGER = logging.getLogger(__name__)
@ -90,7 +94,55 @@ class UserTransformer(BaseTransformer):
@staticmethod
def deserialize_role(model: User, role_value: str) -> None:
"""User role."""
model.role = Role(role_value)
model.role = role_service.Role(role_value)
@register_validator
class UserValidator(BaseValidator):
"""Validate User model."""
type = User
def _validators(
self) -> Dict[str, Callable[[Any], Tuple[bool, str]]]:
return {
'id': self.no_validation,
'name': self.validate_name,
'role': self.validate_role,
'password_hash': self.no_validation,
'password_revision': self.no_validation,
'creation_time': self.no_validation,
'last_login_time': self.no_validation,
'version': self.validate_version
}
def validate_name(self, new_name: Any) -> Tuple[bool, str]:
"""
Name changes are only allowed to be performed by an Admin.
:param new_name:
:return:
"""
validation_result = (self.request_user.role == role_service.Role.ADMIN
or new_name is None)
if validation_result:
return validation_result, ''
return (validation_result,
'Names can only be changed by an administrator')
def validate_role(self, new_role: Any) -> Tuple[bool, str]:
"""
Roles can only be increased to the level of the request_user.
:param new_role:
:return:
"""
acceptable_roles = role_service.ROLES.find_children_roles(
self.request_user.role)
role = new_role if new_role is not None else self.model.role
if role in acceptable_roles:
return True, ''
return False, 'Role escalation is not permitted'
def find_by_name(name: str) -> Optional[User]:

165
server/atheneum/service/validation_service.py

@ -0,0 +1,165 @@
"""Validation service for Atheneum models."""
from typing import Type, Dict, Callable, Any, Set, Optional, Tuple
from sqlalchemy import orm
from atheneum import db, errors
from atheneum.model import User
_changable_attribute_names: Dict[str, Set[str]] = {}
def get_changable_attribute_names(model: Type[db.Model]) -> Set[str]:
"""
Retrieve columns from a SQLAlchemy model.
Caches already seen models to improve performance.
:param model:
:return: A list of changeable model attribute names
"""
class_name = model.__class__.__name__
if class_name in _changable_attribute_names:
return _changable_attribute_names[class_name]
model_attributes = set([prop.key for prop in
orm.class_mapper(
model.__class__).iterate_properties
if isinstance(prop, orm.ColumnProperty)])
_changable_attribute_names[class_name] = model_attributes
return model_attributes
def determine_change_set(original_model: Type[db.Model],
update_model: Type[db.Model],
model_attributes: Set[str]) -> Dict[str, Any]:
"""
Determine the change set for two models.
:param original_model:
:param update_model:
:param model_attributes:
:return:
"""
change_set = {}
for attribute in model_attributes:
original_attribute = getattr(original_model, attribute)
changed_attribute = getattr(update_model, attribute)
if original_attribute != changed_attribute:
change_set[attribute] = changed_attribute
return change_set
class ModelValidationResult: # pylint: disable=too-few-public-methods
"""Result from model validation."""
field_results: Dict[str, Tuple[bool, str]]
success: bool
failed: Dict[str, str] = {}
def __init__(self, field_results: Dict[str, Tuple[bool, str]]) -> None:
"""Initialize the validation results."""
self.field_results = field_results
self.success = len(
[result for (result, _) in self.field_results.values() if
result is False]) == 0
if not self.success:
failed = [(field, rslt[1]) for (field, rslt) in
self.field_results.items() if rslt[0] is False]
self.failed = {}
for field, reason in failed:
self.failed[field] = reason
def get_change_set_value(
change_set: Optional[Dict[str, Any]], field: str) -> Any:
"""Read a value or default from changeset."""
if change_set is not None and field in change_set.keys():
return change_set[field]
return None
class BaseValidator:
"""Base Model validator."""
type: Type[db.Model]
def __init__(self, request_user: User, model: Type[db.Model]) -> None:
"""Initialize the base validator."""
self.request_user = request_user
self._fields: Set[str] = get_changable_attribute_names(model)
self.model = model
def validate(self,
change_set: Optional[Dict[str, Any]] = None) \
-> ModelValidationResult:
"""Validate Model fields."""
field_validators = self._validators()
fields_to_validate = self._fields
if change_set:
fields_to_validate = set(change_set.keys())
validation_results: Dict[str, Tuple[bool, str]] = {}
for field in fields_to_validate:
if field not in field_validators:
raise errors.ValidationError(
'Invalid key: %r. Valid keys: %r.' % (
field, list(sorted(field_validators.keys()))))
field_validator = field_validators[field]
field_result = field_validator(
get_change_set_value(change_set, field))
validation_results[field] = field_result
return ModelValidationResult(validation_results)
def _validators(
self) -> Dict[str, Callable[[Any], Tuple[bool, str]]]:
"""Field definitions."""
raise NotImplementedError()
@staticmethod
def no_validation(_new_value: Any) -> Tuple[bool, str]:
"""Perform no validation."""
return True, ''
def validate_version(self, new_version: Any) -> Tuple[bool, str]:
"""Perform a standard version validation."""
if new_version is not None:
version_increasing = self.model.version <= new_version
if version_increasing:
return version_increasing, ''
return version_increasing, 'Unacceptable version change'
return True, ''
_model_validators: Dict[str, Type[BaseValidator]] = {}
def register_validator(
model_validator: Type[BaseValidator]) -> Type[BaseValidator]:
"""Add a model to the serializer mapping."""
model_name = model_validator.type.__name__
if model_name not in _model_validators:
_model_validators[model_name] = model_validator
else:
raise KeyError(
' '.join([
'A validator for type "{}" already exists with class "{}".',
'Cannot register a new validator with class "{}"'
]).format(
model_name,
_model_validators[model_name].__name__,
model_validator.__name__))
return model_validator
def validate_model(request_user: User,
model_obj: db.Model,
change_set: Optional[Dict[str, Any]] = None) \
-> ModelValidationResult:
"""Lookup a Model and hand off to the validator."""
try:
return _model_validators[type(model_obj).__name__](
request_user, model_obj).validate(change_set)
except KeyError:
raise NotImplementedError(
'{} has no registered validator'.format(model_obj.__name__))

35
server/tests/service/test_patch_service.py

@ -1,12 +1,20 @@
from datetime import datetime, timedelta
import pytest
from mock import MagicMock, patch
from atheneum.model import UserToken, User
from atheneum.service import patch_service
from atheneum.service import patch_service, role_service
service_module = 'atheneum.service.patch_service'
@patch(service_module + '.db.session.commit')
def test_patch_models(
mock_db_session_commit: MagicMock):
request_user = User()
request_user.role = role_service.Role.ADMIN
def test_patch_models():
user = User()
user.name = 'TestUser'
user.version = 1
@ -17,20 +25,27 @@ def test_patch_models():
user_patch.version = 1
user_patch.last_login_time = datetime.now()
patched_user = patch_service.patch(user, user_patch)
patched_user = patch_service.patch(request_user, user, user_patch)
assert patched_user.version > 1
assert patched_user.last_login_time == user_patch.last_login_time
mock_db_session_commit.assert_called_once()
def test_patch_of_different_types():
request_user = User()
request_user.role = role_service.Role.ADMIN
user = User()
user_token = UserToken()
with pytest.raises(ValueError) as error_info:
patch_service.patch(user, user_token)
patch_service.patch(request_user, user, user_token)
def test_patch_different_ids():
request_user = User()
request_user.role = role_service.Role.ADMIN
user1 = User()
user1.id = 1
@ -38,10 +53,13 @@ def test_patch_different_ids():
user2.id = 2
with pytest.raises(ValueError) as error_info:
patch_service.patch(user1, user2)
patch_service.patch(request_user, user1, user2)
def test_patch_different_versions():
request_user = User()
request_user.role = role_service.Role.ADMIN
user1 = User()
user1.version = 1
@ -49,10 +67,13 @@ def test_patch_different_versions():
user2.version = 2
with pytest.raises(ValueError) as error_info:
patch_service.patch(user1, user2)
patch_service.patch(request_user, user1, user2)
def test_patch_restricted_attributes():
request_user = User()
request_user.role = role_service.Role.USER
user1 = User()
user1.version = 1
user1.name = 'Bob'
@ -62,4 +83,4 @@ def test_patch_restricted_attributes():
user2.name = 'Chris'
with pytest.raises(ValueError) as error_info:
patch_service.patch(user1, user2)
patch_service.patch(request_user, user1, user2)

7
server/tests/service/test_role_service.py

@ -15,5 +15,12 @@ def test_role_tree_find_roles_in_hierarchy():
assert Role.ADMIN in roles
def test_role_tree_find_children_roles():
roles = ROLES.find_children_roles(Role.USER)
assert len(roles) == 2
assert Role.USER in roles
assert Role.ANONYMOUS in roles
def test_role_tree_find_role_key_error():
assert len(ROLES.find_role(Role.NONE)) == 0

25
server/tests/service/test_validation_service.py

@ -0,0 +1,25 @@
from atheneum.model import User
from atheneum.service import role_service, validation_service
def test_successful_validation():
request_user = User()
request_user.role = role_service.Role.USER
user = User()
user.role = role_service.Role.USER
validation_result = validation_service.validate_model(request_user, user)
assert validation_result.success
def test_failed_validation():
request_user = User()
request_user.role = role_service.Role.ANONYMOUS
user = User()
user.role = role_service.Role.USER
validation_result = validation_service.validate_model(request_user, user)
assert validation_result.success is False
assert 'role' in validation_result.failed
Loading…
Cancel
Save