diff --git a/weed/iam/sts/sts_service.go b/weed/iam/sts/sts_service.go index d706c50f2..213177ef2 100644 --- a/weed/iam/sts/sts_service.go +++ b/weed/iam/sts/sts_service.go @@ -10,10 +10,11 @@ import ( // STSService provides Security Token Service functionality type STSService struct { - config *STSConfig - initialized bool - providers map[string]providers.IdentityProvider - sessionStore SessionStore + config *STSConfig + initialized bool + providers map[string]providers.IdentityProvider + sessionStore SessionStore + tokenGenerator *TokenGenerator } // STSConfig holds STS service configuration @@ -181,6 +182,9 @@ func (s *STSService) Initialize(config *STSConfig) error { } s.sessionStore = sessionStore + // Initialize token generator for JWT validation + s.tokenGenerator = NewTokenGenerator(config.SigningKey, config.Issuer) + s.initialized = true return nil } @@ -280,6 +284,13 @@ func (s *STSService) AssumeRoleWithWebIdentity(ctx context.Context, request *Ass return nil, fmt.Errorf("failed to generate credentials: %w", err) } + // Generate proper JWT session token using our TokenGenerator + jwtToken, err := s.tokenGenerator.GenerateSessionToken(sessionId, expiresAt) + if err != nil { + return nil, fmt.Errorf("failed to generate JWT session token: %w", err) + } + credentials.SessionToken = jwtToken + // 5. Create session information session := &SessionInfo{ SessionId: sessionId, @@ -359,6 +370,13 @@ func (s *STSService) AssumeRoleWithCredentials(ctx context.Context, request *Ass return nil, fmt.Errorf("failed to generate credentials: %w", err) } + // Generate proper JWT session token using our TokenGenerator + jwtToken, err := s.tokenGenerator.GenerateSessionToken(sessionId, expiresAt) + if err != nil { + return nil, fmt.Errorf("failed to generate JWT session token: %w", err) + } + tempCredentials.SessionToken = jwtToken + // 6. Create session information session := &SessionInfo{ SessionId: sessionId, @@ -399,25 +417,19 @@ func (s *STSService) ValidateSessionToken(ctx context.Context, sessionToken stri return nil, fmt.Errorf("session token cannot be empty") } - // For now, use the session token as session ID directly - // In a full implementation, this would: - // 1. Parse JWT session token - // 2. Verify signature and expiration - // 3. Extract session ID from claims - - // Extract session ID (simplified - assuming token contains session ID directly) - sessionId := s.extractSessionIdFromToken(sessionToken) - if sessionId == "" { - return nil, fmt.Errorf("invalid session token format") + // Use token generator for proper JWT validation + claims, err := s.tokenGenerator.ValidateSessionToken(sessionToken) + if err != nil { + return nil, fmt.Errorf("invalid session token format: %w", err) } - // Retrieve session from store - session, err := s.sessionStore.GetSession(ctx, sessionId) + // Retrieve session from store using session ID from claims + session, err := s.sessionStore.GetSession(ctx, claims.SessionId) if err != nil { return nil, fmt.Errorf("session validation failed: %w", err) } - // Additional validation can be added here + // Additional validation - check expiration if session.ExpiresAt.Before(time.Now()) { return nil, fmt.Errorf("session has expired") } @@ -435,14 +447,14 @@ func (s *STSService) RevokeSession(ctx context.Context, sessionToken string) err return fmt.Errorf("session token cannot be empty") } - // Extract session ID from token - sessionId := s.extractSessionIdFromToken(sessionToken) - if sessionId == "" { - return fmt.Errorf("invalid session token format") + // Use token generator for proper JWT validation + claims, err := s.tokenGenerator.ValidateSessionToken(sessionToken) + if err != nil { + return fmt.Errorf("invalid session token format: %w", err) } - // Remove session from store - err := s.sessionStore.RevokeSession(ctx, sessionId) + // Remove session from store using session ID from claims + err = s.sessionStore.RevokeSession(ctx, claims.SessionId) if err != nil { return fmt.Errorf("failed to revoke session: %w", err) }