Browse Source

Add support for globally rate-limiting actions

Previously, rate limits had to apply to a particular user or a
particular IP address, or both. This adds support for global
rate-limits, where the limit will apply to everyone trying to perform
the action. This probably won't be used much overall, but might be
necessary for certain cases where something abusive is happening and it
can't be easily blocked by user or IP.

This is a bit ugly and would probably be better implemented by having a
separate class that inherits from RateLimitedAction or something
similar, but it will do the job.
merge-requests/135/head
Deimos 4 years ago
parent
commit
06764e9bc5
  1. 38
      tildes/tests/test_ratelimit.py
  2. 32
      tildes/tildes/lib/ratelimit.py
  3. 3
      tildes/tildes/request_methods.py

38
tildes/tests/test_ratelimit.py

@ -24,10 +24,15 @@ def test_all_rate_limited_action_names_unique():
seen_names.add(action.name)
def test_action_with_all_types_disabled():
"""Ensure RateLimitedAction can't have both by_user and by_ip disabled."""
with raises(ValueError):
RateLimitedAction("test", timedelta(hours=1), 5, by_user=False, by_ip=False)
def test_check_global_disabled():
"""Ensure global check is disabled if action is by_user or by_ip."""
action = RateLimitedAction("test", timedelta(hours=1), 5, by_user=True, by_ip=False)
with raises(RateLimitError):
action.check_global()
action = RateLimitedAction("test", timedelta(hours=1), 5, by_user=False, by_ip=True)
with raises(RateLimitError):
action.check_global()
def test_check_by_user_id_disabled():
@ -53,6 +58,31 @@ def test_max_burst_with_limit_1():
assert action.max_burst == 1
def test_simple_global_rate_limiting(redis):
"""Ensure simple global rate-limiting is working."""
limit = 5
# define an action with max_burst equal to the full limit
action = RateLimitedAction(
"testaction",
timedelta(hours=1),
limit,
max_burst=limit,
by_user=False,
by_ip=False,
redis=redis,
)
# run the action the full number of times, should all be allowed
for _ in range(limit):
result = action.check_global()
assert result.is_allowed
# try one more time, should be rejected
result = action.check_global()
assert not result.is_allowed
def test_simple_rate_limiting_by_user_id(redis):
"""Ensure simple rate-limiting by user_id is working."""
limit = 5

32
tildes/tildes/lib/ratelimit.py

@ -185,9 +185,6 @@ class RateLimitedAction:
if max_burst and not 1 <= max_burst <= limit:
raise ValueError("max_burst must be at least 1 and <= limit")
if not (by_user or by_ip):
raise ValueError("At least one of by_user or by_ip must be True")
self.name = name
self.period = period
self.limit = limit
@ -218,9 +215,16 @@ class RateLimitedAction:
"""Set the redis connection."""
self._redis = redis_connection
def _build_redis_key(self, by_type: str, value: Any) -> str:
@property
def is_global(self) -> bool:
"""Whether the rate limit applies globally, not to particular users or IPs."""
return not (self.by_user or self.by_ip)
def _build_redis_key(self, by_type: str, value: Any = None) -> str:
"""Build the Redis key where this rate limit is maintained."""
parts = ["ratelimit", self.name, by_type, str(value)]
parts = ["ratelimit", self.name, by_type]
if value:
parts.append(str(value))
return ":".join(parts)
@ -234,6 +238,24 @@ class RateLimitedAction:
int(self.period.total_seconds()),
)
def check_global(self) -> RateLimitResult:
"""Check a global rate limit to see if anyone can perform this action."""
if not self.is_global:
raise RateLimitError("check_global called on non-global-limited action")
key = self._build_redis_key("global")
result = self._call_redis_command(key)
return RateLimitResult.from_redis_cell_result(result)
def reset_global(self) -> None:
"""Reset the global ratelimit on this action."""
if not self.is_global:
raise RateLimitError("reset_global called on non-global-limited action")
key = self._build_redis_key("global")
self.redis.delete(key)
def check_for_user_id(self, user_id: int) -> RateLimitResult:
"""Check whether a particular user_id can perform this action."""
if not self.by_user:

3
tildes/tildes/request_methods.py

@ -90,6 +90,9 @@ def check_rate_limit(request: Request, action_name: str) -> RateLimitResult:
results = []
if action.is_global:
results.append(action.check_global())
if action.by_user and request.user:
results.append(action.check_for_user_id(request.user.user_id))

Loading…
Cancel
Save