@ -1,16 +1,25 @@
""" Patching support for db.Model objects. """
from typing import Type , Set
from typing import Type , Set , Optional , Any , Dict
from atheneum import db
from atheneum.model import User
from atheneum.service import transformation_service
from atheneum.service import validation_service
def get_patch_fields ( patch_json : Dict [ str , Any ] ) - > Set [ str ] :
""" Convert json fields to python fields. """
return set ( [
transformation_service . convert_key_from_json ( key ) for key in
patch_json . keys ( ) ] )
def perform_patch ( request_user : User ,
original_model : Type [ db . Model ] ,
patch_model : Type [ db . Model ] ,
model_attributes : Set [ str ] ) - > Type [ db . Model ] :
model_attributes : Set [ str ] ,
patched_fields : Optional [ Set [ str ] ] ) \
- > Type [ db . Model ] :
"""
Patch changed attributes onto original model .
@ -18,10 +27,11 @@ def perform_patch(request_user: User,
: param original_model : The model to apply the patches to
: param patch_model : The model to pull the patch information from
: param model_attributes : The attributes that are valid for patching
: param patched_fields : The explicitly passed fields for patching
: return : Thd patched original_model
"""
change_set = validation_service . determine_change_set (
original_model , patch_model , model_attributes )
original_model , patch_model , model_attributes , patched_fields )
model_validation = validation_service . validate_model (
request_user , original_model , change_set )
if model_validation . success :
@ -37,7 +47,9 @@ def perform_patch(request_user: User,
def versioning_aware_patch ( request_user : User ,
original_model : Type [ db . Model ] ,
patch_model : Type [ db . Model ] ,
model_attributes : Set [ str ] ) - > Type [ db . Model ] :
model_attributes : Set [ str ] ,
patched_fields : Optional [ Set [ str ] ] ) \
- > Type [ db . Model ] :
"""
Account for version numbers in the model .
@ -46,6 +58,7 @@ def versioning_aware_patch(request_user: User,
the version on the model by 1 to prevent other reads from performing a
simultaneous edit .
: param patched_fields :
: param request_user :
: param original_model : The model to apply the patches to
: param patch_model : The model to pull the patch information from
@ -55,32 +68,46 @@ def versioning_aware_patch(request_user: User,
if original_model . version == patch_model . version :
patch_model . version = patch_model . version + 1
return perform_patch (
request_user , original_model , patch_model , model_attributes )
request_user ,
original_model ,
patch_model ,
model_attributes ,
patched_fields )
raise ValueError ( ' Versions do not match. Concurrent edit in progress. ' )
def patch (
request_user : User ,
original_model : Type [ db . Model ] ,
patch_model : Type [ db . Model ] ) - > Type [ db . Model ] :
patch_model : Type [ db . Model ] ,
patched_fields : Optional [ Set [ str ] ] = None ) - > Type [ db . Model ] :
"""
Patch the original model with the patch model data .
: param request_user :
: param original_model : The model to apply the patches to
: param patch_model : The model to pull the patch information from
: param patched_fields :
: return : The patched original_model
"""
if type ( original_model ) is type ( patch_model ) :
model_attributes = validation_service . get_changable_attribute_names (
original_model )
if original_model . id != patch_model . id :
if patch_model . id is not None and original_model . id != patch_model . id :
raise ValueError ( ' Cannot change ids through patching ' )
if ' version ' in model_attributes :
return versioning_aware_patch (
request_user , original_model , patch_model , model_attributes )
request_user ,
original_model ,
patch_model ,
model_attributes ,
patched_fields )
return perform_patch (
request_user , original_model , patch_model , model_attributes )
request_user ,
original_model ,
patch_model ,
model_attributes ,
patched_fields )
else :
raise ValueError (
' Model types " {} " and " {} " do not match ' . format (