diff --git a/tildes/tildes/json.py b/tildes/tildes/json.py index 57313f3..7ec8a0b 100644 --- a/tildes/tildes/json.py +++ b/tildes/tildes/json.py @@ -11,6 +11,7 @@ from tildes.models import DatabaseModel from tildes.models.group import Group from tildes.models.topic import Topic from tildes.models.user import User +from tildes.schemas.context import TildesSchemaContext, TildesContext def serialize_model(model_item: DatabaseModel, request: Request) -> dict: @@ -25,11 +26,12 @@ def serialize_model(model_item: DatabaseModel, request: Request) -> dict: def serialize_topic(topic: Topic, request: Request) -> dict: """Return serializable data for a Topic.""" - context = {} + context: TildesContext = {} if not request.has_permission("view_author", topic): context["hide_username"] = True - return topic.schema_class(context=context).dump(topic) + with TildesSchemaContext(context): + return topic.schema_class().dump(topic) def includeme(config: Configurator) -> None: diff --git a/tildes/tildes/resources/group.py b/tildes/tildes/resources/group.py index b95f964..5a4d8f6 100644 --- a/tildes/tildes/resources/group.py +++ b/tildes/tildes/resources/group.py @@ -10,12 +10,14 @@ from sqlalchemy_utils import Ltree from tildes.models.group import Group, GroupWikiPage from tildes.resources import get_resource +from tildes.schemas.context import TildesSchemaContext, TildesContext from tildes.schemas.group import GroupSchema from tildes.views.decorators import use_kwargs @use_kwargs( - GroupSchema(only=("path",), context={"fix_path_capitalization": True}), + GroupSchema(only=("path",)), + context=TildesSchemaContext(TildesContext(fix_path_capitalization=True)), location="matchdict", ) def group_by_path(request: Request, path: str) -> Group: diff --git a/tildes/tildes/schemas/context.py b/tildes/tildes/schemas/context.py new file mode 100644 index 0000000..495d900 --- /dev/null +++ b/tildes/tildes/schemas/context.py @@ -0,0 +1,30 @@ +# Copyright (c) 2018 Tildes contributors +# SPDX-License-Identifier: AGPL-3.0-or-later + +"""Context variables that can be used with Marshmallow schemas.""" +import typing + +from marshmallow.experimental.context import Context + + +class TildesContext(typing.TypedDict, total=False): + """Context for Tildes Marshmallow schemas. + + For convenience, we use one unified class instead of one per schema, + so it can be passed down through different schemas in a subgraph. + For example, if a Topic contains a reference to a User, + one instance of TildesContext can configure both the Topic and User. + """ + + # Applies to UserSchema + hide_username: bool + # Applies to UserSchema + check_breached_passwords: bool + # Applies to UserSchema + username_trim_whitespace: bool + + # Applies to GroupSchema + fix_path_capitalization: bool + + +TildesSchemaContext = Context[TildesContext] diff --git a/tildes/tildes/schemas/group.py b/tildes/tildes/schemas/group.py index 1659cb8..3d513b3 100644 --- a/tildes/tildes/schemas/group.py +++ b/tildes/tildes/schemas/group.py @@ -11,6 +11,7 @@ from marshmallow import pre_load, Schema, validates from marshmallow.exceptions import ValidationError from marshmallow.fields import DateTime +from tildes.schemas.context import TildesSchemaContext, TildesContext from tildes.schemas.fields import Ltree, Markdown, SimpleString @@ -44,7 +45,9 @@ class GroupSchema(Schema): def prepare_path(self, data: dict, many: bool, partial: Any) -> dict: """Prepare the path value before it's validated.""" # pylint: disable=unused-argument - if not self.context.get("fix_path_capitalization"): + if not TildesSchemaContext.get(default=TildesContext()).get( + "fix_path_capitalization" + ): return data if "path" not in data or not isinstance(data["path"], str): diff --git a/tildes/tildes/schemas/user.py b/tildes/tildes/schemas/user.py index 44bf55f..239cd75 100644 --- a/tildes/tildes/schemas/user.py +++ b/tildes/tildes/schemas/user.py @@ -12,6 +12,7 @@ from marshmallow.fields import DateTime, Email, String from marshmallow.validate import Length, Regexp from tildes.lib.password import is_breached_password +from tildes.schemas.context import TildesSchemaContext, TildesContext from tildes.schemas.fields import Markdown @@ -63,7 +64,7 @@ class UserSchema(Schema): def anonymize_username(self, data: dict, many: bool) -> dict: """Hide the username if the dumping context specifies to do so.""" # pylint: disable=unused-argument - if not self.context.get("hide_username"): + if not TildesSchemaContext.get(default=TildesContext()).get("hide_username"): return data if "username" not in data: @@ -101,7 +102,9 @@ class UserSchema(Schema): Requires check_breached_passwords be True in the schema's context. """ - if not self.context.get("check_breached_passwords"): + if not TildesSchemaContext.get(default=TildesContext()).get( + "check_breached_passwords" + ): return if is_breached_password(value): @@ -117,7 +120,9 @@ class UserSchema(Schema): Requires username_trim_whitespace be True in the schema's context. """ # pylint: disable=unused-argument - if not self.context.get("username_trim_whitespace"): + if not TildesSchemaContext.get(default=TildesContext()).get( + "username_trim_whitespace" + ): return data if "username" not in data: diff --git a/tildes/tildes/views/api/web/user.py b/tildes/tildes/views/api/web/user.py index 3ff8d87..a16a91a 100644 --- a/tildes/tildes/views/api/web/user.py +++ b/tildes/tildes/views/api/web/user.py @@ -23,6 +23,7 @@ from tildes.lib.datetime import SimpleHoursPeriod from tildes.lib.string import separate_string from tildes.models.log import Log from tildes.models.user import User, UserInviteCode +from tildes.schemas.context import TildesSchemaContext, TildesContext from tildes.schemas.fields import Enum, ShortTimePeriod from tildes.schemas.topic import TopicSchema from tildes.schemas.user import UserSchema @@ -54,12 +55,13 @@ def patch_change_password( user = request.context # enable checking the new password against the breached-passwords list - user.schema.context["check_breached_passwords"] = True + context: TildesContext = {"check_breached_passwords": True} if new_password != new_password_confirm: raise HTTPUnprocessableEntity("New password and confirmation do not match.") - user.change_password(old_password, new_password) + with TildesSchemaContext(context): + user.change_password(old_password, new_password) return Response("Your password has been updated") diff --git a/tildes/tildes/views/decorators.py b/tildes/tildes/views/decorators.py index 677d3b8..c6ad292 100644 --- a/tildes/tildes/views/decorators.py +++ b/tildes/tildes/views/decorators.py @@ -4,9 +4,10 @@ """Contains decorators for view functions.""" from collections.abc import Callable -from typing import Any, Union +from typing import Any from marshmallow import EXCLUDE +from marshmallow.experimental.context import Context from marshmallow.fields import Field from marshmallow.schema import Schema from pyramid.httpexceptions import HTTPFound @@ -14,9 +15,14 @@ from pyramid.request import Request from pyramid.view import view_config from webargs import pyramidparser +from tildes.schemas.context import TildesSchemaContext, TildesContext + def use_kwargs( - argmap: Union[Schema, dict[str, Field]], location: str = "query", **kwargs: Any + argmap: Schema | dict[str, Field], + location: str = "query", + context: Context[Any] | None = None, + **kwargs: Any ) -> Callable: """Wrap the webargs @use_kwargs decorator with preferred default modifications. @@ -28,15 +34,19 @@ def use_kwargs( it just ignores them, instead of erroring when there's unexpected data (as there almost always is, especially because of Intercooler). """ - # convert a dict argmap to a Schema (the same way webargs would on its own) - if isinstance(argmap, dict): - argmap = Schema.from_dict(argmap)() + if context is None: + context = TildesSchemaContext(TildesContext()) + + with context: + # convert a dict argmap to a Schema (the same way webargs would on its own) + if isinstance(argmap, dict): + argmap = Schema.from_dict(argmap)() - assert isinstance(argmap, Schema) # tell mypy the type is more restricted now + assert isinstance(argmap, Schema) # tell mypy the type is more restricted now - argmap.unknown = EXCLUDE + argmap.unknown = EXCLUDE - return pyramidparser.use_kwargs(argmap, location=location, **kwargs) + return pyramidparser.use_kwargs(argmap, location=location, **kwargs) def ic_view_config(**kwargs: Any) -> Callable: diff --git a/tildes/tildes/views/login.py b/tildes/tildes/views/login.py index ef68091..ba42336 100644 --- a/tildes/tildes/views/login.py +++ b/tildes/tildes/views/login.py @@ -19,6 +19,7 @@ from tildes.enums import LogEventType from tildes.metrics import incr_counter from tildes.models.log import Log from tildes.models.user import User +from tildes.schemas.context import TildesSchemaContext, TildesContext from tildes.schemas.user import UserSchema from tildes.views.decorators import not_logged_in, rate_limit_view, use_kwargs @@ -60,9 +61,8 @@ def finish_login(request: Request, user: User, redirect_url: str) -> HTTPFound: route_name="login", request_method="POST", permission=NO_PERMISSION_REQUIRED ) @use_kwargs( - UserSchema( - only=("username", "password"), context={"username_trim_whitespace": True} - ), + UserSchema(only=("username", "password")), + context=TildesSchemaContext(TildesContext(username_trim_whitespace=True)), location="form", ) @use_kwargs({"from_url": String(load_default="")}, location="form") diff --git a/tildes/tildes/views/register.py b/tildes/tildes/views/register.py index 4a44d88..527da50 100644 --- a/tildes/tildes/views/register.py +++ b/tildes/tildes/views/register.py @@ -15,6 +15,7 @@ from tildes.metrics import incr_counter from tildes.models.group import Group, GroupSubscription from tildes.models.log import Log from tildes.models.user import User, UserInviteCode +from tildes.schemas.context import TildesSchemaContext, TildesContext from tildes.schemas.user import UserSchema from tildes.views.decorators import not_logged_in, rate_limit_view, use_kwargs @@ -34,9 +35,8 @@ def get_register(request: Request, code: str) -> dict: route_name="register", request_method="POST", permission=NO_PERMISSION_REQUIRED ) @use_kwargs( - UserSchema( - only=("username", "password"), context={"check_breached_passwords": True} - ), + UserSchema(only=("username", "password")), + context=TildesSchemaContext(TildesContext(check_breached_passwords=True)), location="form", ) @use_kwargs( diff --git a/tildes/tildes/views/settings.py b/tildes/tildes/views/settings.py index 0af7577..42e0cdd 100644 --- a/tildes/tildes/views/settings.py +++ b/tildes/tildes/views/settings.py @@ -22,6 +22,7 @@ from tildes.models.comment import Comment, CommentLabel, CommentTree from tildes.models.group import Group from tildes.models.topic import Topic from tildes.models.user import User +from tildes.schemas.context import TildesContext, TildesSchemaContext from tildes.schemas.user import ( BIO_MAX_LENGTH, EMAIL_ADDRESS_NOTE_MAX_LENGTH, @@ -151,12 +152,13 @@ def post_settings_password_change( ) -> Response: """Change the logged-in user's password.""" # enable checking the new password against the breached-passwords list - request.user.schema.context["check_breached_passwords"] = True + context: TildesContext = {"check_breached_passwords": True} if new_password != new_password_confirm: raise HTTPUnprocessableEntity("New password and confirmation do not match.") - request.user.change_password(old_password, new_password) + with TildesSchemaContext(context): + request.user.change_password(old_password, new_password) return Response("Your password has been updated")