diff --git a/tildes/tests/test_ratelimit.py b/tildes/tests/test_ratelimit.py index 0dffc5e..a606032 100644 --- a/tildes/tests/test_ratelimit.py +++ b/tildes/tests/test_ratelimit.py @@ -24,10 +24,15 @@ def test_all_rate_limited_action_names_unique(): seen_names.add(action.name) -def test_action_with_all_types_disabled(): - """Ensure RateLimitedAction can't have both by_user and by_ip disabled.""" - with raises(ValueError): - RateLimitedAction("test", timedelta(hours=1), 5, by_user=False, by_ip=False) +def test_check_global_disabled(): + """Ensure global check is disabled if action is by_user or by_ip.""" + action = RateLimitedAction("test", timedelta(hours=1), 5, by_user=True, by_ip=False) + with raises(RateLimitError): + action.check_global() + + action = RateLimitedAction("test", timedelta(hours=1), 5, by_user=False, by_ip=True) + with raises(RateLimitError): + action.check_global() def test_check_by_user_id_disabled(): @@ -53,6 +58,31 @@ def test_max_burst_with_limit_1(): assert action.max_burst == 1 +def test_simple_global_rate_limiting(redis): + """Ensure simple global rate-limiting is working.""" + limit = 5 + + # define an action with max_burst equal to the full limit + action = RateLimitedAction( + "testaction", + timedelta(hours=1), + limit, + max_burst=limit, + by_user=False, + by_ip=False, + redis=redis, + ) + + # run the action the full number of times, should all be allowed + for _ in range(limit): + result = action.check_global() + assert result.is_allowed + + # try one more time, should be rejected + result = action.check_global() + assert not result.is_allowed + + def test_simple_rate_limiting_by_user_id(redis): """Ensure simple rate-limiting by user_id is working.""" limit = 5 diff --git a/tildes/tildes/lib/ratelimit.py b/tildes/tildes/lib/ratelimit.py index abdfcdd..8efa256 100644 --- a/tildes/tildes/lib/ratelimit.py +++ b/tildes/tildes/lib/ratelimit.py @@ -185,9 +185,6 @@ class RateLimitedAction: if max_burst and not 1 <= max_burst <= limit: raise ValueError("max_burst must be at least 1 and <= limit") - if not (by_user or by_ip): - raise ValueError("At least one of by_user or by_ip must be True") - self.name = name self.period = period self.limit = limit @@ -218,9 +215,16 @@ class RateLimitedAction: """Set the redis connection.""" self._redis = redis_connection - def _build_redis_key(self, by_type: str, value: Any) -> str: + @property + def is_global(self) -> bool: + """Whether the rate limit applies globally, not to particular users or IPs.""" + return not (self.by_user or self.by_ip) + + def _build_redis_key(self, by_type: str, value: Any = None) -> str: """Build the Redis key where this rate limit is maintained.""" - parts = ["ratelimit", self.name, by_type, str(value)] + parts = ["ratelimit", self.name, by_type] + if value: + parts.append(str(value)) return ":".join(parts) @@ -234,6 +238,24 @@ class RateLimitedAction: int(self.period.total_seconds()), ) + def check_global(self) -> RateLimitResult: + """Check a global rate limit to see if anyone can perform this action.""" + if not self.is_global: + raise RateLimitError("check_global called on non-global-limited action") + + key = self._build_redis_key("global") + result = self._call_redis_command(key) + + return RateLimitResult.from_redis_cell_result(result) + + def reset_global(self) -> None: + """Reset the global ratelimit on this action.""" + if not self.is_global: + raise RateLimitError("reset_global called on non-global-limited action") + + key = self._build_redis_key("global") + self.redis.delete(key) + def check_for_user_id(self, user_id: int) -> RateLimitResult: """Check whether a particular user_id can perform this action.""" if not self.by_user: diff --git a/tildes/tildes/request_methods.py b/tildes/tildes/request_methods.py index 2a78444..7274894 100644 --- a/tildes/tildes/request_methods.py +++ b/tildes/tildes/request_methods.py @@ -90,6 +90,9 @@ def check_rate_limit(request: Request, action_name: str) -> RateLimitResult: results = [] + if action.is_global: + results.append(action.check_global()) + if action.by_user and request.user: results.append(action.check_for_user_id(request.user.user_id))