|
|
from typing import Type, List, Any, Dict
from sqlalchemy import orm
from atheneum import db
_patchable_attribute_names = {} _restricted_attribute_names = {}
def get_patchable_attribute_names(model: Type[db.Model]) -> List[str]: """
Retrieve columns from a SQLAlchemy model. Caches already seen models to improve performance.
:param model: :return: A list of patchable model attribute names """
class_name = model.__class__.__name__ if class_name in _patchable_attribute_names: return _patchable_attribute_names[class_name] else: model_attributes = [prop.key for prop in orm.class_mapper(model.__class__).iterate_properties if isinstance(prop, orm.ColumnProperty)] _patchable_attribute_names[class_name] = model_attributes return model_attributes
def is_restricted_attribute(column_property: orm.ColumnProperty) -> bool: """
A primary_key or unique columns are not patchable
:param column_property: The SQLAlchemy column element :return: A boolean indicating if a field is restricted """
column = column_property.columns[0] return column.primary_key or column.unique is True
def get_restricted_attribute_names(model: Type[db.Model]) -> List[str]: """
Retrieve primary_key or unique columns from a SQLAlchemy model. Caches already seen models to improve performance.
:param model: :return: A list of patchable model attribute names """
class_name = model.__class__.__name__ if class_name in _restricted_attribute_names: return _restricted_attribute_names[class_name] else: model_attributes = [prop.key for prop in orm.class_mapper(model.__class__).iterate_properties if isinstance(prop, orm.ColumnProperty) and is_restricted_attribute(prop)] _restricted_attribute_names[class_name] = model_attributes return model_attributes
def determine_patch(original_model: Type[db.Model], patch_model: Type[db.Model], model_attributes: List[str] ) -> Dict[str, Any]: """
Determine the patch set for two models
:param original_model: :param patch_model: :param model_attributes: :return: """
patch_set = {} for attribute in model_attributes: original_attribute = getattr(original_model, attribute) patch_attribute = getattr(patch_model, attribute) if original_attribute != patch_attribute: patch_set[attribute] = patch_attribute return patch_set
def perform_patch(original_model: Type[db.Model], patch_model: Type[db.Model], model_attributes: List[str]) -> Type[db.Model]: """
Patch the attributes from the patch_model onto the original_model when the attribute values differ.
:param original_model: The model to apply the patches to :param patch_model: The model to pull the patch information from :param model_attributes: The attributes that are valid for patching :return: Thd patched original_model """
patch_set = determine_patch(original_model, patch_model, model_attributes) restricted_attributes = get_restricted_attribute_names(original_model) if set(patch_set.keys()).isdisjoint(restricted_attributes): for attribute, value in patch_set.items(): setattr(original_model, attribute, value) else: raise ValueError('Restricted attributes modified. Invalid Patch Set.') return original_model
def versioning_aware_patch(original_model: Type[db.Model], patch_model: Type[db.Model], model_attributes: List[str]) -> Type[db.Model]: """
Account for version numbers in the model.
Versions must match to perform the patching. Otherwise a simultaneous edit error has occurred. If the versions match and the patch moves forward, bump the version on the model by 1 to prevent other reads from performing a simultaneous edit.
:param original_model: The model to apply the patches to :param patch_model: The model to pull the patch information from :param model_attributes: The attributes that are valid for patching :return: Thd patched original_model """
if original_model.version == patch_model.version: patch_model.version = patch_model.version + 1 return perform_patch(original_model, patch_model, model_attributes) else: raise ValueError()
def patch( original_model: Type[db.Model], patch_model: Type[db.Model]) -> Type[db.Model]: """
Given two matching models, patch the original model with the patch model data.
:param original_model: The model to apply the patches to :param patch_model: The model to pull the patch information from :return: The patched original_model """
if type(original_model) is type(patch_model): model_attributes = get_patchable_attribute_names(original_model) if original_model.id != patch_model.id: raise ValueError('Cannot change ids through patching') if 'version' in model_attributes: return versioning_aware_patch( original_model, patch_model, model_attributes) else: return perform_patch(original_model, patch_model, model_attributes) else: raise ValueError( 'Model types "{}" and "{}" do not match'.format( original_model.__class__.__name__, patch_model.__class__.__name__ ))
|