"""Validation service for Corvus models.""" from typing import Type, Dict, Callable, Any, Set, Optional, Tuple from sqlalchemy import orm from corvus import db, errors from corvus.model import User _changable_attribute_names: Dict[str, Set[str]] = {} def get_changable_attribute_names(model: Type[db.Model]) -> Set[str]: """ Retrieve columns from a SQLAlchemy model. Caches already seen models to improve performance. :param model: :return: A list of changeable model attribute names """ class_name = model.__class__.__name__ if class_name in _changable_attribute_names: return _changable_attribute_names[class_name] model_attributes = {prop.key for prop in orm.class_mapper(model.__class__).iterate_properties if isinstance(prop, orm.ColumnProperty)} _changable_attribute_names[class_name] = model_attributes return model_attributes def determine_change_set(original_model: Type[db.Model], update_model: Type[db.Model], model_attributes: Set[str], options: Optional[Set[str]]) -> Dict[str, Any]: """ Determine the change set for two models. :param options: :param original_model: :param update_model: :param model_attributes: :return: """ if options is None: options = model_attributes else: options = model_attributes.intersection(options) change_set = {} for attribute in options: original_attribute = getattr(original_model, attribute) changed_attribute = getattr(update_model, attribute) if original_attribute != changed_attribute: change_set[attribute] = changed_attribute return change_set class ModelValidationResult: # pylint: disable=too-few-public-methods """Result from model validation.""" field_results: Dict[str, Tuple[bool, str]] success: bool failed: Dict[str, str] = {} def __init__(self, field_results: Dict[str, Tuple[bool, str]]) -> None: """Initialize the validation results.""" self.field_results = field_results self.success = len( [result for (result, _) in self.field_results.values() if result is False]) == 0 if not self.success: failed = [(field, rslt[1]) for (field, rslt) in self.field_results.items() if rslt[0] is False] self.failed = {} for field, reason in failed: self.failed[field] = reason def get_change_set_value( change_set: Optional[Dict[str, Any]], field: str) -> Any: """Read a value or default from changeset.""" if change_set is not None and field in change_set.keys(): return change_set[field] return None class BaseValidator: """Base Model validator.""" type: Type[db.Model] def __init__(self, request_user: User, model: Type[db.Model]) -> None: """Initialize the base validator.""" self.request_user = request_user self._fields: Set[str] = get_changable_attribute_names(model) self.model = model def validate(self, change_set: Optional[Dict[str, Any]] = None) \ -> ModelValidationResult: """Validate Model fields.""" field_validators = self._validators() fields_to_validate = self._fields if change_set: fields_to_validate = set(change_set.keys()) validation_results: Dict[str, Tuple[bool, str]] = {} for field in fields_to_validate: if field not in field_validators: raise errors.ValidationError( 'Invalid key: %r. Valid keys: %r.' % ( field, list(sorted(field_validators.keys())))) field_validator = field_validators[field] field_result = field_validator( get_change_set_value(change_set, field)) validation_results[field] = field_result return ModelValidationResult(validation_results) def _validators( self) -> Dict[str, Callable[[Any], Tuple[bool, str]]]: """Field definitions.""" raise NotImplementedError() @staticmethod def no_validation(_new_value: Any) -> Tuple[bool, str]: """Perform no validation.""" return True, '' def validate_version(self, new_version: Any) -> Tuple[bool, str]: """Perform a standard version validation.""" if new_version is not None: version_increasing = self.model.version <= new_version if version_increasing: return version_increasing, '' return version_increasing, 'Unacceptable version change' return True, '' _model_validators: Dict[str, Type[BaseValidator]] = {} def register_validator( model_validator: Type[BaseValidator]) -> Type[BaseValidator]: """Add a model to the serializer mapping.""" model_name = model_validator.type.__name__ if model_name not in _model_validators: _model_validators[model_name] = model_validator else: raise KeyError( ' '.join([ 'A validator for type "{}" already exists with class "{}".', 'Cannot register a new validator with class "{}"' ]).format( model_name, _model_validators[model_name].__name__, model_validator.__name__)) return model_validator def validate_model(request_user: User, model_obj: db.Model, change_set: Optional[Dict[str, Any]] = None) \ -> ModelValidationResult: """Lookup a Model and hand off to the validator.""" try: return _model_validators[type(model_obj).__name__]( request_user, model_obj).validate(change_set) except KeyError: raise NotImplementedError( '{} has no registered validator'.format(model_obj.__name__))