diff --git a/tildes/tests/test_markdown_field.py b/tildes/tests/test_markdown_field.py index 753f84c..16b6c09 100644 --- a/tildes/tests/test_markdown_field.py +++ b/tildes/tests/test_markdown_field.py @@ -1,13 +1,14 @@ # Copyright (c) 2018 Tildes contributors # SPDX-License-Identifier: AGPL-3.0-or-later -from marshmallow import Schema, ValidationError +from marshmallow import ValidationError from pytest import raises +from tildes.schemas.base import BaseTildesSchema from tildes.schemas.fields import Markdown -class MarkdownFieldTestSchema(Schema): +class MarkdownFieldTestSchema(BaseTildesSchema): """Simple schema class with a standard Markdown field.""" markdown = Markdown() diff --git a/tildes/tests/test_simplestring_field.py b/tildes/tests/test_simplestring_field.py index 651392e..0a29fe6 100644 --- a/tildes/tests/test_simplestring_field.py +++ b/tildes/tests/test_simplestring_field.py @@ -1,13 +1,14 @@ # Copyright (c) 2018 Tildes contributors # SPDX-License-Identifier: AGPL-3.0-or-later -from marshmallow import Schema, ValidationError +from marshmallow import ValidationError from pytest import raises +from tildes.schemas.base import BaseTildesSchema from tildes.schemas.fields import SimpleString -class SimpleStringTestSchema(Schema): +class SimpleStringTestSchema(BaseTildesSchema): """Simple schema class with a standard SimpleString field.""" subject = SimpleString() diff --git a/tildes/tildes/resources/group.py b/tildes/tildes/resources/group.py index e39447b..3177b47 100644 --- a/tildes/tildes/resources/group.py +++ b/tildes/tildes/resources/group.py @@ -16,8 +16,7 @@ from tildes.views.decorators import use_kwargs @use_kwargs( - GroupSchema(only=("path",)), - context=TildesSchemaContext(TildesSchemaContextDict(fix_path_capitalization=True)), + GroupSchema(only=("path",), context={"fix_path_capitalization": True}), location="matchdict", ) def group_by_path(request: Request, path: str) -> Group: diff --git a/tildes/tildes/schemas/base.py b/tildes/tildes/schemas/base.py new file mode 100644 index 0000000..48c967f --- /dev/null +++ b/tildes/tildes/schemas/base.py @@ -0,0 +1,22 @@ +# Copyright (c) 2018 Tildes contributors +# SPDX-License-Identifier: AGPL-3.0-or-later + +"""Base Marshmallow schema.""" + + +from typing import Any +from marshmallow import Schema +from tildes.schemas.context import TildesSchemaContextDict + + +class BaseTildesSchema(Schema): + """Base Marshmallow schema for Tildes schemas. + + Adds common code like the context dict. + """ + + context: TildesSchemaContextDict + + def __init__(self, context: TildesSchemaContextDict = {}, **kwargs: Any): + super().__init__(**kwargs) + self.context = context diff --git a/tildes/tildes/schemas/comment.py b/tildes/tildes/schemas/comment.py index 307d897..550b506 100644 --- a/tildes/tildes/schemas/comment.py +++ b/tildes/tildes/schemas/comment.py @@ -3,13 +3,12 @@ """Validation/dumping schema for comments.""" -from marshmallow import Schema - from tildes.enums import CommentLabelOption +from tildes.schemas.base import BaseTildesSchema from tildes.schemas.fields import Enum, ID36, Markdown, SimpleString -class CommentSchema(Schema): +class CommentSchema(BaseTildesSchema): """Marshmallow schema for comments.""" comment_id36 = ID36() @@ -17,7 +16,7 @@ class CommentSchema(Schema): parent_comment_id36 = ID36() -class CommentLabelSchema(Schema): +class CommentLabelSchema(BaseTildesSchema): """Marshmallow schema for comment labels.""" name = Enum(CommentLabelOption) diff --git a/tildes/tildes/schemas/group.py b/tildes/tildes/schemas/group.py index cecfd91..9e36cbf 100644 --- a/tildes/tildes/schemas/group.py +++ b/tildes/tildes/schemas/group.py @@ -7,11 +7,12 @@ import re from typing import Any import sqlalchemy_utils -from marshmallow import pre_load, Schema, validates +from marshmallow import pre_load, validates from marshmallow.exceptions import ValidationError from marshmallow.fields import DateTime from marshmallow.types import UnknownOption +from tildes.schemas.base import BaseTildesSchema from tildes.schemas.context import TildesSchemaContext, TildesSchemaContextDict from tildes.schemas.fields import Ltree, Markdown, SimpleString @@ -32,7 +33,7 @@ GROUP_PATH_ELEMENT_VALID_REGEX = re.compile( SHORT_DESCRIPTION_MAX_LENGTH = 200 -class GroupSchema(Schema): +class GroupSchema(BaseTildesSchema): """Marshmallow schema for groups.""" path = Ltree(required=True) @@ -48,7 +49,7 @@ class GroupSchema(Schema): ) -> dict: """Prepare the path value before it's validated.""" # pylint: disable=unused-argument - if not TildesSchemaContext.get(default=TildesSchemaContextDict()).get( + if not TildesSchemaContext.get(default=self.context).get( "fix_path_capitalization" ): return data diff --git a/tildes/tildes/schemas/group_wiki_page.py b/tildes/tildes/schemas/group_wiki_page.py index ea4229f..31a849d 100644 --- a/tildes/tildes/schemas/group_wiki_page.py +++ b/tildes/tildes/schemas/group_wiki_page.py @@ -3,15 +3,14 @@ """Validation/dumping schema for group wiki pages.""" -from marshmallow import Schema - +from tildes.schemas.base import BaseTildesSchema from tildes.schemas.fields import Markdown, SimpleString PAGE_NAME_MAX_LENGTH = 40 -class GroupWikiPageSchema(Schema): +class GroupWikiPageSchema(BaseTildesSchema): """Marshmallow schema for group wiki pages.""" page_name = SimpleString(max_length=PAGE_NAME_MAX_LENGTH) diff --git a/tildes/tildes/schemas/listing.py b/tildes/tildes/schemas/listing.py index 1c2a8e7..aeb1305 100644 --- a/tildes/tildes/schemas/listing.py +++ b/tildes/tildes/schemas/listing.py @@ -5,16 +5,17 @@ from typing import Any -from marshmallow import pre_load, Schema, validates_schema, ValidationError +from marshmallow import pre_load, validates_schema, ValidationError from marshmallow.fields import Boolean, Integer from marshmallow.types import UnknownOption from marshmallow.validate import Range from tildes.enums import TopicSortOption +from tildes.schemas.base import BaseTildesSchema from tildes.schemas.fields import Enum, ID36, Ltree, PostType, ShortTimePeriod -class PaginatedListingSchema(Schema): +class PaginatedListingSchema(BaseTildesSchema): """Marshmallow schema to validate arguments for a paginated listing page.""" after = ID36(load_default=None) diff --git a/tildes/tildes/schemas/message.py b/tildes/tildes/schemas/message.py index 2f16075..c82ccaa 100644 --- a/tildes/tildes/schemas/message.py +++ b/tildes/tildes/schemas/message.py @@ -3,16 +3,16 @@ """Validation/dumping schemas for messages.""" -from marshmallow import Schema from marshmallow.fields import DateTime, String +from tildes.schemas.base import BaseTildesSchema from tildes.schemas.fields import ID36, Markdown, SimpleString SUBJECT_MAX_LENGTH = 200 -class MessageConversationSchema(Schema): +class MessageConversationSchema(BaseTildesSchema): """Marshmallow schema for message conversations.""" conversation_id36 = ID36() @@ -22,7 +22,7 @@ class MessageConversationSchema(Schema): created_time = DateTime(dump_only=True) -class MessageReplySchema(Schema): +class MessageReplySchema(BaseTildesSchema): """Marshmallow schema for message replies.""" reply_id36 = ID36() diff --git a/tildes/tildes/schemas/topic.py b/tildes/tildes/schemas/topic.py index f133fb7..1cd52ab 100644 --- a/tildes/tildes/schemas/topic.py +++ b/tildes/tildes/schemas/topic.py @@ -7,11 +7,12 @@ import re from typing import Any from urllib.parse import urlparse -from marshmallow import pre_load, Schema, validates, validates_schema, ValidationError +from marshmallow import pre_load, validates, validates_schema, ValidationError from marshmallow.fields import DateTime, List, Nested, String, URL from marshmallow.types import UnknownOption from tildes.lib.url_transform import apply_url_transformations +from tildes.schemas.base import BaseTildesSchema from tildes.schemas.fields import Enum, ID36, Markdown, SimpleString from tildes.schemas.group import GroupSchema from tildes.schemas.user import UserSchema @@ -21,7 +22,7 @@ TITLE_MAX_LENGTH = 200 TAG_SYNONYMS = {"spoiler": ["spoilers"]} -class TopicSchema(Schema): +class TopicSchema(BaseTildesSchema): """Marshmallow schema for topics.""" topic_id36 = ID36() diff --git a/tildes/tildes/schemas/user.py b/tildes/tildes/schemas/user.py index 31bdc6f..7811afb 100644 --- a/tildes/tildes/schemas/user.py +++ b/tildes/tildes/schemas/user.py @@ -6,13 +6,14 @@ import re from typing import Any -from marshmallow import post_dump, pre_load, Schema, validates, validates_schema +from marshmallow import post_dump, pre_load, validates, validates_schema from marshmallow.exceptions import ValidationError from marshmallow.fields import DateTime, Email, String from marshmallow.types import UnknownOption from marshmallow.validate import Length, Regexp from tildes.lib.password import is_breached_password +from tildes.schemas.base import BaseTildesSchema from tildes.schemas.context import TildesSchemaContext, TildesSchemaContextDict from tildes.schemas.fields import Markdown @@ -43,7 +44,7 @@ EMAIL_ADDRESS_NOTE_MAX_LENGTH = 100 BIO_MAX_LENGTH = 2000 -class UserSchema(Schema): +class UserSchema(BaseTildesSchema): """Marshmallow schema for users.""" username = String( @@ -65,9 +66,7 @@ class UserSchema(Schema): def anonymize_username(self, data: dict, many: bool) -> dict: """Hide the username if the dumping context specifies to do so.""" # pylint: disable=unused-argument - if not TildesSchemaContext.get(default=TildesSchemaContextDict()).get( - "hide_username" - ): + if not TildesSchemaContext.get(default=self.context).get("hide_username"): return data if "username" not in data: @@ -106,7 +105,7 @@ class UserSchema(Schema): Requires check_breached_passwords be True in the schema's context. """ # pylint: disable=unused-argument - if not TildesSchemaContext.get(default=TildesSchemaContextDict()).get( + if not TildesSchemaContext.get(default=self.context).get( "check_breached_passwords" ): return @@ -126,7 +125,7 @@ class UserSchema(Schema): Requires username_trim_whitespace be True in the schema's context. """ # pylint: disable=unused-argument - if not TildesSchemaContext.get(default=TildesSchemaContextDict()).get( + if not TildesSchemaContext.get(default=self.context).get( "username_trim_whitespace" ): return data diff --git a/tildes/tildes/views/decorators.py b/tildes/tildes/views/decorators.py index bd4d880..15055ac 100644 --- a/tildes/tildes/views/decorators.py +++ b/tildes/tildes/views/decorators.py @@ -19,10 +19,7 @@ from tildes.schemas.context import TildesSchemaContext, TildesSchemaContextDict def use_kwargs( - argmap: Schema | dict[str, Field], - location: str = "query", - context: Context[Any] | None = None, - **kwargs: Any + argmap: Schema | dict[str, Field], location: str = "query", **kwargs: Any ) -> Callable: """Wrap the webargs @use_kwargs decorator with preferred default modifications. @@ -34,21 +31,15 @@ def use_kwargs( it just ignores them, instead of erroring when there's unexpected data (as there almost always is, especially because of Intercooler). """ - if context is None: - context = TildesSchemaContext(TildesSchemaContextDict()) + # convert a dict argmap to a Schema (the same way webargs would on its own) + if isinstance(argmap, dict): + argmap = Schema.from_dict(argmap)() - with context: - # convert a dict argmap to a Schema (the same way webargs would on its own) - if isinstance(argmap, dict): - argmap = Schema.from_dict(argmap)() + assert isinstance(argmap, Schema) # tell mypy the type is more restricted now - assert isinstance(argmap, Schema) # tell mypy the type is more restricted now + argmap.unknown = EXCLUDE - argmap.unknown = EXCLUDE - - return pyramidparser.use_kwargs( - argmap, location=location, unknown=None, **kwargs - ) + return pyramidparser.use_kwargs(argmap, location=location, unknown=None, **kwargs) def ic_view_config(**kwargs: Any) -> Callable: diff --git a/tildes/tildes/views/login.py b/tildes/tildes/views/login.py index e7ce532..a926fc5 100644 --- a/tildes/tildes/views/login.py +++ b/tildes/tildes/views/login.py @@ -61,8 +61,9 @@ def finish_login(request: Request, user: User, redirect_url: str) -> HTTPFound: route_name="login", request_method="POST", permission=NO_PERMISSION_REQUIRED ) @use_kwargs( - UserSchema(only=("username", "password")), - context=TildesSchemaContext(TildesSchemaContextDict(username_trim_whitespace=True)), + UserSchema( + only=("username", "password"), context={"username_trim_whitespace": True} + ), location="form", ) @use_kwargs({"from_url": String(load_default="")}, location="form") diff --git a/tildes/tildes/views/register.py b/tildes/tildes/views/register.py index d1d4e3b..4a44d88 100644 --- a/tildes/tildes/views/register.py +++ b/tildes/tildes/views/register.py @@ -15,7 +15,6 @@ from tildes.metrics import incr_counter from tildes.models.group import Group, GroupSubscription from tildes.models.log import Log from tildes.models.user import User, UserInviteCode -from tildes.schemas.context import TildesSchemaContext, TildesSchemaContextDict from tildes.schemas.user import UserSchema from tildes.views.decorators import not_logged_in, rate_limit_view, use_kwargs @@ -35,8 +34,9 @@ def get_register(request: Request, code: str) -> dict: route_name="register", request_method="POST", permission=NO_PERMISSION_REQUIRED ) @use_kwargs( - UserSchema(only=("username", "password")), - context=TildesSchemaContext(TildesSchemaContextDict(check_breached_passwords=True)), + UserSchema( + only=("username", "password"), context={"check_breached_passwords": True} + ), location="form", ) @use_kwargs(