Browse Source

Simple patching service. Missing full validation.

merge-requests/1/head
Drew Short 7 years ago
parent
commit
54ec24f327
  1. 27
      server/atheneum/api/user_api.py
  2. 152
      server/atheneum/service/patch_service.py
  3. 65
      server/tests/service/test_patch_service.py

27
server/atheneum/api/user_api.py

@ -5,7 +5,7 @@ from atheneum.api.decorators import return_json
from atheneum.api.model import APIResponse from atheneum.api.model import APIResponse
from atheneum.middleware import authentication_middleware from atheneum.middleware import authentication_middleware
from atheneum.model import User from atheneum.model import User
from atheneum.service import user_service, transformation_service
from atheneum.service import user_service, transformation_service, patch_service
from atheneum.service.role_service import Role from atheneum.service.role_service import Role
USER_BLUEPRINT = Blueprint( USER_BLUEPRINT = Blueprint(
@ -15,6 +15,7 @@ USER_BLUEPRINT = Blueprint(
@USER_BLUEPRINT.route('/<name>', methods=['GET']) @USER_BLUEPRINT.route('/<name>', methods=['GET'])
@return_json @return_json
@authentication_middleware.require_token_auth @authentication_middleware.require_token_auth
@authentication_middleware.require_role(required_role=Role.USER)
def get_user(name: str) -> APIResponse: def get_user(name: str) -> APIResponse:
""" """
Get a user. Get a user.
@ -27,6 +28,28 @@ def get_user(name: str) -> APIResponse:
return abort(404) return abort(404)
@USER_BLUEPRINT.route('/<name>', methods=['PUT'])
@return_json
@authentication_middleware.require_token_auth
@authentication_middleware.require_role(required_role=Role.USER)
def patch_user(name: str) -> APIResponse:
"""
Patch a user.
:return: user if patched, 4xx error on patching issue, 404 on nonexistent
"""
user = user_service.find_by_name(name)
if user is not None:
user_patch: User = transformation_service.deserialize_model(
User, request.json)
try:
user = patch_service.patch(user, user_patch)
except ValueError:
return abort(400)
return APIResponse(user, 200)
return abort(404)
@USER_BLUEPRINT.route('/', methods=['POST']) @USER_BLUEPRINT.route('/', methods=['POST'])
@return_json @return_json
@authentication_middleware.require_token_auth @authentication_middleware.require_token_auth
@ -38,7 +61,7 @@ def register_user() -> APIResponse:
:return: The newly registered User :return: The newly registered User
""" """
new_user: User = transformation_service.deserialize_model( new_user: User = transformation_service.deserialize_model(
User.__name__, request.json)
User, request.json)
registered_user = user_service.register( registered_user = user_service.register(
name=new_user.name, name=new_user.name,
password=None, password=None,

152
server/atheneum/service/patch_service.py

@ -0,0 +1,152 @@
from typing import Type, List, Any, Dict
from sqlalchemy import orm
from atheneum import db
_patchable_attribute_names = {}
_restricted_attribute_names = {}
def get_patchable_attribute_names(model: Type[db.Model]) -> List[str]:
"""
Retrieve columns from a SQLAlchemy model.
Caches already seen models to improve performance.
:param model:
:return: A list of patchable model attribute names
"""
class_name = model.__class__.__name__
if class_name in _patchable_attribute_names:
return _patchable_attribute_names[class_name]
else:
model_attributes = [prop.key for prop in
orm.class_mapper(model.__class__).iterate_properties
if isinstance(prop, orm.ColumnProperty)]
_patchable_attribute_names[class_name] = model_attributes
return model_attributes
def is_restricted_attribute(column_property: orm.ColumnProperty) -> bool:
"""
A primary_key or unique columns are not patchable
:param column_property: The SQLAlchemy column element
:return: A boolean indicating if a field is restricted
"""
column = column_property.columns[0]
return column.primary_key or column.unique is True
def get_restricted_attribute_names(model: Type[db.Model]) -> List[str]:
"""
Retrieve primary_key or unique columns from a SQLAlchemy model.
Caches already seen models to improve performance.
:param model:
:return: A list of patchable model attribute names
"""
class_name = model.__class__.__name__
if class_name in _restricted_attribute_names:
return _restricted_attribute_names[class_name]
else:
model_attributes = [prop.key for prop in
orm.class_mapper(model.__class__).iterate_properties
if isinstance(prop, orm.ColumnProperty)
and is_restricted_attribute(prop)]
_restricted_attribute_names[class_name] = model_attributes
return model_attributes
def determine_patch(original_model: Type[db.Model],
patch_model: Type[db.Model],
model_attributes: List[str]
) -> Dict[str, Any]:
"""
Determine the patch set for two models
:param original_model:
:param patch_model:
:param model_attributes:
:return:
"""
patch_set = {}
for attribute in model_attributes:
original_attribute = getattr(original_model, attribute)
patch_attribute = getattr(patch_model, attribute)
if original_attribute != patch_attribute:
patch_set[attribute] = patch_attribute
return patch_set
def perform_patch(original_model: Type[db.Model],
patch_model: Type[db.Model],
model_attributes: List[str]) -> Type[db.Model]:
"""
Patch the attributes from the patch_model onto the original_model when
the attribute values differ.
:param original_model: The model to apply the patches to
:param patch_model: The model to pull the patch information from
:param model_attributes: The attributes that are valid for patching
:return: Thd patched original_model
"""
patch_set = determine_patch(original_model, patch_model, model_attributes)
restricted_attributes = get_restricted_attribute_names(original_model)
if set(patch_set.keys()).isdisjoint(restricted_attributes):
for attribute, value in patch_set.items():
setattr(original_model, attribute, value)
else:
raise ValueError('Restricted attributes modified. Invalid Patch Set.')
return original_model
def versioning_aware_patch(original_model: Type[db.Model],
patch_model: Type[db.Model],
model_attributes: List[str]) -> Type[db.Model]:
"""
Account for version numbers in the model.
Versions must match to perform the patching. Otherwise a simultaneous edit
error has occurred. If the versions match and the patch moves forward, bump
the version on the model by 1 to prevent other reads from performing a
simultaneous edit.
:param original_model: The model to apply the patches to
:param patch_model: The model to pull the patch information from
:param model_attributes: The attributes that are valid for patching
:return: Thd patched original_model
"""
if original_model.version == patch_model.version:
patch_model.version = patch_model.version + 1
return perform_patch(original_model, patch_model, model_attributes)
else:
raise ValueError()
def patch(
original_model: Type[db.Model],
patch_model: Type[db.Model]) -> Type[db.Model]:
"""
Given two matching models, patch the original model
with the patch model data.
:param original_model: The model to apply the patches to
:param patch_model: The model to pull the patch information from
:return: The patched original_model
"""
if type(original_model) is type(patch_model):
model_attributes = get_patchable_attribute_names(original_model)
if original_model.id != patch_model.id:
raise ValueError('Cannot change ids through patching')
if 'version' in model_attributes:
return versioning_aware_patch(
original_model, patch_model, model_attributes)
else:
return perform_patch(original_model, patch_model, model_attributes)
else:
raise ValueError(
'Model types "{}" and "{}" do not match'.format(
original_model.__class__.__name__,
patch_model.__class__.__name__
))

65
server/tests/service/test_patch_service.py

@ -0,0 +1,65 @@
from datetime import datetime, timedelta
import pytest
from atheneum.model import UserToken, User
from atheneum.service import patch_service
def test_patch_models():
user = User()
user.name = 'TestUser'
user.version = 1
user.last_login_time = datetime.now() - timedelta(days=1)
user_patch = User()
user_patch.name = 'TestUser'
user_patch.version = 1
user_patch.last_login_time = datetime.now()
patched_user = patch_service.patch(user, user_patch)
assert patched_user.version > 1
assert patched_user.last_login_time == user_patch.last_login_time
def test_patch_of_different_types():
user = User()
user_token = UserToken()
with pytest.raises(ValueError) as error_info:
patch_service.patch(user, user_token)
def test_patch_different_ids():
user1 = User()
user1.id = 1
user2 = User()
user2.id = 2
with pytest.raises(ValueError) as error_info:
patch_service.patch(user1, user2)
def test_patch_different_versions():
user1 = User()
user1.version = 1
user2 = User()
user2.version = 2
with pytest.raises(ValueError) as error_info:
patch_service.patch(user1, user2)
def test_patch_restricted_attributes():
user1 = User()
user1.version = 1
user1.name = 'Bob'
user2 = User()
user2.version = 1
user2.name = 'Chris'
with pytest.raises(ValueError) as error_info:
patch_service.patch(user1, user2)
Loading…
Cancel
Save