Drew Short
6 years ago
9 changed files with 342 additions and 112 deletions
-
12server/atheneum/api/user_api.py
-
129server/atheneum/service/patch_service.py
-
22server/atheneum/service/role_service.py
-
1server/atheneum/service/transformation_service.py
-
58server/atheneum/service/user_service.py
-
165server/atheneum/service/validation_service.py
-
35server/tests/service/test_patch_service.py
-
7server/tests/service/test_role_service.py
-
25server/tests/service/test_validation_service.py
@ -0,0 +1,165 @@ |
|||
"""Validation service for Atheneum models.""" |
|||
|
|||
from typing import Type, Dict, Callable, Any, Set, Optional, Tuple |
|||
|
|||
from sqlalchemy import orm |
|||
|
|||
from atheneum import db, errors |
|||
from atheneum.model import User |
|||
|
|||
_changable_attribute_names: Dict[str, Set[str]] = {} |
|||
|
|||
|
|||
def get_changable_attribute_names(model: Type[db.Model]) -> Set[str]: |
|||
""" |
|||
Retrieve columns from a SQLAlchemy model. |
|||
|
|||
Caches already seen models to improve performance. |
|||
|
|||
:param model: |
|||
:return: A list of changeable model attribute names |
|||
""" |
|||
class_name = model.__class__.__name__ |
|||
if class_name in _changable_attribute_names: |
|||
return _changable_attribute_names[class_name] |
|||
|
|||
model_attributes = set([prop.key for prop in |
|||
orm.class_mapper( |
|||
model.__class__).iterate_properties |
|||
if isinstance(prop, orm.ColumnProperty)]) |
|||
_changable_attribute_names[class_name] = model_attributes |
|||
return model_attributes |
|||
|
|||
|
|||
def determine_change_set(original_model: Type[db.Model], |
|||
update_model: Type[db.Model], |
|||
model_attributes: Set[str]) -> Dict[str, Any]: |
|||
""" |
|||
Determine the change set for two models. |
|||
|
|||
:param original_model: |
|||
:param update_model: |
|||
:param model_attributes: |
|||
:return: |
|||
""" |
|||
change_set = {} |
|||
for attribute in model_attributes: |
|||
original_attribute = getattr(original_model, attribute) |
|||
changed_attribute = getattr(update_model, attribute) |
|||
if original_attribute != changed_attribute: |
|||
change_set[attribute] = changed_attribute |
|||
return change_set |
|||
|
|||
|
|||
class ModelValidationResult: # pylint: disable=too-few-public-methods |
|||
"""Result from model validation.""" |
|||
|
|||
field_results: Dict[str, Tuple[bool, str]] |
|||
success: bool |
|||
failed: Dict[str, str] = {} |
|||
|
|||
def __init__(self, field_results: Dict[str, Tuple[bool, str]]) -> None: |
|||
"""Initialize the validation results.""" |
|||
self.field_results = field_results |
|||
self.success = len( |
|||
[result for (result, _) in self.field_results.values() if |
|||
result is False]) == 0 |
|||
if not self.success: |
|||
failed = [(field, rslt[1]) for (field, rslt) in |
|||
self.field_results.items() if rslt[0] is False] |
|||
self.failed = {} |
|||
for field, reason in failed: |
|||
self.failed[field] = reason |
|||
|
|||
|
|||
def get_change_set_value( |
|||
change_set: Optional[Dict[str, Any]], field: str) -> Any: |
|||
"""Read a value or default from changeset.""" |
|||
if change_set is not None and field in change_set.keys(): |
|||
return change_set[field] |
|||
return None |
|||
|
|||
|
|||
class BaseValidator: |
|||
"""Base Model validator.""" |
|||
|
|||
type: Type[db.Model] |
|||
|
|||
def __init__(self, request_user: User, model: Type[db.Model]) -> None: |
|||
"""Initialize the base validator.""" |
|||
self.request_user = request_user |
|||
self._fields: Set[str] = get_changable_attribute_names(model) |
|||
self.model = model |
|||
|
|||
def validate(self, |
|||
change_set: Optional[Dict[str, Any]] = None) \ |
|||
-> ModelValidationResult: |
|||
"""Validate Model fields.""" |
|||
field_validators = self._validators() |
|||
fields_to_validate = self._fields |
|||
if change_set: |
|||
fields_to_validate = set(change_set.keys()) |
|||
validation_results: Dict[str, Tuple[bool, str]] = {} |
|||
for field in fields_to_validate: |
|||
if field not in field_validators: |
|||
raise errors.ValidationError( |
|||
'Invalid key: %r. Valid keys: %r.' % ( |
|||
field, list(sorted(field_validators.keys())))) |
|||
field_validator = field_validators[field] |
|||
field_result = field_validator( |
|||
get_change_set_value(change_set, field)) |
|||
validation_results[field] = field_result |
|||
return ModelValidationResult(validation_results) |
|||
|
|||
def _validators( |
|||
self) -> Dict[str, Callable[[Any], Tuple[bool, str]]]: |
|||
"""Field definitions.""" |
|||
raise NotImplementedError() |
|||
|
|||
@staticmethod |
|||
def no_validation(_new_value: Any) -> Tuple[bool, str]: |
|||
"""Perform no validation.""" |
|||
return True, '' |
|||
|
|||
def validate_version(self, new_version: Any) -> Tuple[bool, str]: |
|||
"""Perform a standard version validation.""" |
|||
if new_version is not None: |
|||
version_increasing = self.model.version <= new_version |
|||
if version_increasing: |
|||
return version_increasing, '' |
|||
return version_increasing, 'Unacceptable version change' |
|||
return True, '' |
|||
|
|||
|
|||
_model_validators: Dict[str, Type[BaseValidator]] = {} |
|||
|
|||
|
|||
def register_validator( |
|||
model_validator: Type[BaseValidator]) -> Type[BaseValidator]: |
|||
"""Add a model to the serializer mapping.""" |
|||
model_name = model_validator.type.__name__ |
|||
if model_name not in _model_validators: |
|||
_model_validators[model_name] = model_validator |
|||
else: |
|||
raise KeyError( |
|||
' '.join([ |
|||
'A validator for type "{}" already exists with class "{}".', |
|||
'Cannot register a new validator with class "{}"' |
|||
]).format( |
|||
model_name, |
|||
_model_validators[model_name].__name__, |
|||
model_validator.__name__)) |
|||
return model_validator |
|||
|
|||
|
|||
def validate_model(request_user: User, |
|||
model_obj: db.Model, |
|||
change_set: Optional[Dict[str, Any]] = None) \ |
|||
-> ModelValidationResult: |
|||
"""Lookup a Model and hand off to the validator.""" |
|||
try: |
|||
return _model_validators[type(model_obj).__name__]( |
|||
request_user, model_obj).validate(change_set) |
|||
except KeyError: |
|||
raise NotImplementedError( |
|||
'{} has no registered validator'.format(model_obj.__name__)) |
@ -0,0 +1,25 @@ |
|||
from atheneum.model import User |
|||
from atheneum.service import role_service, validation_service |
|||
|
|||
|
|||
def test_successful_validation(): |
|||
request_user = User() |
|||
request_user.role = role_service.Role.USER |
|||
|
|||
user = User() |
|||
user.role = role_service.Role.USER |
|||
|
|||
validation_result = validation_service.validate_model(request_user, user) |
|||
assert validation_result.success |
|||
|
|||
|
|||
def test_failed_validation(): |
|||
request_user = User() |
|||
request_user.role = role_service.Role.ANONYMOUS |
|||
|
|||
user = User() |
|||
user.role = role_service.Role.USER |
|||
|
|||
validation_result = validation_service.validate_model(request_user, user) |
|||
assert validation_result.success is False |
|||
assert 'role' in validation_result.failed |
Write
Preview
Loading…
Cancel
Save
Reference in new issue