You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
152 lines
5.6 KiB
152 lines
5.6 KiB
from typing import Type, List, Any, Dict
|
|
|
|
from sqlalchemy import orm
|
|
|
|
from atheneum import db
|
|
|
|
_patchable_attribute_names = {}
|
|
_restricted_attribute_names = {}
|
|
|
|
|
|
def get_patchable_attribute_names(model: Type[db.Model]) -> List[str]:
|
|
"""
|
|
Retrieve columns from a SQLAlchemy model.
|
|
Caches already seen models to improve performance.
|
|
|
|
:param model:
|
|
:return: A list of patchable model attribute names
|
|
"""
|
|
class_name = model.__class__.__name__
|
|
if class_name in _patchable_attribute_names:
|
|
return _patchable_attribute_names[class_name]
|
|
else:
|
|
model_attributes = [prop.key for prop in
|
|
orm.class_mapper(model.__class__).iterate_properties
|
|
if isinstance(prop, orm.ColumnProperty)]
|
|
_patchable_attribute_names[class_name] = model_attributes
|
|
return model_attributes
|
|
|
|
|
|
def is_restricted_attribute(column_property: orm.ColumnProperty) -> bool:
|
|
"""
|
|
A primary_key or unique columns are not patchable
|
|
|
|
:param column_property: The SQLAlchemy column element
|
|
:return: A boolean indicating if a field is restricted
|
|
"""
|
|
column = column_property.columns[0]
|
|
return column.primary_key or column.unique is True
|
|
|
|
|
|
def get_restricted_attribute_names(model: Type[db.Model]) -> List[str]:
|
|
"""
|
|
Retrieve primary_key or unique columns from a SQLAlchemy model.
|
|
Caches already seen models to improve performance.
|
|
|
|
:param model:
|
|
:return: A list of patchable model attribute names
|
|
"""
|
|
class_name = model.__class__.__name__
|
|
if class_name in _restricted_attribute_names:
|
|
return _restricted_attribute_names[class_name]
|
|
else:
|
|
model_attributes = [prop.key for prop in
|
|
orm.class_mapper(model.__class__).iterate_properties
|
|
if isinstance(prop, orm.ColumnProperty)
|
|
and is_restricted_attribute(prop)]
|
|
_restricted_attribute_names[class_name] = model_attributes
|
|
return model_attributes
|
|
|
|
|
|
def determine_patch(original_model: Type[db.Model],
|
|
patch_model: Type[db.Model],
|
|
model_attributes: List[str]
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Determine the patch set for two models
|
|
|
|
:param original_model:
|
|
:param patch_model:
|
|
:param model_attributes:
|
|
:return:
|
|
"""
|
|
patch_set = {}
|
|
for attribute in model_attributes:
|
|
original_attribute = getattr(original_model, attribute)
|
|
patch_attribute = getattr(patch_model, attribute)
|
|
if original_attribute != patch_attribute:
|
|
patch_set[attribute] = patch_attribute
|
|
return patch_set
|
|
|
|
|
|
def perform_patch(original_model: Type[db.Model],
|
|
patch_model: Type[db.Model],
|
|
model_attributes: List[str]) -> Type[db.Model]:
|
|
"""
|
|
Patch the attributes from the patch_model onto the original_model when
|
|
the attribute values differ.
|
|
|
|
:param original_model: The model to apply the patches to
|
|
:param patch_model: The model to pull the patch information from
|
|
:param model_attributes: The attributes that are valid for patching
|
|
:return: Thd patched original_model
|
|
"""
|
|
patch_set = determine_patch(original_model, patch_model, model_attributes)
|
|
restricted_attributes = get_restricted_attribute_names(original_model)
|
|
if set(patch_set.keys()).isdisjoint(restricted_attributes):
|
|
for attribute, value in patch_set.items():
|
|
setattr(original_model, attribute, value)
|
|
else:
|
|
raise ValueError('Restricted attributes modified. Invalid Patch Set.')
|
|
return original_model
|
|
|
|
|
|
def versioning_aware_patch(original_model: Type[db.Model],
|
|
patch_model: Type[db.Model],
|
|
model_attributes: List[str]) -> Type[db.Model]:
|
|
"""
|
|
Account for version numbers in the model.
|
|
|
|
Versions must match to perform the patching. Otherwise a simultaneous edit
|
|
error has occurred. If the versions match and the patch moves forward, bump
|
|
the version on the model by 1 to prevent other reads from performing a
|
|
simultaneous edit.
|
|
|
|
:param original_model: The model to apply the patches to
|
|
:param patch_model: The model to pull the patch information from
|
|
:param model_attributes: The attributes that are valid for patching
|
|
:return: Thd patched original_model
|
|
"""
|
|
if original_model.version == patch_model.version:
|
|
patch_model.version = patch_model.version + 1
|
|
return perform_patch(original_model, patch_model, model_attributes)
|
|
else:
|
|
raise ValueError()
|
|
|
|
|
|
def patch(
|
|
original_model: Type[db.Model],
|
|
patch_model: Type[db.Model]) -> Type[db.Model]:
|
|
"""
|
|
Given two matching models, patch the original model
|
|
with the patch model data.
|
|
|
|
:param original_model: The model to apply the patches to
|
|
:param patch_model: The model to pull the patch information from
|
|
:return: The patched original_model
|
|
"""
|
|
if type(original_model) is type(patch_model):
|
|
model_attributes = get_patchable_attribute_names(original_model)
|
|
if original_model.id != patch_model.id:
|
|
raise ValueError('Cannot change ids through patching')
|
|
if 'version' in model_attributes:
|
|
return versioning_aware_patch(
|
|
original_model, patch_model, model_attributes)
|
|
else:
|
|
return perform_patch(original_model, patch_model, model_attributes)
|
|
else:
|
|
raise ValueError(
|
|
'Model types "{}" and "{}" do not match'.format(
|
|
original_model.__class__.__name__,
|
|
patch_model.__class__.__name__
|
|
))
|