Browse Source

Migrate to Marshmallow 4.0 way of passing context to schemas

https://marshmallow.readthedocs.io/en/latest/upgrading.html#new-context-api
merge-requests/171/head
Andrew Shu 1 month ago
parent
commit
aededd4cd3
  1. 6
      tildes/tildes/json.py
  2. 4
      tildes/tildes/resources/group.py
  3. 30
      tildes/tildes/schemas/context.py
  4. 5
      tildes/tildes/schemas/group.py
  5. 11
      tildes/tildes/schemas/user.py
  6. 6
      tildes/tildes/views/api/web/user.py
  7. 26
      tildes/tildes/views/decorators.py
  8. 6
      tildes/tildes/views/login.py
  9. 6
      tildes/tildes/views/register.py
  10. 6
      tildes/tildes/views/settings.py

6
tildes/tildes/json.py

@ -11,6 +11,7 @@ from tildes.models import DatabaseModel
from tildes.models.group import Group
from tildes.models.topic import Topic
from tildes.models.user import User
from tildes.schemas.context import TildesSchemaContext, TildesContext
def serialize_model(model_item: DatabaseModel, request: Request) -> dict:
@ -25,11 +26,12 @@ def serialize_model(model_item: DatabaseModel, request: Request) -> dict:
def serialize_topic(topic: Topic, request: Request) -> dict:
"""Return serializable data for a Topic."""
context = {}
context: TildesContext = {}
if not request.has_permission("view_author", topic):
context["hide_username"] = True
return topic.schema_class(context=context).dump(topic)
with TildesSchemaContext(context):
return topic.schema_class().dump(topic)
def includeme(config: Configurator) -> None:

4
tildes/tildes/resources/group.py

@ -10,12 +10,14 @@ from sqlalchemy_utils import Ltree
from tildes.models.group import Group, GroupWikiPage
from tildes.resources import get_resource
from tildes.schemas.context import TildesSchemaContext, TildesContext
from tildes.schemas.group import GroupSchema
from tildes.views.decorators import use_kwargs
@use_kwargs(
GroupSchema(only=("path",), context={"fix_path_capitalization": True}),
GroupSchema(only=("path",)),
context=TildesSchemaContext(TildesContext(fix_path_capitalization=True)),
location="matchdict",
)
def group_by_path(request: Request, path: str) -> Group:

30
tildes/tildes/schemas/context.py

@ -0,0 +1,30 @@
# Copyright (c) 2018 Tildes contributors <code@tildes.net>
# SPDX-License-Identifier: AGPL-3.0-or-later
"""Context variables that can be used with Marshmallow schemas."""
import typing
from marshmallow.experimental.context import Context
class TildesContext(typing.TypedDict, total=False):
"""Context for Tildes Marshmallow schemas.
For convenience, we use one unified class instead of one per schema,
so it can be passed down through different schemas in a subgraph.
For example, if a Topic contains a reference to a User,
one instance of TildesContext can configure both the Topic and User.
"""
# Applies to UserSchema
hide_username: bool
# Applies to UserSchema
check_breached_passwords: bool
# Applies to UserSchema
username_trim_whitespace: bool
# Applies to GroupSchema
fix_path_capitalization: bool
TildesSchemaContext = Context[TildesContext]

5
tildes/tildes/schemas/group.py

@ -11,6 +11,7 @@ from marshmallow import pre_load, Schema, validates
from marshmallow.exceptions import ValidationError
from marshmallow.fields import DateTime
from tildes.schemas.context import TildesSchemaContext, TildesContext
from tildes.schemas.fields import Ltree, Markdown, SimpleString
@ -44,7 +45,9 @@ class GroupSchema(Schema):
def prepare_path(self, data: dict, many: bool, partial: Any) -> dict:
"""Prepare the path value before it's validated."""
# pylint: disable=unused-argument
if not self.context.get("fix_path_capitalization"):
if not TildesSchemaContext.get(default=TildesContext()).get(
"fix_path_capitalization"
):
return data
if "path" not in data or not isinstance(data["path"], str):

11
tildes/tildes/schemas/user.py

@ -12,6 +12,7 @@ from marshmallow.fields import DateTime, Email, String
from marshmallow.validate import Length, Regexp
from tildes.lib.password import is_breached_password
from tildes.schemas.context import TildesSchemaContext, TildesContext
from tildes.schemas.fields import Markdown
@ -63,7 +64,7 @@ class UserSchema(Schema):
def anonymize_username(self, data: dict, many: bool) -> dict:
"""Hide the username if the dumping context specifies to do so."""
# pylint: disable=unused-argument
if not self.context.get("hide_username"):
if not TildesSchemaContext.get(default=TildesContext()).get("hide_username"):
return data
if "username" not in data:
@ -101,7 +102,9 @@ class UserSchema(Schema):
Requires check_breached_passwords be True in the schema's context.
"""
if not self.context.get("check_breached_passwords"):
if not TildesSchemaContext.get(default=TildesContext()).get(
"check_breached_passwords"
):
return
if is_breached_password(value):
@ -117,7 +120,9 @@ class UserSchema(Schema):
Requires username_trim_whitespace be True in the schema's context.
"""
# pylint: disable=unused-argument
if not self.context.get("username_trim_whitespace"):
if not TildesSchemaContext.get(default=TildesContext()).get(
"username_trim_whitespace"
):
return data
if "username" not in data:

6
tildes/tildes/views/api/web/user.py

