An ebook/comic library service and web client
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

170 lines
5.9 KiB

  1. """Validation service for Atheneum models."""
  2. from typing import Type, Dict, Callable, Any, Set, Optional, Tuple
  3. from sqlalchemy import orm
  4. from atheneum import db, errors
  5. from atheneum.model import User
  6. _changable_attribute_names: Dict[str, Set[str]] = {}
  7. def get_changable_attribute_names(model: Type[db.Model]) -> Set[str]:
  8. """
  9. Retrieve columns from a SQLAlchemy model.
  10. Caches already seen models to improve performance.
  11. :param model:
  12. :return: A list of changeable model attribute names
  13. """
  14. class_name = model.__class__.__name__
  15. if class_name in _changable_attribute_names:
  16. return _changable_attribute_names[class_name]
  17. model_attributes = {prop.key for prop in
  18. orm.class_mapper(model.__class__).iterate_properties
  19. if isinstance(prop, orm.ColumnProperty)}
  20. _changable_attribute_names[class_name] = model_attributes
  21. return model_attributes
  22. def determine_change_set(original_model: Type[db.Model],
  23. update_model: Type[db.Model],
  24. model_attributes: Set[str],
  25. options: Optional[Set[str]]) -> Dict[str, Any]:
  26. """
  27. Determine the change set for two models.
  28. :param options:
  29. :param original_model:
  30. :param update_model:
  31. :param model_attributes:
  32. :return:
  33. """
  34. if options is None:
  35. options = model_attributes
  36. else:
  37. options = model_attributes.intersection(options)
  38. change_set = {}
  39. for attribute in options:
  40. original_attribute = getattr(original_model, attribute)
  41. changed_attribute = getattr(update_model, attribute)
  42. if original_attribute != changed_attribute:
  43. change_set[attribute] = changed_attribute
  44. return change_set
  45. class ModelValidationResult: # pylint: disable=too-few-public-methods
  46. """Result from model validation."""
  47. field_results: Dict[str, Tuple[bool, str]]
  48. success: bool
  49. failed: Dict[str, str] = {}
  50. def __init__(self, field_results: Dict[str, Tuple[bool, str]]) -> None:
  51. """Initialize the validation results."""
  52. self.field_results = field_results
  53. self.success = len(
  54. [result for (result, _) in self.field_results.values() if
  55. result is False]) == 0
  56. if not self.success:
  57. failed = [(field, rslt[1]) for (field, rslt) in
  58. self.field_results.items() if rslt[0] is False]
  59. self.failed = {}
  60. for field, reason in failed:
  61. self.failed[field] = reason
  62. def get_change_set_value(
  63. change_set: Optional[Dict[str, Any]], field: str) -> Any:
  64. """Read a value or default from changeset."""
  65. if change_set is not None and field in change_set.keys():
  66. return change_set[field]
  67. return None
  68. class BaseValidator:
  69. """Base Model validator."""
  70. type: Type[db.Model]
  71. def __init__(self, request_user: User, model: Type[db.Model]) -> None:
  72. """Initialize the base validator."""
  73. self.request_user = request_user
  74. self._fields: Set[str] = get_changable_attribute_names(model)
  75. self.model = model
  76. def validate(self,
  77. change_set: Optional[Dict[str, Any]] = None) \
  78. -> ModelValidationResult:
  79. """Validate Model fields."""
  80. field_validators = self._validators()
  81. fields_to_validate = self._fields
  82. if change_set:
  83. fields_to_validate = set(change_set.keys())
  84. validation_results: Dict[str, Tuple[bool, str]] = {}
  85. for field in fields_to_validate:
  86. if field not in field_validators:
  87. raise errors.ValidationError(
  88. 'Invalid key: %r. Valid keys: %r.' % (
  89. field, list(sorted(field_validators.keys()))))
  90. field_validator = field_validators[field]
  91. field_result = field_validator(
  92. get_change_set_value(change_set, field))
  93. validation_results[field] = field_result
  94. return ModelValidationResult(validation_results)
  95. def _validators(
  96. self) -> Dict[str, Callable[[Any], Tuple[bool, str]]]:
  97. """Field definitions."""
  98. raise NotImplementedError()
  99. @staticmethod
  100. def no_validation(_new_value: Any) -> Tuple[bool, str]:
  101. """Perform no validation."""
  102. return True, ''
  103. def validate_version(self, new_version: Any) -> Tuple[bool, str]:
  104. """Perform a standard version validation."""
  105. if new_version is not None:
  106. version_increasing = self.model.version <= new_version
  107. if version_increasing:
  108. return version_increasing, ''
  109. return version_increasing, 'Unacceptable version change'
  110. return True, ''
  111. _model_validators: Dict[str, Type[BaseValidator]] = {}
  112. def register_validator(
  113. model_validator: Type[BaseValidator]) -> Type[BaseValidator]:
  114. """Add a model to the serializer mapping."""
  115. model_name = model_validator.type.__name__
  116. if model_name not in _model_validators:
  117. _model_validators[model_name] = model_validator
  118. else:
  119. raise KeyError(
  120. ' '.join([
  121. 'A validator for type "{}" already exists with class "{}".',
  122. 'Cannot register a new validator with class "{}"'
  123. ]).format(
  124. model_name,
  125. _model_validators[model_name].__name__,
  126. model_validator.__name__))
  127. return model_validator
  128. def validate_model(request_user: User,
  129. model_obj: db.Model,
  130. change_set: Optional[Dict[str, Any]] = None) \
  131. -> ModelValidationResult:
  132. """Lookup a Model and hand off to the validator."""
  133. try:
  134. return _model_validators[type(model_obj).__name__](
  135. request_user, model_obj).validate(change_set)
  136. except KeyError:
  137. raise NotImplementedError(
  138. '{} has no registered validator'.format(model_obj.__name__))