diff --git a/tildes/tildes/__init__.py b/tildes/tildes/__init__.py index 7be2000..aaa7300 100644 --- a/tildes/tildes/__init__.py +++ b/tildes/tildes/__init__.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, Optional from paste.deploy.config import PrefixMiddleware from pyramid.config import Configurator +from pyramid.httpexceptions import HTTPTooManyRequests from pyramid.registry import Registry from pyramid.request import Request from redis import StrictRedis @@ -50,6 +51,7 @@ def main(global_config: Dict[str, str], **settings: str) -> PrefixMiddleware: # pylint: enable=unnecessary-lambda config.add_request_method(check_rate_limit, "check_rate_limit") + config.add_request_method(apply_rate_limit, "apply_rate_limit") config.add_request_method(current_listing_base_url, "current_listing_base_url") config.add_request_method(current_listing_normal_url, "current_listing_normal_url") @@ -120,6 +122,13 @@ def check_rate_limit(request: Request, action_name: str) -> RateLimitResult: return RateLimitResult.merged_result(results) +def apply_rate_limit(request: Request, action_name: str) -> None: + """Check the rate limit for an action, and raise HTTP 429 if it's exceeded.""" + result = request.check_rate_limit(action_name) + if not result.is_allowed: + raise result.add_headers_to_response(HTTPTooManyRequests()) + + def current_listing_base_url( request: Request, query: Optional[Dict[str, Any]] = None ) -> str: diff --git a/tildes/tildes/lib/ratelimit.py b/tildes/tildes/lib/ratelimit.py index 96661b7..7046736 100644 --- a/tildes/tildes/lib/ratelimit.py +++ b/tildes/tildes/lib/ratelimit.py @@ -281,6 +281,8 @@ _RATE_LIMITED_ACTIONS = ( RateLimitedAction("login", timedelta(hours=1), 20), RateLimitedAction("login_two_factor", timedelta(hours=1), 20), RateLimitedAction("register", timedelta(hours=1), 50), + RateLimitedAction("topic_post", timedelta(hours=1), 6, max_burst=4), + RateLimitedAction("comment_post", timedelta(hours=1), 30, max_burst=20), ) # (public) dict to be able to look up the actions by name diff --git a/tildes/tildes/views/api/web/comment.py b/tildes/tildes/views/api/web/comment.py index 84ee5f5..e17fc41 100644 --- a/tildes/tildes/views/api/web/comment.py +++ b/tildes/tildes/views/api/web/comment.py @@ -16,7 +16,7 @@ from tildes.models.comment import Comment, CommentNotification, CommentTag, Comm from tildes.models.topic import TopicVisit from tildes.schemas.comment import CommentSchema, CommentTagSchema from tildes.views import IC_NOOP -from tildes.views.decorators import ic_view_config +from tildes.views.decorators import ic_view_config, rate_limit_view def _increment_topic_comments_seen(request: Request, comment: Comment) -> None: @@ -57,6 +57,7 @@ def _increment_topic_comments_seen(request: Request, comment: Comment) -> None: permission="comment", ) @use_kwargs(CommentSchema(only=("markdown",))) +@rate_limit_view("comment_post") def post_toplevel_comment(request: Request, markdown: str) -> dict: """Post a new top-level comment on a topic with Intercooler.""" topic = request.context @@ -90,6 +91,7 @@ def post_toplevel_comment(request: Request, markdown: str) -> dict: permission="reply", ) @use_kwargs(CommentSchema(only=("markdown",))) +@rate_limit_view("comment_post") def post_comment_reply(request: Request, markdown: str) -> dict: """Post a reply to a comment with Intercooler.""" parent_comment = request.context diff --git a/tildes/tildes/views/decorators.py b/tildes/tildes/views/decorators.py index e8b3042..c0b0a6f 100644 --- a/tildes/tildes/views/decorators.py +++ b/tildes/tildes/views/decorators.py @@ -2,7 +2,7 @@ from typing import Any, Callable -from pyramid.httpexceptions import HTTPFound, HTTPTooManyRequests +from pyramid.httpexceptions import HTTPFound from pyramid.request import Request from pyramid.view import view_config @@ -35,10 +35,8 @@ def rate_limit_view(action_name: str) -> Callable: def decorator(func: Callable) -> Callable: def wrapper(*args: Any, **kwargs: Any) -> Any: request = args[0] - result = request.check_rate_limit(action_name) - if not result.is_allowed: - raise result.add_headers_to_response(HTTPTooManyRequests()) + request.apply_rate_limit(action_name) return func(*args, **kwargs) diff --git a/tildes/tildes/views/topic.py b/tildes/tildes/views/topic.py index 46e4539..5156272 100644 --- a/tildes/tildes/views/topic.py +++ b/tildes/tildes/views/topic.py @@ -30,6 +30,7 @@ from tildes.schemas.comment import CommentSchema from tildes.schemas.fields import Enum, ShortTimePeriod from tildes.schemas.topic import TopicSchema from tildes.schemas.topic_listing import TopicListingSchema +from tildes.views.decorators import rate_limit_view DefaultSettings = namedtuple("DefaultSettings", ["order", "period"]) @@ -64,6 +65,8 @@ def post_group_topics( except ValidationError: raise ValidationError({"tags": ["Invalid tags"]}) + request.apply_rate_limit("topic_post") + request.db_session.add(new_topic) request.db_session.add(LogTopic(LogEventType.TOPIC_POST, request, new_topic)) @@ -225,6 +228,7 @@ def get_topic(request: Request, comment_order: CommentSortOption) -> dict: @view_config(route_name="topic", request_method="POST", permission="comment") @use_kwargs(CommentSchema(only=("markdown",))) +@rate_limit_view("comment_post") def post_comment_on_topic(request: Request, markdown: str) -> HTTPFound: """Post a new top-level comment on a topic.""" topic = request.context