From 399b89a9c71093cffd0da817275302adb6156301 Mon Sep 17 00:00:00 2001 From: Andrew Shu Date: Mon, 15 Sep 2025 11:39:24 -0700 Subject: [PATCH] Add type specification to Marshmallow Field classes Marshmallow 4.0 makes Field generic taking a type argument --- tildes/tildes/schemas/fields.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/tildes/tildes/schemas/fields.py b/tildes/tildes/schemas/fields.py index 48ff185..6510106 100644 --- a/tildes/tildes/schemas/fields.py +++ b/tildes/tildes/schemas/fields.py @@ -22,7 +22,7 @@ from tildes.lib.string import simplify_string DataType = Optional[Mapping[str, Any]] -class Enum(Field): +class Enum(Field[enum.Enum]): """Field for a native Python Enum (or subclasses).""" def __init__( @@ -34,9 +34,12 @@ class Enum(Field): self._enum_class = enum_class def _serialize( - self, value: enum.Enum, attr: str | None, obj: object, **kwargs: Any - ) -> str: + self, value: enum.Enum | None, attr: str | None, obj: object, **kwargs: Any + ) -> str | None: """Serialize the enum value - lowercase version of its name.""" + if value is None: + return None + return value.name.lower() def _deserialize( @@ -64,7 +67,7 @@ class ID36(String): super().__init__(validate=Regexp(ID36_REGEX), **kwargs) -class ShortTimePeriod(Field): +class ShortTimePeriod(Field[Optional[SimpleHoursPeriod]]): """Field for short time period strings like "4h" and "2d". Also supports the string "all" which will be converted to None. @@ -100,7 +103,7 @@ class ShortTimePeriod(Field): return value.as_short_form() -class Markdown(Field): +class Markdown(Field[str]): """Field for markdown strings (comments, text topic, messages, etc.).""" DEFAULT_MAX_LENGTH = 50000 @@ -132,13 +135,13 @@ class Markdown(Field): return value def _serialize( - self, value: str, attr: str | None, obj: object, **kwargs: Any - ) -> str: + self, value: str | None, attr: str | None, obj: object, **kwargs: Any + ) -> str | None: """Serialize the value (no-op in this case).""" return value -class SimpleString(Field): +class SimpleString(Field[str]): """Field for "simple" strings, suitable for uses like subject, title, etc. These strings should generally not contain any special formatting (such as @@ -169,13 +172,13 @@ class SimpleString(Field): return simplify_string(value) def _serialize( - self, value: str, attr: str | None, obj: object, **kwargs: Any - ) -> str: + self, value: str | None, attr: str | None, obj: object, **kwargs: Any + ) -> str | None: """Serialize the value (no-op in this case).""" return value -class Ltree(Field): +class Ltree(Field[sqlalchemy_utils.Ltree]): """Field for postgresql ltree type.""" # note that this regex only checks whether all of the chars are individually valid,