Drew Short
6 years ago
3 changed files with 242 additions and 2 deletions
-
27server/atheneum/api/user_api.py
-
152server/atheneum/service/patch_service.py
-
65server/tests/service/test_patch_service.py
@ -0,0 +1,152 @@ |
|||||
|
from typing import Type, List, Any, Dict |
||||
|
|
||||
|
from sqlalchemy import orm |
||||
|
|
||||
|
from atheneum import db |
||||
|
|
||||
|
_patchable_attribute_names = {} |
||||
|
_restricted_attribute_names = {} |
||||
|
|
||||
|
|
||||
|
def get_patchable_attribute_names(model: Type[db.Model]) -> List[str]: |
||||
|
""" |
||||
|
Retrieve columns from a SQLAlchemy model. |
||||
|
Caches already seen models to improve performance. |
||||
|
|
||||
|
:param model: |
||||
|
:return: A list of patchable model attribute names |
||||
|
""" |
||||
|
class_name = model.__class__.__name__ |
||||
|
if class_name in _patchable_attribute_names: |
||||
|
return _patchable_attribute_names[class_name] |
||||
|
else: |
||||
|
model_attributes = [prop.key for prop in |
||||
|
orm.class_mapper(model.__class__).iterate_properties |
||||
|
if isinstance(prop, orm.ColumnProperty)] |
||||
|
_patchable_attribute_names[class_name] = model_attributes |
||||
|
return model_attributes |
||||
|
|
||||
|
|
||||
|
def is_restricted_attribute(column_property: orm.ColumnProperty) -> bool: |
||||
|
""" |
||||
|
A primary_key or unique columns are not patchable |
||||
|
|
||||
|
:param column_property: The SQLAlchemy column element |
||||
|
:return: A boolean indicating if a field is restricted |
||||
|
""" |
||||
|
column = column_property.columns[0] |
||||
|
return column.primary_key or column.unique is True |
||||
|
|
||||
|
|
||||
|
def get_restricted_attribute_names(model: Type[db.Model]) -> List[str]: |
||||
|
""" |
||||
|
Retrieve primary_key or unique columns from a SQLAlchemy model. |
||||
|
Caches already seen models to improve performance. |
||||
|
|
||||
|
:param model: |
||||
|
:return: A list of patchable model attribute names |
||||
|
""" |
||||
|
class_name = model.__class__.__name__ |
||||
|
if class_name in _restricted_attribute_names: |
||||
|
return _restricted_attribute_names[class_name] |
||||
|
else: |
||||
|
model_attributes = [prop.key for prop in |
||||
|
orm.class_mapper(model.__class__).iterate_properties |
||||
|
if isinstance(prop, orm.ColumnProperty) |
||||
|
and is_restricted_attribute(prop)] |
||||
|
_restricted_attribute_names[class_name] = model_attributes |
||||
|
return model_attributes |
||||
|
|
||||
|
|
||||
|
def determine_patch(original_model: Type[db.Model], |
||||
|
patch_model: Type[db.Model], |
||||
|
model_attributes: List[str] |
||||
|
) -> Dict[str, Any]: |
||||
|
""" |
||||
|
Determine the patch set for two models |
||||
|
|
||||
|
:param original_model: |
||||
|
:param patch_model: |
||||
|
:param model_attributes: |
||||
|
:return: |
||||
|
""" |
||||
|
patch_set = {} |
||||
|
for attribute in model_attributes: |
||||
|
original_attribute = getattr(original_model, attribute) |
||||
|
patch_attribute = getattr(patch_model, attribute) |
||||
|
if original_attribute != patch_attribute: |
||||
|
patch_set[attribute] = patch_attribute |
||||
|
return patch_set |
||||
|
|
||||
|
|
||||
|
def perform_patch(original_model: Type[db.Model], |
||||
|
patch_model: Type[db.Model], |
||||
|
model_attributes: List[str]) -> Type[db.Model]: |
||||
|
""" |
||||
|
Patch the attributes from the patch_model onto the original_model when |
||||
|
the attribute values differ. |
||||
|
|
||||
|
:param original_model: The model to apply the patches to |
||||
|
:param patch_model: The model to pull the patch information from |
||||
|
:param model_attributes: The attributes that are valid for patching |
||||
|
:return: Thd patched original_model |
||||
|
""" |
||||
|
patch_set = determine_patch(original_model, patch_model, model_attributes) |
||||
|
restricted_attributes = get_restricted_attribute_names(original_model) |
||||
|
if set(patch_set.keys()).isdisjoint(restricted_attributes): |
||||
|
for attribute, value in patch_set.items(): |
||||
|
setattr(original_model, attribute, value) |
||||
|
else: |
||||
|
raise ValueError('Restricted attributes modified. Invalid Patch Set.') |
||||
|
return original_model |
||||
|
|
||||
|
|
||||
|
def versioning_aware_patch(original_model: Type[db.Model], |
||||
|
patch_model: Type[db.Model], |
||||
|
model_attributes: List[str]) -> Type[db.Model]: |
||||
|
""" |
||||
|
Account for version numbers in the model. |
||||
|
|
||||
|
Versions must match to perform the patching. Otherwise a simultaneous edit |
||||
|
error has occurred. If the versions match and the patch moves forward, bump |
||||
|
the version on the model by 1 to prevent other reads from performing a |
||||
|
simultaneous edit. |
||||
|
|
||||
|
:param original_model: The model to apply the patches to |
||||
|
:param patch_model: The model to pull the patch information from |
||||
|
:param model_attributes: The attributes that are valid for patching |
||||
|
:return: Thd patched original_model |
||||
|
""" |
||||
|
if original_model.version == patch_model.version: |
||||
|
patch_model.version = patch_model.version + 1 |
||||
|
return perform_patch(original_model, patch_model, model_attributes) |
||||
|
else: |
||||
|
raise ValueError() |
||||
|
|
||||
|
|
||||
|
def patch( |
||||
|
original_model: Type[db.Model], |
||||
|
patch_model: Type[db.Model]) -> Type[db.Model]: |
||||
|
""" |
||||
|
Given two matching models, patch the original model |
||||
|
with the patch model data. |
||||
|
|
||||
|
:param original_model: The model to apply the patches to |
||||
|
:param patch_model: The model to pull the patch information from |
||||
|
:return: The patched original_model |
||||
|
""" |
||||
|
if type(original_model) is type(patch_model): |
||||
|
model_attributes = get_patchable_attribute_names(original_model) |
||||
|
if original_model.id != patch_model.id: |
||||
|
raise ValueError('Cannot change ids through patching') |
||||
|
if 'version' in model_attributes: |
||||
|
return versioning_aware_patch( |
||||
|
original_model, patch_model, model_attributes) |
||||
|
else: |
||||
|
return perform_patch(original_model, patch_model, model_attributes) |
||||
|
else: |
||||
|
raise ValueError( |
||||
|
'Model types "{}" and "{}" do not match'.format( |
||||
|
original_model.__class__.__name__, |
||||
|
patch_model.__class__.__name__ |
||||
|
)) |
@ -0,0 +1,65 @@ |
|||||
|
from datetime import datetime, timedelta |
||||
|
|
||||
|
import pytest |
||||
|
|
||||
|
from atheneum.model import UserToken, User |
||||
|
from atheneum.service import patch_service |
||||
|
|
||||
|
|
||||
|
def test_patch_models(): |
||||
|
user = User() |
||||
|
user.name = 'TestUser' |
||||
|
user.version = 1 |
||||
|
user.last_login_time = datetime.now() - timedelta(days=1) |
||||
|
|
||||
|
user_patch = User() |
||||
|
user_patch.name = 'TestUser' |
||||
|
user_patch.version = 1 |
||||
|
user_patch.last_login_time = datetime.now() |
||||
|
|
||||
|
patched_user = patch_service.patch(user, user_patch) |
||||
|
assert patched_user.version > 1 |
||||
|
assert patched_user.last_login_time == user_patch.last_login_time |
||||
|
|
||||
|
|
||||
|
def test_patch_of_different_types(): |
||||
|
user = User() |
||||
|
user_token = UserToken() |
||||
|
|
||||
|
with pytest.raises(ValueError) as error_info: |
||||
|
patch_service.patch(user, user_token) |
||||
|
|
||||
|
|
||||
|
def test_patch_different_ids(): |
||||
|
user1 = User() |
||||
|
user1.id = 1 |
||||
|
|
||||
|
user2 = User() |
||||
|
user2.id = 2 |
||||
|
|
||||
|
with pytest.raises(ValueError) as error_info: |
||||
|
patch_service.patch(user1, user2) |
||||
|
|
||||
|
|
||||
|
def test_patch_different_versions(): |
||||
|
user1 = User() |
||||
|
user1.version = 1 |
||||
|
|
||||
|
user2 = User() |
||||
|
user2.version = 2 |
||||
|
|
||||
|
with pytest.raises(ValueError) as error_info: |
||||
|
patch_service.patch(user1, user2) |
||||
|
|
||||
|
|
||||
|
def test_patch_restricted_attributes(): |
||||
|
user1 = User() |
||||
|
user1.version = 1 |
||||
|
user1.name = 'Bob' |
||||
|
|
||||
|
user2 = User() |
||||
|
user2.version = 1 |
||||
|
user2.name = 'Chris' |
||||
|
|
||||
|
with pytest.raises(ValueError) as error_info: |
||||
|
patch_service.patch(user1, user2) |
Write
Preview
Loading…
Cancel
Save
Reference in new issue