Browse Source

Move request methods into a module

merge-requests/85/head
Deimos 5 years ago
parent
commit
e006a6f8f5
  1. 136
      tildes/tildes/__init__.py
  2. 144
      tildes/tildes/request_methods.py

136
tildes/tildes/__init__.py

@ -3,20 +3,15 @@
"""Configure and initialize the Pyramid app.""" """Configure and initialize the Pyramid app."""
from typing import Any, Dict, Optional, Tuple
from typing import Dict
import sentry_sdk import sentry_sdk
from marshmallow.exceptions import ValidationError from marshmallow.exceptions import ValidationError
from paste.deploy.config import PrefixMiddleware from paste.deploy.config import PrefixMiddleware
from pyramid.config import Configurator from pyramid.config import Configurator
from pyramid.httpexceptions import HTTPTooManyRequests
from pyramid.request import Request
from redis import Redis
from sentry_sdk.integrations.pyramid import PyramidIntegration from sentry_sdk.integrations.pyramid import PyramidIntegration
from webassets import Bundle from webassets import Bundle
from tildes.lib.ratelimit import RATE_LIMITED_ACTIONS, RateLimitResult
def main(global_config: Dict[str, str], **settings: str) -> PrefixMiddleware: def main(global_config: Dict[str, str], **settings: str) -> PrefixMiddleware:
"""Configure and return a Pyramid WSGI application.""" """Configure and return a Pyramid WSGI application."""
@ -31,6 +26,7 @@ def main(global_config: Dict[str, str], **settings: str) -> PrefixMiddleware:
config.include("tildes.auth") config.include("tildes.auth")
config.include("tildes.jinja") config.include("tildes.jinja")
config.include("tildes.json") config.include("tildes.json")
config.include("tildes.request_methods")
config.include("tildes.routes") config.include("tildes.routes")
config.include("tildes.tweens") config.include("tildes.tweens")
@ -42,25 +38,6 @@ def main(global_config: Dict[str, str], **settings: str) -> PrefixMiddleware:
config.add_static_view("images", "/images") config.add_static_view("images", "/images")
config.add_request_method(is_safe_request_method, "is_safe_method", reify=True)
# Add the request.redis request method to access a redis connection. This is done in
# a bit of a strange way to support being overridden in tests.
config.registry["redis_connection_factory"] = get_redis_connection
# pylint: disable=unnecessary-lambda
config.add_request_method(
lambda request: config.registry["redis_connection_factory"](request),
"redis",
reify=True,
)
# pylint: enable=unnecessary-lambda
config.add_request_method(check_rate_limit, "check_rate_limit")
config.add_request_method(apply_rate_limit, "apply_rate_limit")
config.add_request_method(current_listing_base_url, "current_listing_base_url")
config.add_request_method(current_listing_normal_url, "current_listing_normal_url")
if settings.get("sentry_dsn"): if settings.get("sentry_dsn"):
sentry_sdk.init( sentry_sdk.init(
dsn=settings["sentry_dsn"], dsn=settings["sentry_dsn"],
@ -77,112 +54,3 @@ def main(global_config: Dict[str, str], **settings: str) -> PrefixMiddleware:
prefixed_app = PrefixMiddleware(app) prefixed_app = PrefixMiddleware(app)
return prefixed_app return prefixed_app
def get_redis_connection(request: Request) -> Redis:
"""Return a connection to the Redis server."""
socket = request.registry.settings["redis.unix_socket_path"]
return Redis(unix_socket_path=socket)
def is_safe_request_method(request: Request) -> bool:
"""Return whether the request method is "safe" (is GET or HEAD)."""
return request.method in {"GET", "HEAD"}
def check_rate_limit(request: Request, action_name: str) -> RateLimitResult:
"""Check the rate limit for a particular action on a request."""
try:
action = RATE_LIMITED_ACTIONS[action_name]
except KeyError:
raise ValueError("Invalid action name: %s" % action_name)
action.redis = request.redis
results = []
if action.by_user and request.user:
results.append(action.check_for_user_id(request.user.user_id))
if action.by_ip and request.remote_addr:
results.append(action.check_for_ip(request.remote_addr))
# no checks were done, return the "not limited" result
if not results:
return RateLimitResult.unlimited_result()
return RateLimitResult.merged_result(results)
def apply_rate_limit(request: Request, action_name: str) -> None:
"""Check the rate limit for an action, and raise HTTP 429 if it's exceeded."""
result = request.check_rate_limit(action_name)
if not result.is_allowed:
raise result.add_headers_to_response(HTTPTooManyRequests())
def current_listing_base_url(
request: Request, query: Optional[Dict[str, Any]] = None
) -> str:
"""Return the "base" url for the current listing route.
The "base" url represents the current listing, including any filtering options (or
the fact that filters are disabled).
The `query` argument allows adding query variables to the generated url.
"""
base_vars_by_route: Dict[str, Tuple[str, ...]] = {
"bookmarks": ("per_page", "type"),
"group": ("order", "period", "per_page", "tag", "unfiltered"),
"group_search": ("order", "period", "per_page", "q"),
"home": ("order", "period", "per_page", "tag", "unfiltered"),
"search": ("order", "period", "per_page", "q"),
"user": ("order", "per_page", "type"),
}
try:
base_view_vars = base_vars_by_route[request.matched_route.name]
except KeyError:
raise AttributeError("Current route is not supported.")
query_vars = {
key: val for key, val in request.GET.copy().items() if key in base_view_vars
}
if query:
query_vars.update(query)
return request.current_route_url(_query=query_vars)
def current_listing_normal_url(
request: Request, query: Optional[Dict[str, Any]] = None
) -> str:
"""Return the "normal" url for the current listing route.
The "normal" url represents the current listing without any additional
filtering-related changes (the user's standard view of that listing).
The `query` argument allows adding query variables to the generated url.
"""
normal_vars_by_route: Dict[str, Tuple[str, ...]] = {
"bookmarks": ("order", "period", "per_page"),
"group": ("order", "period", "per_page"),
"group_search": ("order", "period", "per_page", "q"),
"home": ("order", "period", "per_page"),
"notifications": ("per_page",),
"search": ("order", "period", "per_page", "q"),
"user": ("order", "per_page"),
}
try:
normal_view_vars = normal_vars_by_route[request.matched_route.name]
except KeyError:
raise AttributeError("Current route is not supported.")
query_vars = {
key: val for key, val in request.GET.copy().items() if key in normal_view_vars
}
if query:
query_vars.update(query)
return request.current_route_url(_query=query_vars)

144
tildes/tildes/request_methods.py

@ -0,0 +1,144 @@
# Copyright (c) 2019 Tildes contributors <code@tildes.net>
# SPDX-License-Identifier: AGPL-3.0-or-later
"""Define and attach request methods to the Pyramid request object."""
from typing import Any, Dict, Optional, Tuple
from pyramid.config import Configurator
from pyramid.httpexceptions import HTTPTooManyRequests
from pyramid.request import Request
from redis import Redis
from tildes.lib.ratelimit import RATE_LIMITED_ACTIONS, RateLimitResult
def get_redis_connection(request: Request) -> Redis:
"""Return a connection to the Redis server."""
socket = request.registry.settings["redis.unix_socket_path"]
return Redis(unix_socket_path=socket)
def is_safe_request_method(request: Request) -> bool:
"""Return whether the request method is "safe" (is GET or HEAD)."""
return request.method in {"GET", "HEAD"}
def check_rate_limit(request: Request, action_name: str) -> RateLimitResult:
"""Check the rate limit for a particular action on a request."""
try:
action = RATE_LIMITED_ACTIONS[action_name]
except KeyError:
raise ValueError("Invalid action name: %s" % action_name)
action.redis = request.redis
results = []
if action.by_user and request.user:
results.append(action.check_for_user_id(request.user.user_id))
if action.by_ip and request.remote_addr:
results.append(action.check_for_ip(request.remote_addr))
# no checks were done, return the "not limited" result
if not results:
return RateLimitResult.unlimited_result()
return RateLimitResult.merged_result(results)
def apply_rate_limit(request: Request, action_name: str) -> None:
"""Check the rate limit for an action, and raise HTTP 429 if it's exceeded."""
result = request.check_rate_limit(action_name)
if not result.is_allowed:
raise result.add_headers_to_response(HTTPTooManyRequests())
def current_listing_base_url(
request: Request, query: Optional[Dict[str, Any]] = None
) -> str:
"""Return the "base" url for the current listing route.
The "base" url represents the current listing, including any filtering options (or
the fact that filters are disabled).
The `query` argument allows adding query variables to the generated url.
"""
base_vars_by_route: Dict[str, Tuple[str, ...]] = {
"bookmarks": ("per_page", "type"),
"group": ("order", "period", "per_page", "tag", "unfiltered"),
"group_search": ("order", "period", "per_page", "q"),
"home": ("order", "period", "per_page", "tag", "unfiltered"),
"search": ("order", "period", "per_page", "q"),
"user": ("order", "per_page", "type"),
}
try:
base_view_vars = base_vars_by_route[request.matched_route.name]
except KeyError:
raise AttributeError("Current route is not supported.")
query_vars = {
key: val for key, val in request.GET.copy().items() if key in base_view_vars
}
if query:
query_vars.update(query)
return request.current_route_url(_query=query_vars)
def current_listing_normal_url(
request: Request, query: Optional[Dict[str, Any]] = None
) -> str:
"""Return the "normal" url for the current listing route.
The "normal" url represents the current listing without any additional
filtering-related changes (the user's standard view of that listing).
The `query` argument allows adding query variables to the generated url.
"""
normal_vars_by_route: Dict[str, Tuple[str, ...]] = {
"bookmarks": ("order", "period", "per_page"),
"group": ("order", "period", "per_page"),
"group_search": ("order", "period", "per_page", "q"),
"home": ("order", "period", "per_page"),
"notifications": ("per_page",),
"search": ("order", "period", "per_page", "q"),
"user": ("order", "per_page"),
}
try:
normal_view_vars = normal_vars_by_route[request.matched_route.name]
except KeyError:
raise AttributeError("Current route is not supported.")
query_vars = {
key: val for key, val in request.GET.copy().items() if key in normal_view_vars
}
if query:
query_vars.update(query)
return request.current_route_url(_query=query_vars)
def includeme(config: Configurator) -> None:
"""Attach the request methods to the Pyramid request object."""
config.add_request_method(is_safe_request_method, "is_safe_method", reify=True)
# Add the request.redis request method to access a redis connection. This is done in
# a bit of a strange way to support being overridden in tests.
config.registry["redis_connection_factory"] = get_redis_connection
# pylint: disable=unnecessary-lambda
config.add_request_method(
lambda request: config.registry["redis_connection_factory"](request),
"redis",
reify=True,
)
# pylint: enable=unnecessary-lambda
config.add_request_method(check_rate_limit, "check_rate_limit")
config.add_request_method(apply_rate_limit, "apply_rate_limit")
config.add_request_method(current_listing_base_url, "current_listing_base_url")
config.add_request_method(current_listing_normal_url, "current_listing_normal_url")
Loading…
Cancel
Save