@ -23,6 +23,7 @@ from tildes.lib.datetime import SimpleHoursPeriod
from tildes.lib.string import separate_string
from tildes.models.log import Log
from tildes.models.user import User, UserInviteCode
from tildes.schemas.context import TildesSchemaContext, TildesContext
from tildes.schemas.fields import Enum, ShortTimePeriod
from tildes.schemas.topic import TopicSchema
from tildes.schemas.user import UserSchema
@ -54,12 +55,13 @@ def patch_change_password(
user = request.context
# enable checking the new password against the breached-passwords list
user.schema.context["check_breached_passwords"] = True
context: TildesContext = {"check_breached_passwords": True}
if new_password != new_password_confirm:
raise HTTPUnprocessableEntity("New password and confirmation do not match.")
user.change_password(old_password, new_password)
with TildesSchemaContext(context):
user.change_password(old_password, new_password)
return Response("Your password has been updated")

26
tildes/tildes/views/decorators.py

@ -4,9 +4,10 @@
"""Contains decorators for view functions."""
from collections.abc import Callable
from typing import Any, Union
from typing import Any
from marshmallow import EXCLUDE
from marshmallow.experimental.context import Context
from marshmallow.fields import Field
from marshmallow.schema import Schema
from pyramid.httpexceptions import HTTPFound
@ -14,9 +15,14 @@ from pyramid.request import Request
from pyramid.view import view_config
from webargs import pyramidparser
from tildes.schemas.context import TildesSchemaContext, TildesContext
def use_kwargs(
argmap: Union[Schema, dict[str, Field]], location: str = "query", **kwargs: Any
argmap: Schema | dict[str, Field],
location: str = "query",
context: Context[Any] | None = None,
**kwargs: Any
) -> Callable:
"""Wrap the webargs @use_kwargs decorator with preferred default modifications.
@ -28,15 +34,19 @@ def use_kwargs(
it just ignores them, instead of erroring when there's unexpected data (as there
almost always is, especially because of Intercooler).
"""
# convert a dict argmap to a Schema (the same way webargs would on its own)
if isinstance(argmap, dict):
argmap = Schema.from_dict(argmap)()
if context is None:
context = TildesSchemaContext(TildesContext())
with context:
# convert a dict argmap to a Schema (the same way webargs would on its own)
if isinstance(argmap, dict):
argmap = Schema.from_dict(argmap)()
assert isinstance(argmap, Schema) # tell mypy the type is more restricted now
assert isinstance(argmap, Schema) # tell mypy the type is more restricted now
argmap.unknown = EXCLUDE
argmap.unknown = EXCLUDE
return pyramidparser.use_kwargs(argmap, location=location, **kwargs)
return pyramidparser.use_kwargs(argmap, location=location, **kwargs)
def ic_view_config(**kwargs: Any) -> Callable:

6
tildes/tildes/views/login.py

@ -19,6 +19,7 @@ from tildes.enums import LogEventType
from tildes.metrics import incr_counter
from tildes.models.log import Log
from tildes.models.user import User
from tildes.schemas.context import TildesSchemaContext, TildesContext
from tildes.schemas.user import UserSchema
from tildes.views.decorators import not_logged_in, rate_limit_view, use_kwargs
@ -60,9 +61,8 @@ def finish_login(request: Request, user: User, redirect_url: str) -> HTTPFound:
route_name="login", request_method="POST", permission=NO_PERMISSION_REQUIRED
)
@use_kwargs(
UserSchema(
only=("username", "password"), context={"username_trim_whitespace": True}
),
UserSchema(only=("username", "password")),
context=TildesSchemaContext(TildesContext(username_trim_whitespace=True)),
location="form",
)
@use_kwargs({"from_url": String(load_default="")}, location="form")

6
tildes/tildes/views/register.py

@ -15,6 +15,7 @@ from tildes.metrics import incr_counter
from tildes.models.group import Group, GroupSubscription
from tildes.models.log import Log
from tildes.models.user import User, UserInviteCode
from tildes.schemas.context import TildesSchemaContext, TildesContext
from tildes.schemas.user import UserSchema
from tildes.views.decorators import not_logged_in, rate_limit_view, use_kwargs
@ -34,9 +35,8 @@ def get_register(request: Request, code: str) -> dict:
route_name="register", request_method="POST", permission=NO_PERMISSION_REQUIRED
)
@use_kwargs(
UserSchema(
only=("username", "password"), context={"check_breached_passwords": True}
),
UserSchema(only=("username", "password")),
context=TildesSchemaContext(TildesContext(check_breached_passwords=True)),
location="form",
)
@use_kwargs(

6
tildes/tildes/views/settings.py

@ -22,6 +22,7 @@ from tildes.models.comment import Comment, CommentLabel, CommentTree
from tildes.models.group import Group
from tildes.models.topic import Topic
from tildes.models.user import User
from tildes.schemas.context import TildesContext, TildesSchemaContext
from tildes.schemas.user import (
BIO_MAX_LENGTH,
EMAIL_ADDRESS_NOTE_MAX_LENGTH,
@ -151,12 +152,13 @@ def post_settings_password_change(
) -> Response:
"""Change the logged-in user's password."""
# enable checking the new password against the breached-passwords list
request.user.schema.context["check_breached_passwords"] = True
context: TildesContext = {"check_breached_passwords": True}
if new_password != new_password_confirm:
raise HTTPUnprocessableEntity("New password and confirmation do not match.")
request.user.change_password(old_password, new_password)
with TildesSchemaContext(context):
request.user.change_password(old_password, new_password)
return Response("Your password has been updated")

Loading…
Cancel
Save