diff --git a/tildes/development.ini b/tildes/development.ini index f463e17..809bcb6 100644 --- a/tildes/development.ini +++ b/tildes/development.ini @@ -45,3 +45,6 @@ webassets.base_dir = %(here)s/static webassets.base_url = / webassets.cache = false webassets.manifest = json + +# JWT settings for the API authentication +jwt.secret = completely_insecure_jwt_secret_that_is_at_least_256_bits_long \ No newline at end of file diff --git a/tildes/production.ini.example b/tildes/production.ini.example index 3c7ad68..21ad9a1 100644 --- a/tildes/production.ini.example +++ b/tildes/production.ini.example @@ -42,6 +42,9 @@ webassets.base_url = / webassets.cache = false webassets.manifest = json +# JWT settings for the API authentication +jwt.secret = SomeReallyLongSecretDifferentFromTheSessionSecretAtLeast256BitsLong + # API keys for external APIs api_keys.embedly = embedlykeygoeshere api_keys.stripe.publishable = pk_live_ActualKeyShouldGoHere diff --git a/tildes/tildes/auth.py b/tildes/tildes/auth.py index 6b72f63..23109c2 100644 --- a/tildes/tildes/auth.py +++ b/tildes/tildes/auth.py @@ -3,16 +3,24 @@ """Configuration and functionality related to authentication/authorization.""" +import jwt from collections.abc import Sequence -from typing import Any, Optional +from datetime import datetime, timedelta, timezone +from typing import Any, Callable, Optional -from pyramid.authentication import SessionAuthenticationPolicy +from pyramid.authentication import ( + SessionAuthenticationPolicy, + CallbackAuthenticationPolicy, +) from pyramid.authorization import ACLAuthorizationPolicy from pyramid.config import Configurator from pyramid.httpexceptions import HTTPFound +from pyramid.interfaces import IAuthenticationPolicy +from pyramid_multiauth import MultiAuthenticationPolicy from pyramid.request import Request from pyramid.security import Allow, Everyone from sqlalchemy.orm import joinedload +from zope.interface import implementer from tildes.models.user import User @@ -80,12 +88,20 @@ def includeme(config: Configurator) -> None: config.set_authorization_policy(ACLAuthorizationPolicy()) - config.set_authentication_policy( - SessionAuthenticationPolicy(callback=auth_callback) - ) + # Get the JWT secret from settings + jwt_secret = config.registry.settings["jwt.secret"] + + # Configure both session and JWT authentication + policies = [ + SessionAuthenticationPolicy(callback=auth_callback), + JWTAuthenticationPolicy(secret=jwt_secret, callback=auth_callback), + ] + config.set_authentication_policy(MultiAuthenticationPolicy(policies)) - # enable CSRF checking globally by default - config.set_default_csrf_options(require_csrf=True) + # enable CSRF checking globally by default, but exclude API endpoints + config.set_default_csrf_options( + require_csrf=True, callback=lambda request: not request.path.startswith("/api/") + ) # make the logged-in User object available as request.user config.add_request_method(get_authenticated_user, "user", reify=True) @@ -101,3 +117,78 @@ def has_any_permission( return any( request.has_permission(permission, context) for permission in permissions ) + + +@implementer(IAuthenticationPolicy) +class JWTAuthenticationPolicy(CallbackAuthenticationPolicy): + """Authentication policy for JWT tokens. + + This policy checks for an Authorization header with a Bearer token. + The token is expected to be a JWT signed with the application's secret key. + """ + + def __init__( + self, + secret: str, + callback: None | Callable[[int, Request], Optional[Sequence[str]]] = None, + ): + """Initialize the policy with a secret key for JWT validation.""" + self.secret = secret + self.callback = callback + + def create_jwt_token(self, user_id: int, expiry: int = 86400) -> str: + """Create a new JWT token for a user.""" + payload = { + "sub": str(user_id), # JWT subjects must be strings + "iat": datetime.now(timezone.utc), + "exp": datetime.now(timezone.utc) + timedelta(seconds=expiry), + } + return jwt.encode(payload, self.secret, algorithm="HS256") + + def validate_jwt_token(self, token: str) -> Optional[dict[str, Any]]: + """Validate a JWT token and return its payload if valid.""" + try: + # jwt.decode() WILL verify the expiration as well. + # This does not need to be checked separately. + return jwt.decode(token, self.secret, algorithms=["HS256"]) + except jwt.InvalidTokenError: + return None + + def unauthenticated_userid(self, request: Request) -> Optional[int]: + """Return the user id from the token without validating it.""" + + if not request.path.startswith("/api/"): + return None + + auth_header = request.headers.get("Authorization") + if not auth_header: + return None + + try: + auth_type, token = auth_header.split(" ", 1) + except ValueError: + return None + + if auth_type.lower() != "bearer": + return None + + payload = self.validate_jwt_token(token) + if not payload: + return None + + try: + user_id = int( + payload["sub"] + ) # JWT subjects must be strings, convert to int + except (KeyError, ValueError): + return None + + return user_id + + def remember(self, _request: Request, _userid: int, **_kw: dict) -> Sequence[tuple]: + """This should not be used for JWT authentication as it is stateless.""" + return [] + + def forget(self, _request: Request) -> Sequence[tuple]: + """This should not be used for JWT authentication as it is stateless.""" + return []