You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
113 lines
3.9 KiB
113 lines
3.9 KiB
"""Handle Model Serialization."""
|
|
import logging
|
|
from typing import Dict, Callable, Any, List, Optional, Type
|
|
|
|
from atheneum import errors
|
|
from atheneum.db import db
|
|
|
|
LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
class BaseTransformer:
|
|
"""Base Model serializer."""
|
|
|
|
type: Type[db.Model]
|
|
|
|
def __init__(self, model: Type[db.Model]) -> None:
|
|
"""Initialize the base serializer."""
|
|
self._fields: Dict[str, Callable[[db.Model], Any]] = {}
|
|
self.model = model
|
|
|
|
def serialize(self, options: Optional[List[str]]) -> Any:
|
|
"""Convert Model field and factories to dicts."""
|
|
field_factories = self._serializers()
|
|
if not options:
|
|
options = list(field_factories.keys())
|
|
ret = {}
|
|
for key in options:
|
|
if key not in field_factories:
|
|
raise errors.ValidationError(
|
|
'Invalid key: %r. Valid keys: %r.' % (
|
|
key, list(sorted(field_factories.keys()))))
|
|
factory = field_factories[key]
|
|
val = factory()
|
|
if val is not None:
|
|
ret[key] = val
|
|
return ret
|
|
|
|
def deserialize(self,
|
|
json_model: dict,
|
|
options: Optional[List[str]]) -> Any:
|
|
"""Convert dict to Model."""
|
|
field_factories = self._deserializers()
|
|
if not options:
|
|
options = list(field_factories.keys())
|
|
for key in options:
|
|
if key not in field_factories:
|
|
raise errors.ValidationError(
|
|
'Invalid key: %r. Valid keys: %r.' % (
|
|
key, list(sorted(field_factories.keys()))))
|
|
factory = field_factories[key]
|
|
try:
|
|
value = json_model[key]
|
|
if value is not None:
|
|
factory(self.model, value)
|
|
except KeyError as key_error:
|
|
LOGGER.error(
|
|
'Unable to transform field: %s %s', key, key_error)
|
|
return self.model
|
|
|
|
def _serializers(self) -> Dict[str, Callable[[], Any]]:
|
|
"""Field definitions."""
|
|
raise NotImplementedError()
|
|
|
|
def _deserializers(
|
|
self) -> Dict[str, Callable[[db.Model, Any], None]]:
|
|
"""Field definitions."""
|
|
raise NotImplementedError()
|
|
|
|
|
|
_model_transformers: Dict[str, Type[BaseTransformer]] = {}
|
|
|
|
|
|
def register_transformer(
|
|
model_serializer: Type[BaseTransformer]) -> Type[BaseTransformer]:
|
|
"""Add a model to the serializer mapping."""
|
|
model_name = model_serializer.type.__name__
|
|
if model_name not in _model_transformers:
|
|
_model_transformers[model_name] = model_serializer
|
|
else:
|
|
raise KeyError(
|
|
' '.join([
|
|
'A transformer for type "{}" already exists with class "{}".',
|
|
'Cannot register a new transformer with class "{}"'
|
|
]).format(
|
|
model_name,
|
|
_model_transformers[model_name].__name__,
|
|
model_serializer.__name__))
|
|
return model_serializer
|
|
|
|
|
|
def serialize_model(model_obj: db.Model,
|
|
options: Optional[List[str]] = None) -> Any:
|
|
"""Lookup a Model and hand off to the serializer."""
|
|
try:
|
|
return _model_transformers[
|
|
type(model_obj).__name__](model_obj).serialize(options)
|
|
except KeyError:
|
|
raise NotImplementedError(
|
|
'{} has no registered serializers'.format(model_obj.__name__))
|
|
|
|
|
|
def deserialize_model(
|
|
model_type: Type[db.Model],
|
|
json_model_object: dict,
|
|
options: Optional[List[str]] = None) -> db.Model:
|
|
"""Lookup a Model and hand it off to the deserializer."""
|
|
try:
|
|
transformer = _model_transformers[model_type.__name__]
|
|
return transformer(
|
|
transformer.type()).deserialize(json_model_object, options)
|
|
except KeyError:
|
|
raise NotImplementedError(
|
|
'{} has no registered serializers'.format(model_type))
|