mirror of https://gitlab.com/tildes/tildes.git
Deimos
5 years ago
2 changed files with 146 additions and 134 deletions
@ -0,0 +1,144 @@ |
|||
# Copyright (c) 2019 Tildes contributors <code@tildes.net> |
|||
# SPDX-License-Identifier: AGPL-3.0-or-later |
|||
|
|||
"""Define and attach request methods to the Pyramid request object.""" |
|||
|
|||
from typing import Any, Dict, Optional, Tuple |
|||
|
|||
from pyramid.config import Configurator |
|||
from pyramid.httpexceptions import HTTPTooManyRequests |
|||
from pyramid.request import Request |
|||
from redis import Redis |
|||
|
|||
from tildes.lib.ratelimit import RATE_LIMITED_ACTIONS, RateLimitResult |
|||
|
|||
|
|||
def get_redis_connection(request: Request) -> Redis: |
|||
"""Return a connection to the Redis server.""" |
|||
socket = request.registry.settings["redis.unix_socket_path"] |
|||
return Redis(unix_socket_path=socket) |
|||
|
|||
|
|||
def is_safe_request_method(request: Request) -> bool: |
|||
"""Return whether the request method is "safe" (is GET or HEAD).""" |
|||
return request.method in {"GET", "HEAD"} |
|||
|
|||
|
|||
def check_rate_limit(request: Request, action_name: str) -> RateLimitResult: |
|||
"""Check the rate limit for a particular action on a request.""" |
|||
try: |
|||
action = RATE_LIMITED_ACTIONS[action_name] |
|||
except KeyError: |
|||
raise ValueError("Invalid action name: %s" % action_name) |
|||
|
|||
action.redis = request.redis |
|||
|
|||
results = [] |
|||
|
|||
if action.by_user and request.user: |
|||
results.append(action.check_for_user_id(request.user.user_id)) |
|||
|
|||
if action.by_ip and request.remote_addr: |
|||
results.append(action.check_for_ip(request.remote_addr)) |
|||
|
|||
# no checks were done, return the "not limited" result |
|||
if not results: |
|||
return RateLimitResult.unlimited_result() |
|||
|
|||
return RateLimitResult.merged_result(results) |
|||
|
|||
|
|||
def apply_rate_limit(request: Request, action_name: str) -> None: |
|||
"""Check the rate limit for an action, and raise HTTP 429 if it's exceeded.""" |
|||
result = request.check_rate_limit(action_name) |
|||
if not result.is_allowed: |
|||
raise result.add_headers_to_response(HTTPTooManyRequests()) |
|||
|
|||
|
|||
def current_listing_base_url( |
|||
request: Request, query: Optional[Dict[str, Any]] = None |
|||
) -> str: |
|||
"""Return the "base" url for the current listing route. |
|||
|
|||
The "base" url represents the current listing, including any filtering options (or |
|||
the fact that filters are disabled). |
|||
|
|||
The `query` argument allows adding query variables to the generated url. |
|||
""" |
|||
base_vars_by_route: Dict[str, Tuple[str, ...]] = { |
|||
"bookmarks": ("per_page", "type"), |
|||
"group": ("order", "period", "per_page", "tag", "unfiltered"), |
|||
"group_search": ("order", "period", "per_page", "q"), |
|||
"home": ("order", "period", "per_page", "tag", "unfiltered"), |
|||
"search": ("order", "period", "per_page", "q"), |
|||
"user": ("order", "per_page", "type"), |
|||
} |
|||
|
|||
try: |
|||
base_view_vars = base_vars_by_route[request.matched_route.name] |
|||
except KeyError: |
|||
raise AttributeError("Current route is not supported.") |
|||
|
|||
query_vars = { |
|||
key: val for key, val in request.GET.copy().items() if key in base_view_vars |
|||
} |
|||
if query: |
|||
query_vars.update(query) |
|||
|
|||
return request.current_route_url(_query=query_vars) |
|||
|
|||
|
|||
def current_listing_normal_url( |
|||
request: Request, query: Optional[Dict[str, Any]] = None |
|||
) -> str: |
|||
"""Return the "normal" url for the current listing route. |
|||
|
|||
The "normal" url represents the current listing without any additional |
|||
filtering-related changes (the user's standard view of that listing). |
|||
|
|||
The `query` argument allows adding query variables to the generated url. |
|||
""" |
|||
normal_vars_by_route: Dict[str, Tuple[str, ...]] = { |
|||
"bookmarks": ("order", "period", "per_page"), |
|||
"group": ("order", "period", "per_page"), |
|||
"group_search": ("order", "period", "per_page", "q"), |
|||
"home": ("order", "period", "per_page"), |
|||
"notifications": ("per_page",), |
|||
"search": ("order", "period", "per_page", "q"), |
|||
"user": ("order", "per_page"), |
|||
} |
|||
|
|||
try: |
|||
normal_view_vars = normal_vars_by_route[request.matched_route.name] |
|||
except KeyError: |
|||
raise AttributeError("Current route is not supported.") |
|||
|
|||
query_vars = { |
|||
key: val for key, val in request.GET.copy().items() if key in normal_view_vars |
|||
} |
|||
if query: |
|||
query_vars.update(query) |
|||
|
|||
return request.current_route_url(_query=query_vars) |
|||
|
|||
|
|||
def includeme(config: Configurator) -> None: |
|||
"""Attach the request methods to the Pyramid request object.""" |
|||
config.add_request_method(is_safe_request_method, "is_safe_method", reify=True) |
|||
|
|||
# Add the request.redis request method to access a redis connection. This is done in |
|||
# a bit of a strange way to support being overridden in tests. |
|||
config.registry["redis_connection_factory"] = get_redis_connection |
|||
# pylint: disable=unnecessary-lambda |
|||
config.add_request_method( |
|||
lambda request: config.registry["redis_connection_factory"](request), |
|||
"redis", |
|||
reify=True, |
|||
) |
|||
# pylint: enable=unnecessary-lambda |
|||
|
|||
config.add_request_method(check_rate_limit, "check_rate_limit") |
|||
config.add_request_method(apply_rate_limit, "apply_rate_limit") |
|||
|
|||
config.add_request_method(current_listing_base_url, "current_listing_base_url") |
|||
config.add_request_method(current_listing_normal_url, "current_listing_normal_url") |
Write
Preview
Loading…
Cancel
Save
Reference in new issue