diff --git a/tildes/tildes/resources/group.py b/tildes/tildes/resources/group.py index 3177b47..b95f964 100644 --- a/tildes/tildes/resources/group.py +++ b/tildes/tildes/resources/group.py @@ -10,7 +10,6 @@ from sqlalchemy_utils import Ltree from tildes.models.group import Group, GroupWikiPage from tildes.resources import get_resource -from tildes.schemas.context import TildesSchemaContext, TildesSchemaContextDict from tildes.schemas.group import GroupSchema from tildes.views.decorators import use_kwargs diff --git a/tildes/tildes/schemas/base.py b/tildes/tildes/schemas/base.py index 4c741c4..1ab6b3c 100644 --- a/tildes/tildes/schemas/base.py +++ b/tildes/tildes/schemas/base.py @@ -4,7 +4,7 @@ """Base Marshmallow schema.""" -from typing import Any +from typing import Any, Optional from marshmallow import Schema from tildes.schemas.context import TildesSchemaContext, TildesSchemaContextDict @@ -17,9 +17,21 @@ class BaseTildesSchema(Schema): context: TildesSchemaContextDict - def __init__(self, context: TildesSchemaContextDict = {}, **kwargs: Any): + def __init__( + self, context: Optional[TildesSchemaContextDict] = None, **kwargs: Any + ): + """Pass an optional context, and forward Schema arguments to superclass.""" super().__init__(**kwargs) - self.context = context + self.context = context if context else {} def get_context_value(self, key: str) -> Any: - return TildesSchemaContext.get(default=self.context).get(key) + """Get a value from the context dict. + + Any active TildesSchemaContext, e.g. set using a "with" statement, + takes precedence. If there is no active TildesSchemaContext, then + it takes the value from the dict passed in __init__ instead. + """ + result = TildesSchemaContext.get(default=self.context).get(key) + if result: + return result + return self.context.get(key) diff --git a/tildes/tildes/schemas/group.py b/tildes/tildes/schemas/group.py index 1669595..53575b8 100644 --- a/tildes/tildes/schemas/group.py +++ b/tildes/tildes/schemas/group.py @@ -13,7 +13,6 @@ from marshmallow.fields import DateTime from marshmallow.types import UnknownOption from tildes.schemas.base import BaseTildesSchema -from tildes.schemas.context import TildesSchemaContext, TildesSchemaContextDict from tildes.schemas.fields import Ltree, Markdown, SimpleString diff --git a/tildes/tildes/schemas/user.py b/tildes/tildes/schemas/user.py index 133c57e..f73b7a8 100644 --- a/tildes/tildes/schemas/user.py +++ b/tildes/tildes/schemas/user.py @@ -14,7 +14,6 @@ from marshmallow.validate import Length, Regexp from tildes.lib.password import is_breached_password from tildes.schemas.base import BaseTildesSchema -from tildes.schemas.context import TildesSchemaContext, TildesSchemaContextDict from tildes.schemas.fields import Markdown diff --git a/tildes/tildes/views/decorators.py b/tildes/tildes/views/decorators.py index 15055ac..b3637fa 100644 --- a/tildes/tildes/views/decorators.py +++ b/tildes/tildes/views/decorators.py @@ -7,7 +7,6 @@ from collections.abc import Callable from typing import Any from marshmallow import EXCLUDE -from marshmallow.experimental.context import Context from marshmallow.fields import Field from marshmallow.schema import Schema from pyramid.httpexceptions import HTTPFound @@ -15,8 +14,6 @@ from pyramid.request import Request from pyramid.view import view_config from webargs import pyramidparser -from tildes.schemas.context import TildesSchemaContext, TildesSchemaContextDict - def use_kwargs( argmap: Schema | dict[str, Field], location: str = "query", **kwargs: Any diff --git a/tildes/tildes/views/login.py b/tildes/tildes/views/login.py index a926fc5..ef68091 100644 --- a/tildes/tildes/views/login.py +++ b/tildes/tildes/views/login.py @@ -19,7 +19,6 @@ from tildes.enums import LogEventType from tildes.metrics import incr_counter from tildes.models.log import Log from tildes.models.user import User -from tildes.schemas.context import TildesSchemaContext, TildesSchemaContextDict from tildes.schemas.user import UserSchema from tildes.views.decorators import not_logged_in, rate_limit_view, use_kwargs