170 lines
5.9 KiB
170 lines
5.9 KiB
"""Validation service for Atheneum models."""
|
|
|
|
from typing import Type, Dict, Callable, Any, Set, Optional, Tuple
|
|
|
|
from sqlalchemy import orm
|
|
|
|
from atheneum import db, errors
|
|
from atheneum.model import User
|
|
|
|
_changable_attribute_names: Dict[str, Set[str]] = {}
|
|
|
|
|
|
def get_changable_attribute_names(model: Type[db.Model]) -> Set[str]:
|
|
"""
|
|
Retrieve columns from a SQLAlchemy model.
|
|
|
|
Caches already seen models to improve performance.
|
|
|
|
:param model:
|
|
:return: A list of changeable model attribute names
|
|
"""
|
|
class_name = model.__class__.__name__
|
|
if class_name in _changable_attribute_names:
|
|
return _changable_attribute_names[class_name]
|
|
|
|
model_attributes = {prop.key for prop in
|
|
orm.class_mapper(model.__class__).iterate_properties
|
|
if isinstance(prop, orm.ColumnProperty)}
|
|
_changable_attribute_names[class_name] = model_attributes
|
|
return model_attributes
|
|
|
|
|
|
def determine_change_set(original_model: Type[db.Model],
|
|
update_model: Type[db.Model],
|
|
model_attributes: Set[str],
|
|
options: Optional[Set[str]]) -> Dict[str, Any]:
|
|
"""
|
|
Determine the change set for two models.
|
|
|
|
:param options:
|
|
:param original_model:
|
|
:param update_model:
|
|
:param model_attributes:
|
|
:return:
|
|
"""
|
|
if options is None:
|
|
options = model_attributes
|
|
else:
|
|
options = model_attributes.intersection(options)
|
|
change_set = {}
|
|
for attribute in options:
|
|
original_attribute = getattr(original_model, attribute)
|
|
changed_attribute = getattr(update_model, attribute)
|
|
if original_attribute != changed_attribute:
|
|
change_set[attribute] = changed_attribute
|
|
return change_set
|
|
|
|
|
|
class ModelValidationResult: # pylint: disable=too-few-public-methods
|
|
"""Result from model validation."""
|
|
|
|
field_results: Dict[str, Tuple[bool, str]]
|
|
success: bool
|
|
failed: Dict[str, str] = {}
|
|
|
|
def __init__(self, field_results: Dict[str, Tuple[bool, str]]) -> None:
|
|
"""Initialize the validation results."""
|
|
self.field_results = field_results
|
|
self.success = len(
|
|
[result for (result, _) in self.field_results.values() if
|
|
result is False]) == 0
|
|
if not self.success:
|
|
failed = [(field, rslt[1]) for (field, rslt) in
|
|
self.field_results.items() if rslt[0] is False]
|
|
self.failed = {}
|
|
for field, reason in failed:
|
|
self.failed[field] = reason
|
|
|
|
|
|
def get_change_set_value(
|
|
change_set: Optional[Dict[str, Any]], field: str) -> Any:
|
|
"""Read a value or default from changeset."""
|
|
if change_set is not None and field in change_set.keys():
|
|
return change_set[field]
|
|
return None
|
|
|
|
|
|
class BaseValidator:
|
|
"""Base Model validator."""
|
|
|
|
type: Type[db.Model]
|
|
|
|
def __init__(self, request_user: User, model: Type[db.Model]) -> None:
|
|
"""Initialize the base validator."""
|
|
self.request_user = request_user
|
|
self._fields: Set[str] = get_changable_attribute_names(model)
|
|
self.model = model
|
|
|
|
def validate(self,
|
|
change_set: Optional[Dict[str, Any]] = None) \
|
|
-> ModelValidationResult:
|
|
"""Validate Model fields."""
|
|
field_validators = self._validators()
|
|
fields_to_validate = self._fields
|
|
if change_set:
|
|
fields_to_validate = set(change_set.keys())
|
|
validation_results: Dict[str, Tuple[bool, str]] = {}
|
|
for field in fields_to_validate:
|
|
if field not in field_validators:
|
|
raise errors.ValidationError(
|
|
'Invalid key: %r. Valid keys: %r.' % (
|
|
field, list(sorted(field_validators.keys()))))
|
|
field_validator = field_validators[field]
|
|
field_result = field_validator(
|
|
get_change_set_value(change_set, field))
|
|
validation_results[field] = field_result
|
|
return ModelValidationResult(validation_results)
|
|
|
|
def _validators(
|
|
self) -> Dict[str, Callable[[Any], Tuple[bool, str]]]:
|
|
"""Field definitions."""
|
|
raise NotImplementedError()
|
|
|
|
@staticmethod
|
|
def no_validation(_new_value: Any) -> Tuple[bool, str]:
|
|
"""Perform no validation."""
|
|
return True, ''
|
|
|
|
def validate_version(self, new_version: Any) -> Tuple[bool, str]:
|
|
"""Perform a standard version validation."""
|
|
if new_version is not None:
|
|
version_increasing = self.model.version <= new_version
|
|
if version_increasing:
|
|
return version_increasing, ''
|
|
return version_increasing, 'Unacceptable version change'
|
|
return True, ''
|
|
|
|
|
|
_model_validators: Dict[str, Type[BaseValidator]] = {}
|
|
|
|
|
|
def register_validator(
|
|
model_validator: Type[BaseValidator]) -> Type[BaseValidator]:
|
|
"""Add a model to the serializer mapping."""
|
|
model_name = model_validator.type.__name__
|
|
if model_name not in _model_validators:
|
|
_model_validators[model_name] = model_validator
|
|
else:
|
|
raise KeyError(
|
|
' '.join([
|
|
'A validator for type "{}" already exists with class "{}".',
|
|
'Cannot register a new validator with class "{}"'
|
|
]).format(
|
|
model_name,
|
|
_model_validators[model_name].__name__,
|
|
model_validator.__name__))
|
|
return model_validator
|
|
|
|
|
|
def validate_model(request_user: User,
|
|
model_obj: db.Model,
|
|
change_set: Optional[Dict[str, Any]] = None) \
|
|
-> ModelValidationResult:
|
|
"""Lookup a Model and hand off to the validator."""
|
|
try:
|
|
return _model_validators[type(model_obj).__name__](
|
|
request_user, model_obj).validate(change_set)
|
|
except KeyError:
|
|
raise NotImplementedError(
|
|
'{} has no registered validator'.format(model_obj.__name__))
|