An ebook/comic library service and web client
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 

170 lines
5.9 KiB

"""Validation service for Atheneum models."""
from typing import Type, Dict, Callable, Any, Set, Optional, Tuple
from sqlalchemy import orm
from atheneum import db, errors
from atheneum.model import User
_changable_attribute_names: Dict[str, Set[str]] = {}
def get_changable_attribute_names(model: Type[db.Model]) -> Set[str]:
"""
Retrieve columns from a SQLAlchemy model.
Caches already seen models to improve performance.
:param model:
:return: A list of changeable model attribute names
"""
class_name = model.__class__.__name__
if class_name in _changable_attribute_names:
return _changable_attribute_names[class_name]
model_attributes = {prop.key for prop in
orm.class_mapper(model.__class__).iterate_properties
if isinstance(prop, orm.ColumnProperty)}
_changable_attribute_names[class_name] = model_attributes
return model_attributes
def determine_change_set(original_model: Type[db.Model],
update_model: Type[db.Model],
model_attributes: Set[str],
options: Optional[Set[str]]) -> Dict[str, Any]:
"""
Determine the change set for two models.
:param options:
:param original_model:
:param update_model:
:param model_attributes:
:return:
"""
if options is None:
options = model_attributes
else:
options = model_attributes.intersection(options)
change_set = {}
for attribute in options:
original_attribute = getattr(original_model, attribute)
changed_attribute = getattr(update_model, attribute)
if original_attribute != changed_attribute:
change_set[attribute] = changed_attribute
return change_set
class ModelValidationResult: # pylint: disable=too-few-public-methods
"""Result from model validation."""
field_results: Dict[str, Tuple[bool, str]]
success: bool
failed: Dict[str, str] = {}
def __init__(self, field_results: Dict[str, Tuple[bool, str]]) -> None:
"""Initialize the validation results."""
self.field_results = field_results
self.success = len(
[result for (result, _) in self.field_results.values() if
result is False]) == 0
if not self.success:
failed = [(field, rslt[1]) for (field, rslt) in
self.field_results.items() if rslt[0] is False]
self.failed = {}
for field, reason in failed:
self.failed[field] = reason
def get_change_set_value(
change_set: Optional[Dict[str, Any]], field: str) -> Any:
"""Read a value or default from changeset."""
if change_set is not None and field in change_set.keys():
return change_set[field]
return None
class BaseValidator:
"""Base Model validator."""
type: Type[db.Model]
def __init__(self, request_user: User, model: Type[db.Model]) -> None:
"""Initialize the base validator."""
self.request_user = request_user
self._fields: Set[str] = get_changable_attribute_names(model)
self.model = model
def validate(self,
change_set: Optional[Dict[str, Any]] = None) \
-> ModelValidationResult:
"""Validate Model fields."""
field_validators = self._validators()
fields_to_validate = self._fields
if change_set:
fields_to_validate = set(change_set.keys())
validation_results: Dict[str, Tuple[bool, str]] = {}
for field in fields_to_validate:
if field not in field_validators:
raise errors.ValidationError(
'Invalid key: %r. Valid keys: %r.' % (
field, list(sorted(field_validators.keys()))))
field_validator = field_validators[field]
field_result = field_validator(
get_change_set_value(change_set, field))
validation_results[field] = field_result
return ModelValidationResult(validation_results)
def _validators(
self) -> Dict[str, Callable[[Any], Tuple[bool, str]]]:
"""Field definitions."""
raise NotImplementedError()
@staticmethod
def no_validation(_new_value: Any) -> Tuple[bool, str]:
"""Perform no validation."""
return True, ''
def validate_version(self, new_version: Any) -> Tuple[bool, str]:
"""Perform a standard version validation."""
if new_version is not None:
version_increasing = self.model.version <= new_version
if version_increasing:
return version_increasing, ''
return version_increasing, 'Unacceptable version change'
return True, ''
_model_validators: Dict[str, Type[BaseValidator]] = {}
def register_validator(
model_validator: Type[BaseValidator]) -> Type[BaseValidator]:
"""Add a model to the serializer mapping."""
model_name = model_validator.type.__name__
if model_name not in _model_validators:
_model_validators[model_name] = model_validator
else:
raise KeyError(
' '.join([
'A validator for type "{}" already exists with class "{}".',
'Cannot register a new validator with class "{}"'
]).format(
model_name,
_model_validators[model_name].__name__,
model_validator.__name__))
return model_validator
def validate_model(request_user: User,
model_obj: db.Model,
change_set: Optional[Dict[str, Any]] = None) \
-> ModelValidationResult:
"""Lookup a Model and hand off to the validator."""
try:
return _model_validators[type(model_obj).__name__](
request_user, model_obj).validate(change_set)
except KeyError:
raise NotImplementedError(
'{} has no registered validator'.format(model_obj.__name__))