Browse Source

fix: Implement proper JWT session token validation in STS service

- Add TokenGenerator to STSService for proper JWT validation
- Generate JWT session tokens in AssumeRole operations using TokenGenerator
- ValidateSessionToken now properly parses and validates JWT tokens
- RevokeSession uses JWT validation to extract session ID
- Fixes session token format mismatch between generation and validation
pull/7160/head
chrislu 1 month ago
parent
commit
9406898ab1
  1. 58
      weed/iam/sts/sts_service.go

58
weed/iam/sts/sts_service.go

@ -10,10 +10,11 @@ import (
// STSService provides Security Token Service functionality // STSService provides Security Token Service functionality
type STSService struct { type STSService struct {
config *STSConfig
initialized bool
providers map[string]providers.IdentityProvider
sessionStore SessionStore
config *STSConfig
initialized bool
providers map[string]providers.IdentityProvider
sessionStore SessionStore
tokenGenerator *TokenGenerator
} }
// STSConfig holds STS service configuration // STSConfig holds STS service configuration
@ -181,6 +182,9 @@ func (s *STSService) Initialize(config *STSConfig) error {
} }
s.sessionStore = sessionStore s.sessionStore = sessionStore
// Initialize token generator for JWT validation
s.tokenGenerator = NewTokenGenerator(config.SigningKey, config.Issuer)
s.initialized = true s.initialized = true
return nil return nil
} }
@ -280,6 +284,13 @@ func (s *STSService) AssumeRoleWithWebIdentity(ctx context.Context, request *Ass
return nil, fmt.Errorf("failed to generate credentials: %w", err) return nil, fmt.Errorf("failed to generate credentials: %w", err)
} }
// Generate proper JWT session token using our TokenGenerator
jwtToken, err := s.tokenGenerator.GenerateSessionToken(sessionId, expiresAt)
if err != nil {
return nil, fmt.Errorf("failed to generate JWT session token: %w", err)
}
credentials.SessionToken = jwtToken
// 5. Create session information // 5. Create session information
session := &SessionInfo{ session := &SessionInfo{
SessionId: sessionId, SessionId: sessionId,
@ -359,6 +370,13 @@ func (s *STSService) AssumeRoleWithCredentials(ctx context.Context, request *Ass
return nil, fmt.Errorf("failed to generate credentials: %w", err) return nil, fmt.Errorf("failed to generate credentials: %w", err)
} }
// Generate proper JWT session token using our TokenGenerator
jwtToken, err := s.tokenGenerator.GenerateSessionToken(sessionId, expiresAt)
if err != nil {
return nil, fmt.Errorf("failed to generate JWT session token: %w", err)
}
tempCredentials.SessionToken = jwtToken
// 6. Create session information // 6. Create session information
session := &SessionInfo{ session := &SessionInfo{
SessionId: sessionId, SessionId: sessionId,
@ -399,25 +417,19 @@ func (s *STSService) ValidateSessionToken(ctx context.Context, sessionToken stri
return nil, fmt.Errorf("session token cannot be empty") return nil, fmt.Errorf("session token cannot be empty")
} }
// For now, use the session token as session ID directly
// In a full implementation, this would:
// 1. Parse JWT session token
// 2. Verify signature and expiration
// 3. Extract session ID from claims
// Extract session ID (simplified - assuming token contains session ID directly)
sessionId := s.extractSessionIdFromToken(sessionToken)
if sessionId == "" {
return nil, fmt.Errorf("invalid session token format")
// Use token generator for proper JWT validation
claims, err := s.tokenGenerator.ValidateSessionToken(sessionToken)
if err != nil {
return nil, fmt.Errorf("invalid session token format: %w", err)
} }
// Retrieve session from store
session, err := s.sessionStore.GetSession(ctx, sessionId)
// Retrieve session from store using session ID from claims
session, err := s.sessionStore.GetSession(ctx, claims.SessionId)
if err != nil { if err != nil {
return nil, fmt.Errorf("session validation failed: %w", err) return nil, fmt.Errorf("session validation failed: %w", err)
} }
// Additional validation can be added here
// Additional validation - check expiration
if session.ExpiresAt.Before(time.Now()) { if session.ExpiresAt.Before(time.Now()) {
return nil, fmt.Errorf("session has expired") return nil, fmt.Errorf("session has expired")
} }
@ -435,14 +447,14 @@ func (s *STSService) RevokeSession(ctx context.Context, sessionToken string) err
return fmt.Errorf("session token cannot be empty") return fmt.Errorf("session token cannot be empty")
} }
// Extract session ID from token
sessionId := s.extractSessionIdFromToken(sessionToken)
if sessionId == "" {
return fmt.Errorf("invalid session token format")
// Use token generator for proper JWT validation
claims, err := s.tokenGenerator.ValidateSessionToken(sessionToken)
if err != nil {
return fmt.Errorf("invalid session token format: %w", err)
} }
// Remove session from store
err := s.sessionStore.RevokeSession(ctx, sessionId)
// Remove session from store using session ID from claims
err = s.sessionStore.RevokeSession(ctx, claims.SessionId)
if err != nil { if err != nil {
return fmt.Errorf("failed to revoke session: %w", err) return fmt.Errorf("failed to revoke session: %w", err)
} }

Loading…
Cancel
Save