Drew Short
6 years ago
9 changed files with 342 additions and 112 deletions
-
12server/atheneum/api/user_api.py
-
129server/atheneum/service/patch_service.py
-
22server/atheneum/service/role_service.py
-
1server/atheneum/service/transformation_service.py
-
58server/atheneum/service/user_service.py
-
165server/atheneum/service/validation_service.py
-
35server/tests/service/test_patch_service.py
-
7server/tests/service/test_role_service.py
-
25server/tests/service/test_validation_service.py
@ -0,0 +1,165 @@ |
|||||
|
"""Validation service for Atheneum models.""" |
||||
|
|
||||
|
from typing import Type, Dict, Callable, Any, Set, Optional, Tuple |
||||
|
|
||||
|
from sqlalchemy import orm |
||||
|
|
||||
|
from atheneum import db, errors |
||||
|
from atheneum.model import User |
||||
|
|
||||
|
_changable_attribute_names: Dict[str, Set[str]] = {} |
||||
|
|
||||
|
|
||||
|
def get_changable_attribute_names(model: Type[db.Model]) -> Set[str]: |
||||
|
""" |
||||
|
Retrieve columns from a SQLAlchemy model. |
||||
|
|
||||
|
Caches already seen models to improve performance. |
||||
|
|
||||
|
:param model: |
||||
|
:return: A list of changeable model attribute names |
||||
|
""" |
||||
|
class_name = model.__class__.__name__ |
||||
|
if class_name in _changable_attribute_names: |
||||
|
return _changable_attribute_names[class_name] |
||||
|
|
||||
|
model_attributes = set([prop.key for prop in |
||||
|
orm.class_mapper( |
||||
|
model.__class__).iterate_properties |
||||
|
if isinstance(prop, orm.ColumnProperty)]) |
||||
|
_changable_attribute_names[class_name] = model_attributes |
||||
|
return model_attributes |
||||
|
|
||||
|
|
||||
|
def determine_change_set(original_model: Type[db.Model], |
||||
|
update_model: Type[db.Model], |
||||
|
model_attributes: Set[str]) -> Dict[str, Any]: |
||||
|
""" |
||||
|
Determine the change set for two models. |
||||
|
|
||||
|
:param original_model: |
||||
|
:param update_model: |
||||
|
:param model_attributes: |
||||
|
:return: |
||||
|
""" |
||||
|
change_set = {} |
||||
|
for attribute in model_attributes: |
||||
|
original_attribute = getattr(original_model, attribute) |
||||
|
changed_attribute = getattr(update_model, attribute) |
||||
|
if original_attribute != changed_attribute: |
||||
|
change_set[attribute] = changed_attribute |
||||
|
return change_set |
||||
|
|
||||
|
|
||||
|
class ModelValidationResult: # pylint: disable=too-few-public-methods |
||||
|
"""Result from model validation.""" |
||||
|
|
||||
|
field_results: Dict[str, Tuple[bool, str]] |
||||
|
success: bool |
||||
|
failed: Dict[str, str] = {} |
||||
|
|
||||
|
def __init__(self, field_results: Dict[str, Tuple[bool, str]]) -> None: |
||||
|
"""Initialize the validation results.""" |
||||
|
self.field_results = field_results |
||||
|
self.success = len( |
||||
|
[result for (result, _) in self.field_results.values() if |
||||
|
result is False]) == 0 |
||||
|
if not self.success: |
||||
|
failed = [(field, rslt[1]) for (field, rslt) in |
||||
|
self.field_results.items() if rslt[0] is False] |
||||
|
self.failed = {} |
||||
|
for field, reason in failed: |
||||
|
self.failed[field] = reason |
||||
|
|
||||
|
|
||||
|
def get_change_set_value( |
||||
|
change_set: Optional[Dict[str, Any]], field: str) -> Any: |
||||
|
"""Read a value or default from changeset.""" |
||||
|
if change_set is not None and field in change_set.keys(): |
||||
|
return change_set[field] |
||||
|
return None |
||||
|
|
||||
|
|
||||
|
class BaseValidator: |
||||
|
"""Base Model validator.""" |
||||
|
|
||||
|
type: Type[db.Model] |
||||
|
|
||||
|
def __init__(self, request_user: User, model: Type[db.Model]) -> None: |
||||
|
"""Initialize the base validator.""" |
||||
|
self.request_user = request_user |
||||
|
self._fields: Set[str] = get_changable_attribute_names(model) |
||||
|
self.model = model |
||||
|
|
||||
|
def validate(self, |
||||
|
change_set: Optional[Dict[str, Any]] = None) \ |
||||
|
-> ModelValidationResult: |
||||
|
"""Validate Model fields.""" |
||||
|
field_validators = self._validators() |
||||
|
fields_to_validate = self._fields |
||||
|
if change_set: |
||||
|
fields_to_validate = set(change_set.keys()) |
||||
|
validation_results: Dict[str, Tuple[bool, str]] = {} |
||||
|
for field in fields_to_validate: |
||||
|
if field not in field_validators: |
||||
|
raise errors.ValidationError( |
||||
|
'Invalid key: %r. Valid keys: %r.' % ( |
||||
|
field, list(sorted(field_validators.keys())))) |
||||
|
field_validator = field_validators[field] |
||||
|
field_result = field_validator( |
||||
|
get_change_set_value(change_set, field)) |
||||
|
validation_results[field] = field_result |
||||
|
return ModelValidationResult(validation_results) |
||||
|
|
||||
|
def _validators( |
||||
|
self) -> Dict[str, Callable[[Any], Tuple[bool, str]]]: |
||||
|
"""Field definitions.""" |
||||
|
raise NotImplementedError() |
||||
|
|
||||
|
@staticmethod |
||||
|
def no_validation(_new_value: Any) -> Tuple[bool, str]: |
||||
|
"""Perform no validation.""" |
||||
|
return True, '' |
||||
|
|
||||
|
def validate_version(self, new_version: Any) -> Tuple[bool, str]: |
||||
|
"""Perform a standard version validation.""" |
||||
|
if new_version is not None: |
||||
|
version_increasing = self.model.version <= new_version |
||||
|
if version_increasing: |
||||
|
return version_increasing, '' |
||||
|
return version_increasing, 'Unacceptable version change' |
||||
|
return True, '' |
||||
|
|
||||
|
|
||||
|
_model_validators: Dict[str, Type[BaseValidator]] = {} |
||||
|
|
||||
|
|
||||
|
def register_validator( |
||||
|
model_validator: Type[BaseValidator]) -> Type[BaseValidator]: |
||||
|
"""Add a model to the serializer mapping.""" |
||||
|
model_name = model_validator.type.__name__ |
||||
|
if model_name not in _model_validators: |
||||
|
_model_validators[model_name] = model_validator |
||||
|
else: |
||||
|
raise KeyError( |
||||
|
' '.join([ |
||||
|
'A validator for type "{}" already exists with class "{}".', |
||||
|
'Cannot register a new validator with class "{}"' |
||||
|
]).format( |
||||
|
model_name, |
||||
|
_model_validators[model_name].__name__, |
||||
|
model_validator.__name__)) |
||||
|
return model_validator |
||||
|
|
||||
|
|
||||
|
def validate_model(request_user: User, |
||||
|
model_obj: db.Model, |
||||
|
change_set: Optional[Dict[str, Any]] = None) \ |
||||
|
-> ModelValidationResult: |
||||
|
"""Lookup a Model and hand off to the validator.""" |
||||
|
try: |
||||
|
return _model_validators[type(model_obj).__name__]( |
||||
|
request_user, model_obj).validate(change_set) |
||||
|
except KeyError: |
||||
|
raise NotImplementedError( |
||||
|
'{} has no registered validator'.format(model_obj.__name__)) |
@ -0,0 +1,25 @@ |
|||||
|
from atheneum.model import User |
||||
|
from atheneum.service import role_service, validation_service |
||||
|
|
||||
|
|
||||
|
def test_successful_validation(): |
||||
|
request_user = User() |
||||
|
request_user.role = role_service.Role.USER |
||||
|
|
||||
|
user = User() |
||||
|
user.role = role_service.Role.USER |
||||
|
|
||||
|
validation_result = validation_service.validate_model(request_user, user) |
||||
|
assert validation_result.success |
||||
|
|
||||
|
|
||||
|
def test_failed_validation(): |
||||
|
request_user = User() |
||||
|
request_user.role = role_service.Role.ANONYMOUS |
||||
|
|
||||
|
user = User() |
||||
|
user.role = role_service.Role.USER |
||||
|
|
||||
|
validation_result = validation_service.validate_model(request_user, user) |
||||
|
assert validation_result.success is False |
||||
|
assert 'role' in validation_result.failed |
Write
Preview
Loading…
Cancel
Save
Reference in new issue