diff --git a/server/atheneum/service/transformation_service.py b/server/atheneum/service/transformation_service.py index b86908a..774c0ad 100644 --- a/server/atheneum/service/transformation_service.py +++ b/server/atheneum/service/transformation_service.py @@ -71,9 +71,21 @@ _model_transformers: Dict[str, Type[BaseTransformer]] = {} def register_transformer( - model_name: str, model_serializer: Type[BaseTransformer]) -> None: + model_serializer: Type[BaseTransformer]) -> Type[BaseTransformer]: """Add a model to the serializer mapping.""" - _model_transformers[model_name] = model_serializer + model_name = model_serializer.type.__name__ + if model_name not in _model_transformers: + _model_transformers[model_name] = model_serializer + else: + raise ValueError( + ' '.join([ + 'A transformer for type "{}" already exists with class "{}".', + 'Cannot register a new transformer with class "{}"' + ]).format( + model_name, + _model_transformers[model_name].__name__, + model_serializer.__name__)) + return model_serializer def serialize_model(model_obj: db.Model, diff --git a/server/atheneum/service/user_service.py b/server/atheneum/service/user_service.py index b0a56a7..7fbc042 100644 --- a/server/atheneum/service/user_service.py +++ b/server/atheneum/service/user_service.py @@ -18,6 +18,7 @@ from atheneum.utility import authentication_utility LOGGER = logging.getLogger(__name__) +@register_transformer class UserTransformer(BaseTransformer): """Serialize User model.""" @@ -92,9 +93,6 @@ class UserTransformer(BaseTransformer): model.role = Role(role_value) -register_transformer(User.__name__, UserTransformer) - - def find_by_name(name: str) -> Optional[User]: """ Find a user by name. diff --git a/server/atheneum/service/user_token_service.py b/server/atheneum/service/user_token_service.py index 6b4b77f..6063479 100644 --- a/server/atheneum/service/user_token_service.py +++ b/server/atheneum/service/user_token_service.py @@ -11,6 +11,7 @@ from atheneum.service.transformation_service import ( ) +@register_transformer class UserTokenTransformer(BaseTransformer): """Serialize User model.""" @@ -108,9 +109,6 @@ class UserTokenTransformer(BaseTransformer): model.version = version -register_transformer(UserToken.__name__, UserTokenTransformer) - - def generate_token() -> uuid.UUID: """ Generate a unique token. diff --git a/server/tests/service/test_transformation_service.py b/server/tests/service/test_transformation_service.py index 28e2c1f..297cc17 100644 --- a/server/tests/service/test_transformation_service.py +++ b/server/tests/service/test_transformation_service.py @@ -1,8 +1,15 @@ +from typing import Dict, Callable, Any + +import pytest + +from atheneum import db from atheneum.model import UserToken from atheneum.service.transformation_service import ( - serialize_model, - deserialize_model -) + BaseTransformer, + deserialize_model, + register_transformer, + serialize_model) +from atheneum.service.user_token_service import UserTokenTransformer def test_serialize_model(): @@ -20,3 +27,26 @@ def test_deserialize_model(): result = deserialize_model('UserToken', user_token_json) assert result is not None assert result.token == 'test' + + +def test_registering_two_transformers_of_the_same_type(): + class BadTransformer(BaseTransformer): + type = UserToken + + def _serializers(self) -> Dict[str, Callable[[], Any]]: + pass + + def _deserializers(self) -> Dict[str, Callable[[db.Model, Any], None]]: + pass + + error = None + with pytest.raises(ValueError) as e_info: + error = e_info + register_transformer(BadTransformer) + + assert error is not None + error_msg = error.value.args[0] + assert error_msg is not None + assert UserToken.__name__ in error_msg + assert UserTokenTransformer.__name__ in error_msg + assert BadTransformer.__name__ in error_msg