|
|
@ -3,16 +3,24 @@ |
|
|
|
|
|
|
|
"""Configuration and functionality related to authentication/authorization.""" |
|
|
|
|
|
|
|
import jwt |
|
|
|
from collections.abc import Sequence |
|
|
|
from typing import Any, Optional |
|
|
|
from datetime import datetime, timedelta, timezone |
|
|
|
from typing import Any, Callable, Optional |
|
|
|
|
|
|
|
from pyramid.authentication import SessionAuthenticationPolicy |
|
|
|
from pyramid.authentication import ( |
|
|
|
SessionAuthenticationPolicy, |
|
|
|
CallbackAuthenticationPolicy, |
|
|
|
) |
|
|
|
from pyramid.authorization import ACLAuthorizationPolicy |
|
|
|
from pyramid.config import Configurator |
|
|
|
from pyramid.httpexceptions import HTTPFound |
|
|
|
from pyramid.interfaces import IAuthenticationPolicy |
|
|
|
from pyramid_multiauth import MultiAuthenticationPolicy |
|
|
|
from pyramid.request import Request |
|
|
|
from pyramid.security import Allow, Everyone |
|
|
|
from sqlalchemy.orm import joinedload |
|
|
|
from zope.interface import implementer |
|
|
|
|
|
|
|
from tildes.models.user import User |
|
|
|
|
|
|
@ -80,12 +88,20 @@ def includeme(config: Configurator) -> None: |
|
|
|
|
|
|
|
config.set_authorization_policy(ACLAuthorizationPolicy()) |
|
|
|
|
|
|
|
config.set_authentication_policy( |
|
|
|
SessionAuthenticationPolicy(callback=auth_callback) |
|
|
|
) |
|
|
|
# Get the JWT secret from settings |
|
|
|
jwt_secret = config.registry.settings["jwt.secret"] |
|
|
|
|
|
|
|
# Configure both session and JWT authentication |
|
|
|
policies = [ |
|
|
|
SessionAuthenticationPolicy(callback=auth_callback), |
|
|
|
JWTAuthenticationPolicy(secret=jwt_secret, callback=auth_callback), |
|
|
|
] |
|
|
|
config.set_authentication_policy(MultiAuthenticationPolicy(policies)) |
|
|
|
|
|
|
|
# enable CSRF checking globally by default |
|
|
|
config.set_default_csrf_options(require_csrf=True) |
|
|
|
# enable CSRF checking globally by default, but exclude API endpoints |
|
|
|
config.set_default_csrf_options( |
|
|
|
require_csrf=True, callback=lambda request: not request.path.startswith("/api/") |
|
|
|
) |
|
|
|
|
|
|
|
# make the logged-in User object available as request.user |
|
|
|
config.add_request_method(get_authenticated_user, "user", reify=True) |
|
|
@ -101,3 +117,78 @@ def has_any_permission( |
|
|
|
return any( |
|
|
|
request.has_permission(permission, context) for permission in permissions |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
@implementer(IAuthenticationPolicy) |
|
|
|
class JWTAuthenticationPolicy(CallbackAuthenticationPolicy): |
|
|
|
"""Authentication policy for JWT tokens. |
|
|
|
|
|
|
|
This policy checks for an Authorization header with a Bearer token. |
|
|
|
The token is expected to be a JWT signed with the application's secret key. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
secret: str, |
|
|
|
callback: None | Callable[[int, Request], Optional[Sequence[str]]] = None, |
|
|
|
): |
|
|
|
"""Initialize the policy with a secret key for JWT validation.""" |
|
|
|
self.secret = secret |
|
|
|
self.callback = callback |
|
|
|
|
|
|
|
def create_jwt_token(self, user_id: int, expiry: int = 86400) -> str: |
|
|
|
"""Create a new JWT token for a user.""" |
|
|
|
payload = { |
|
|
|
"sub": str(user_id), # JWT subjects must be strings |
|
|
|
"iat": datetime.now(timezone.utc), |
|
|
|
"exp": datetime.now(timezone.utc) + timedelta(seconds=expiry), |
|
|
|
} |
|
|
|
return jwt.encode(payload, self.secret, algorithm="HS256") |
|
|
|
|
|
|
|
def validate_jwt_token(self, token: str) -> Optional[dict[str, Any]]: |
|
|
|
"""Validate a JWT token and return its payload if valid.""" |
|
|
|
try: |
|
|
|
# jwt.decode() WILL verify the expiration as well. |
|
|
|
# This does not need to be checked separately. |
|
|
|
return jwt.decode(token, self.secret, algorithms=["HS256"]) |
|
|
|
except jwt.InvalidTokenError: |
|
|
|
return None |
|
|
|
|
|
|
|
def unauthenticated_userid(self, request: Request) -> Optional[int]: |
|
|
|
"""Return the user id from the token without validating it.""" |
|
|
|
|
|
|
|
if not request.path.startswith("/api/"): |
|
|
|
return None |
|
|
|
|
|
|
|
auth_header = request.headers.get("Authorization") |
|
|
|
if not auth_header: |
|
|
|
return None |
|
|
|
|
|
|
|
try: |
|
|
|
auth_type, token = auth_header.split(" ", 1) |
|
|
|
except ValueError: |
|
|
|
return None |
|
|
|
|
|
|
|
if auth_type.lower() != "bearer": |
|
|
|
return None |
|
|
|
|
|
|
|
payload = self.validate_jwt_token(token) |
|
|
|
if not payload: |
|
|
|
return None |
|
|
|
|
|
|
|
try: |
|
|
|
user_id = int( |
|
|
|
payload["sub"] |
|
|
|
) # JWT subjects must be strings, convert to int |
|
|
|
except (KeyError, ValueError): |
|
|
|
return None |
|
|
|
|
|
|
|
return user_id |
|
|
|
|
|
|
|
def remember(self, _request: Request, _userid: int, **_kw: dict) -> Sequence[tuple]: |
|
|
|
"""This should not be used for JWT authentication as it is stateless.""" |
|
|
|
return [] |
|
|
|
|
|
|
|
def forget(self, _request: Request) -> Sequence[tuple]: |
|
|
|
"""This should not be used for JWT authentication as it is stateless.""" |
|
|
|
return [] |