mirror of https://gitlab.com/tildes/tildes.git
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
129 lines
4.7 KiB
129 lines
4.7 KiB
"""Contains the base DatabaseModel class."""
|
|
|
|
from typing import Any, Optional, Type, TypeVar
|
|
|
|
from marshmallow import Schema
|
|
from sqlalchemy import event
|
|
from sqlalchemy.dialects.postgresql import JSONB
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
from sqlalchemy.ext.mutable import MutableDict
|
|
from sqlalchemy.schema import MetaData
|
|
from sqlalchemy.sql.schema import Table
|
|
|
|
|
|
ModelType = TypeVar("ModelType") # pylint: disable=invalid-name
|
|
|
|
# SQLAlchemy naming convention for constraints and indexes
|
|
NAMING_CONVENTION = {
|
|
"pk": "pk_%(table_name)s",
|
|
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
|
|
"ix": "ix_%(table_name)s_%(column_0_name)s",
|
|
"ck": "ck_%(table_name)s_%(constraint_name)s",
|
|
"uq": "uq_%(table_name)s_%(column_0_name)s",
|
|
}
|
|
|
|
|
|
def attach_set_listener(
|
|
class_: Type["DatabaseModelBase"], attribute: str, instance: "DatabaseModelBase"
|
|
) -> None:
|
|
"""Attach the SQLAlchemy ORM "set" attribute listener."""
|
|
# pylint: disable=unused-argument
|
|
def set_handler(
|
|
target: "DatabaseModelBase", value: Any, oldvalue: Any, initiator: Any
|
|
) -> Any:
|
|
"""Handle an SQLAlchemy ORM "set" attribute event."""
|
|
# pylint: disable=protected-access
|
|
return target._validate_new_value(attribute, value)
|
|
|
|
event.listen(instance, "set", set_handler, retval=True)
|
|
|
|
|
|
class DatabaseModelBase:
|
|
"""Base class for models that will persist to the database."""
|
|
|
|
# declare the type of __table__ so mypy understands it when checking __eq__
|
|
__table__: Table
|
|
|
|
schema_class: Optional[Type[Schema]] = None
|
|
|
|
def __eq__(self, other: Any) -> bool:
|
|
"""Equality comparison method - check if primary key values match."""
|
|
if not isinstance(other, self.__class__):
|
|
return NotImplemented
|
|
|
|
# loop over all the columns in the primary key - if any don't match, return
|
|
# False, otherwise return True if we get through all of them
|
|
for column in self.__table__.primary_key:
|
|
if getattr(self, column.name) != getattr(other, column.name):
|
|
return False
|
|
|
|
return True
|
|
|
|
def __hash__(self) -> int:
|
|
"""Return the hash value of the model.
|
|
|
|
This is implemented by mixing together the hash values of the primary key
|
|
columns used in __eq__, as recommended in the Python documentation.
|
|
"""
|
|
primary_key_values = tuple(
|
|
getattr(self, column.name) for column in self.__table__.primary_key
|
|
)
|
|
return hash(primary_key_values)
|
|
|
|
@property
|
|
def schema(self) -> Schema:
|
|
"""Return a "partial" instance of the model's schema."""
|
|
if not self.schema_class:
|
|
raise AttributeError
|
|
|
|
if not hasattr(self, "_schema"):
|
|
self._schema = self.schema_class(partial=True) # noqa
|
|
|
|
return self._schema
|
|
|
|
def _validate_new_value(self, attribute: str, value: Any) -> Any:
|
|
"""Validate the new value for a column.
|
|
|
|
This function will be attached to the SQLAlchemy ORM attribute event for "set"
|
|
and will be called whenever a new value is assigned to any of a model's column
|
|
attributes. It works by deserializing/loading the new value through the
|
|
marshmallow schema associated with the model class (by its `schema` class
|
|
attribute).
|
|
|
|
The deserialization process can modify the value if desired (for sanitization),
|
|
or raise an exception which will prevent the assignment from happening at all.
|
|
|
|
Note that if the schema does not have a Field defined for the column, or the
|
|
Field is declared dump_only, no validation/sanitization will be applied.
|
|
"""
|
|
if not self.schema_class:
|
|
return value
|
|
|
|
# This is a bit "magic", but simplifies the interaction between this validation
|
|
# and SQLAlchemy hybrid properties. If the attribute being set starts with an
|
|
# underscore, assume that it's due to being set up as a hybrid property, and
|
|
# remove the underscore prefix when looking for a field to validate against.
|
|
if attribute.startswith("_"):
|
|
attribute = attribute[1:]
|
|
|
|
field = self.schema.fields.get(attribute)
|
|
|
|
if not field or field.dump_only:
|
|
return value
|
|
|
|
result = self.schema.load({attribute: value})
|
|
return result.data[attribute]
|
|
|
|
|
|
DatabaseModel = declarative_base( # pylint: disable=invalid-name
|
|
cls=DatabaseModelBase,
|
|
name="DatabaseModel",
|
|
metadata=MetaData(naming_convention=NAMING_CONVENTION),
|
|
)
|
|
|
|
|
|
# attach the listener for SQLAlchemy ORM attribute "set" events to all models
|
|
event.listen(DatabaseModel, "attribute_instrument", attach_set_listener)
|
|
|
|
# associate JSONB columns with MutableDict so value changes are detected
|
|
MutableDict.associate_with(JSONB)
|