An ebook/comic library service and web client
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

152 lines
5.6 KiB

  1. from typing import Type, List, Any, Dict
  2. from sqlalchemy import orm
  3. from atheneum import db
  4. _patchable_attribute_names = {}
  5. _restricted_attribute_names = {}
  6. def get_patchable_attribute_names(model: Type[db.Model]) -> List[str]:
  7. """
  8. Retrieve columns from a SQLAlchemy model.
  9. Caches already seen models to improve performance.
  10. :param model:
  11. :return: A list of patchable model attribute names
  12. """
  13. class_name = model.__class__.__name__
  14. if class_name in _patchable_attribute_names:
  15. return _patchable_attribute_names[class_name]
  16. else:
  17. model_attributes = [prop.key for prop in
  18. orm.class_mapper(model.__class__).iterate_properties
  19. if isinstance(prop, orm.ColumnProperty)]
  20. _patchable_attribute_names[class_name] = model_attributes
  21. return model_attributes
  22. def is_restricted_attribute(column_property: orm.ColumnProperty) -> bool:
  23. """
  24. A primary_key or unique columns are not patchable
  25. :param column_property: The SQLAlchemy column element
  26. :return: A boolean indicating if a field is restricted
  27. """
  28. column = column_property.columns[0]
  29. return column.primary_key or column.unique is True
  30. def get_restricted_attribute_names(model: Type[db.Model]) -> List[str]:
  31. """
  32. Retrieve primary_key or unique columns from a SQLAlchemy model.
  33. Caches already seen models to improve performance.
  34. :param model:
  35. :return: A list of patchable model attribute names
  36. """
  37. class_name = model.__class__.__name__
  38. if class_name in _restricted_attribute_names:
  39. return _restricted_attribute_names[class_name]
  40. else:
  41. model_attributes = [prop.key for prop in
  42. orm.class_mapper(model.__class__).iterate_properties
  43. if isinstance(prop, orm.ColumnProperty)
  44. and is_restricted_attribute(prop)]
  45. _restricted_attribute_names[class_name] = model_attributes
  46. return model_attributes
  47. def determine_patch(original_model: Type[db.Model],
  48. patch_model: Type[db.Model],
  49. model_attributes: List[str]
  50. ) -> Dict[str, Any]:
  51. """
  52. Determine the patch set for two models
  53. :param original_model:
  54. :param patch_model:
  55. :param model_attributes:
  56. :return:
  57. """
  58. patch_set = {}
  59. for attribute in model_attributes:
  60. original_attribute = getattr(original_model, attribute)
  61. patch_attribute = getattr(patch_model, attribute)
  62. if original_attribute != patch_attribute:
  63. patch_set[attribute] = patch_attribute
  64. return patch_set
  65. def perform_patch(original_model: Type[db.Model],
  66. patch_model: Type[db.Model],
  67. model_attributes: List[str]) -> Type[db.Model]:
  68. """
  69. Patch the attributes from the patch_model onto the original_model when
  70. the attribute values differ.
  71. :param original_model: The model to apply the patches to
  72. :param patch_model: The model to pull the patch information from
  73. :param model_attributes: The attributes that are valid for patching
  74. :return: Thd patched original_model
  75. """
  76. patch_set = determine_patch(original_model, patch_model, model_attributes)
  77. restricted_attributes = get_restricted_attribute_names(original_model)
  78. if set(patch_set.keys()).isdisjoint(restricted_attributes):
  79. for attribute, value in patch_set.items():
  80. setattr(original_model, attribute, value)
  81. else:
  82. raise ValueError('Restricted attributes modified. Invalid Patch Set.')
  83. return original_model
  84. def versioning_aware_patch(original_model: Type[db.Model],
  85. patch_model: Type[db.Model],
  86. model_attributes: List[str]) -> Type[db.Model]:
  87. """
  88. Account for version numbers in the model.
  89. Versions must match to perform the patching. Otherwise a simultaneous edit
  90. error has occurred. If the versions match and the patch moves forward, bump
  91. the version on the model by 1 to prevent other reads from performing a
  92. simultaneous edit.
  93. :param original_model: The model to apply the patches to
  94. :param patch_model: The model to pull the patch information from
  95. :param model_attributes: The attributes that are valid for patching
  96. :return: Thd patched original_model
  97. """
  98. if original_model.version == patch_model.version:
  99. patch_model.version = patch_model.version + 1
  100. return perform_patch(original_model, patch_model, model_attributes)
  101. else:
  102. raise ValueError()
  103. def patch(
  104. original_model: Type[db.Model],
  105. patch_model: Type[db.Model]) -> Type[db.Model]:
  106. """
  107. Given two matching models, patch the original model
  108. with the patch model data.
  109. :param original_model: The model to apply the patches to
  110. :param patch_model: The model to pull the patch information from
  111. :return: The patched original_model
  112. """
  113. if type(original_model) is type(patch_model):
  114. model_attributes = get_patchable_attribute_names(original_model)
  115. if original_model.id != patch_model.id:
  116. raise ValueError('Cannot change ids through patching')
  117. if 'version' in model_attributes:
  118. return versioning_aware_patch(
  119. original_model, patch_model, model_attributes)
  120. else:
  121. return perform_patch(original_model, patch_model, model_attributes)
  122. else:
  123. raise ValueError(
  124. 'Model types "{}" and "{}" do not match'.format(
  125. original_model.__class__.__name__,
  126. patch_model.__class__.__name__
  127. ))