diff --git a/tildes/alembic/versions/fa14e9f5ebe5_add_user_rate_limit_table.py b/tildes/alembic/versions/fa14e9f5ebe5_add_user_rate_limit_table.py new file mode 100644 index 0000000..22f7b5b --- /dev/null +++ b/tildes/alembic/versions/fa14e9f5ebe5_add_user_rate_limit_table.py @@ -0,0 +1,36 @@ +"""Add user_rate_limit table + +Revision ID: fa14e9f5ebe5 +Revises: b761d0185ca0 +Create Date: 2019-11-05 18:11:34.303355 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "fa14e9f5ebe5" +down_revision = "b761d0185ca0" +branch_labels = None +depends_on = None + + +def upgrade(): + op.create_table( + "user_rate_limit", + sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("action", sa.Text(), nullable=False), + sa.Column("period", sa.Interval(), nullable=False), + sa.Column("limit", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.user_id"], + name=op.f("fk_user_rate_limit_user_id_users"), + ), + sa.PrimaryKeyConstraint("user_id", "action", name=op.f("pk_user_rate_limit")), + ) + + +def downgrade(): + op.drop_table("user_rate_limit") diff --git a/tildes/tests/test_ratelimit.py b/tildes/tests/test_ratelimit.py index fd6faaf..0dffc5e 100644 --- a/tildes/tests/test_ratelimit.py +++ b/tildes/tests/test_ratelimit.py @@ -46,6 +46,13 @@ def test_check_by_ip_disabled(): action.check_for_ip("123.123.123.123") +def test_max_burst_with_limit_1(): + """Ensure an action with limit 1 also has its max_burst set to 1.""" + action = RateLimitedAction("test", timedelta(hours=1), 1) + + assert action.max_burst == 1 + + def test_simple_rate_limiting_by_user_id(redis): """Ensure simple rate-limiting by user_id is working.""" limit = 5 diff --git a/tildes/tildes/database_models.py b/tildes/tildes/database_models.py index 553632e..6a69829 100644 --- a/tildes/tildes/database_models.py +++ b/tildes/tildes/database_models.py @@ -23,4 +23,4 @@ from tildes.models.topic import ( TopicVisit, TopicVote, ) -from tildes.models.user import User, UserGroupSettings, UserInviteCode +from tildes.models.user import User, UserGroupSettings, UserInviteCode, UserRateLimit diff --git a/tildes/tildes/lib/ratelimit.py b/tildes/tildes/lib/ratelimit.py index 1c1c4d6..c4548f2 100644 --- a/tildes/tildes/lib/ratelimit.py +++ b/tildes/tildes/lib/ratelimit.py @@ -194,8 +194,8 @@ class RateLimitedAction: if max_burst: self.max_burst = max_burst else: - # if a max burst wasn't specified, set it to half the limit - self.max_burst = limit // 2 + # if max burst wasn't specified, set it to half the limit (no lower than 1) + self.max_burst = max(limit // 2, 1) self.by_user = by_user self.by_ip = by_ip diff --git a/tildes/tildes/models/user/__init__.py b/tildes/tildes/models/user/__init__.py index e0b22d2..8ab4788 100644 --- a/tildes/tildes/models/user/__init__.py +++ b/tildes/tildes/models/user/__init__.py @@ -3,3 +3,4 @@ from .user import User from .user_group_settings import UserGroupSettings from .user_invite_code import UserInviteCode +from .user_rate_limit import UserRateLimit diff --git a/tildes/tildes/models/user/user_rate_limit.py b/tildes/tildes/models/user/user_rate_limit.py new file mode 100644 index 0000000..23d1928 --- /dev/null +++ b/tildes/tildes/models/user/user_rate_limit.py @@ -0,0 +1,33 @@ +# Copyright (c) 2019 Tildes contributors +# SPDX-License-Identifier: AGPL-3.0-or-later + +"""Contains the UserRateLimit class.""" + +from datetime import timedelta + +from sqlalchemy import Column, ForeignKey, Integer, Interval, Text +from sqlalchemy.orm import relationship + +from tildes.models import DatabaseModel + +from .user import User + + +class UserRateLimit(DatabaseModel): + """Model for custom rate-limits on actions for individual users.""" + + __tablename__ = "user_rate_limit" + + user_id: int = Column(Integer, ForeignKey("users.user_id"), primary_key=True) + action: str = Column(Text, primary_key=True) + period: timedelta = Column(Interval, nullable=False) + limit: int = Column(Integer, nullable=False) + + user: User = relationship("User", innerjoin=True) + + def __init__(self, user: User, action: str, period: timedelta, limit: int): + """Set a new custom rate-limit for a particular user and action.""" + self.user = user + self.action = action + self.period = period + self.limit = limit diff --git a/tildes/tildes/request_methods.py b/tildes/tildes/request_methods.py index ba57622..19f672d 100644 --- a/tildes/tildes/request_methods.py +++ b/tildes/tildes/request_methods.py @@ -10,7 +10,12 @@ from pyramid.httpexceptions import HTTPTooManyRequests from pyramid.request import Request from redis import Redis -from tildes.lib.ratelimit import RATE_LIMITED_ACTIONS, RateLimitResult +from tildes.lib.ratelimit import ( + RATE_LIMITED_ACTIONS, + RateLimitedAction, + RateLimitResult, +) +from tildes.models.user import UserRateLimit def get_redis_connection(request: Request) -> Redis: @@ -50,10 +55,33 @@ def is_safe_request_method(request: Request) -> bool: def check_rate_limit(request: Request, action_name: str) -> RateLimitResult: """Check the rate limit for a particular action on a request.""" - try: - action = RATE_LIMITED_ACTIONS[action_name] - except KeyError: - raise ValueError("Invalid action name: %s" % action_name) + action = None + + # check for a custom rate-limit for the user + if request.user: + user_limit = ( + request.query(UserRateLimit) + .filter( + UserRateLimit.user == request.user, UserRateLimit.action == action_name + ) + .one_or_none() + ) + + if user_limit: + action = RateLimitedAction( + action_name, + user_limit.period, + user_limit.limit, + by_user=True, + by_ip=False, + ) + + # if a custom rate-limit wasn't found, use the default, global rate-limit + if not action: + try: + action = RATE_LIMITED_ACTIONS[action_name] + except KeyError: + raise ValueError("Invalid action name: %s" % action_name) action.redis = request.redis diff --git a/tildes/tildes/views/api/web/exceptions.py b/tildes/tildes/views/api/web/exceptions.py index 5ea78bf..40b4ac1 100644 --- a/tildes/tildes/views/api/web/exceptions.py +++ b/tildes/tildes/views/api/web/exceptions.py @@ -86,10 +86,14 @@ def httptoomanyrequests(request: Request) -> Response: """Update a 429 error to show wait time info in the response text.""" response = request.exception - retry_seconds = request.exception.headers["Retry-After"] - response.text = ( - f"Rate limit exceeded. Please wait {retry_seconds} seconds before retrying." - ) + retry_seconds = int(request.exception.headers["Retry-After"]) + + if retry_seconds >= 60: + retry_wait = f"{retry_seconds // 60} minutes" + else: + retry_wait = f"{retry_seconds} seconds" + + response.text = f"Rate limit exceeded. Please wait {retry_wait} before retrying." return response