|
@ -10,7 +10,12 @@ from pyramid.httpexceptions import HTTPTooManyRequests |
|
|
from pyramid.request import Request |
|
|
from pyramid.request import Request |
|
|
from redis import Redis |
|
|
from redis import Redis |
|
|
|
|
|
|
|
|
from tildes.lib.ratelimit import RATE_LIMITED_ACTIONS, RateLimitResult |
|
|
|
|
|
|
|
|
from tildes.lib.ratelimit import ( |
|
|
|
|
|
RATE_LIMITED_ACTIONS, |
|
|
|
|
|
RateLimitedAction, |
|
|
|
|
|
RateLimitResult, |
|
|
|
|
|
) |
|
|
|
|
|
from tildes.models.user import UserRateLimit |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_redis_connection(request: Request) -> Redis: |
|
|
def get_redis_connection(request: Request) -> Redis: |
|
@ -50,10 +55,33 @@ def is_safe_request_method(request: Request) -> bool: |
|
|
|
|
|
|
|
|
def check_rate_limit(request: Request, action_name: str) -> RateLimitResult: |
|
|
def check_rate_limit(request: Request, action_name: str) -> RateLimitResult: |
|
|
"""Check the rate limit for a particular action on a request.""" |
|
|
"""Check the rate limit for a particular action on a request.""" |
|
|
try: |
|
|
|
|
|
action = RATE_LIMITED_ACTIONS[action_name] |
|
|
|
|
|
except KeyError: |
|
|
|
|
|
raise ValueError("Invalid action name: %s" % action_name) |
|
|
|
|
|
|
|
|
action = None |
|
|
|
|
|
|
|
|
|
|
|
# check for a custom rate-limit for the user |
|
|
|
|
|
if request.user: |
|
|
|
|
|
user_limit = ( |
|
|
|
|
|
request.query(UserRateLimit) |
|
|
|
|
|
.filter( |
|
|
|
|
|
UserRateLimit.user == request.user, UserRateLimit.action == action_name |
|
|
|
|
|
) |
|
|
|
|
|
.one_or_none() |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if user_limit: |
|
|
|
|
|
action = RateLimitedAction( |
|
|
|
|
|
action_name, |
|
|
|
|
|
user_limit.period, |
|
|
|
|
|
user_limit.limit, |
|
|
|
|
|
by_user=True, |
|
|
|
|
|
by_ip=False, |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
# if a custom rate-limit wasn't found, use the default, global rate-limit |
|
|
|
|
|
if not action: |
|
|
|
|
|
try: |
|
|
|
|
|
action = RATE_LIMITED_ACTIONS[action_name] |
|
|
|
|
|
except KeyError: |
|
|
|
|
|
raise ValueError("Invalid action name: %s" % action_name) |
|
|
|
|
|
|
|
|
action.redis = request.redis |
|
|
action.redis = request.redis |
|
|
|
|
|
|
|
|