You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							217 lines
						
					
					
						
							6.5 KiB
						
					
					
				
			
		
		
		
			
			
			
		
		
	
	
							217 lines
						
					
					
						
							6.5 KiB
						
					
					
				
								package sts
							 | 
						|
								
							 | 
						|
								import (
							 | 
						|
									"crypto/rand"
							 | 
						|
									"crypto/sha256"
							 | 
						|
									"encoding/base64"
							 | 
						|
									"encoding/hex"
							 | 
						|
									"fmt"
							 | 
						|
									"time"
							 | 
						|
								
							 | 
						|
									"github.com/golang-jwt/jwt/v5"
							 | 
						|
									"github.com/seaweedfs/seaweedfs/weed/iam/utils"
							 | 
						|
								)
							 | 
						|
								
							 | 
						|
								// TokenGenerator handles token generation and validation
							 | 
						|
								type TokenGenerator struct {
							 | 
						|
									signingKey []byte
							 | 
						|
									issuer     string
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// NewTokenGenerator creates a new token generator
							 | 
						|
								func NewTokenGenerator(signingKey []byte, issuer string) *TokenGenerator {
							 | 
						|
									return &TokenGenerator{
							 | 
						|
										signingKey: signingKey,
							 | 
						|
										issuer:     issuer,
							 | 
						|
									}
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// GenerateSessionToken creates a signed JWT session token (legacy method for compatibility)
							 | 
						|
								func (t *TokenGenerator) GenerateSessionToken(sessionId string, expiresAt time.Time) (string, error) {
							 | 
						|
									claims := NewSTSSessionClaims(sessionId, t.issuer, expiresAt)
							 | 
						|
									return t.GenerateJWTWithClaims(claims)
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// GenerateJWTWithClaims creates a signed JWT token with comprehensive session claims
							 | 
						|
								func (t *TokenGenerator) GenerateJWTWithClaims(claims *STSSessionClaims) (string, error) {
							 | 
						|
									if claims == nil {
							 | 
						|
										return "", fmt.Errorf("claims cannot be nil")
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Ensure issuer is set from token generator
							 | 
						|
									if claims.Issuer == "" {
							 | 
						|
										claims.Issuer = t.issuer
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
							 | 
						|
									return token.SignedString(t.signingKey)
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// ValidateSessionToken validates and extracts claims from a session token
							 | 
						|
								func (t *TokenGenerator) ValidateSessionToken(tokenString string) (*SessionTokenClaims, error) {
							 | 
						|
									token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
							 | 
						|
										if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
							 | 
						|
											return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
							 | 
						|
										}
							 | 
						|
										return t.signingKey, nil
							 | 
						|
									})
							 | 
						|
								
							 | 
						|
									if err != nil {
							 | 
						|
										return nil, fmt.Errorf(ErrInvalidToken, err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if !token.Valid {
							 | 
						|
										return nil, fmt.Errorf(ErrTokenNotValid)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									claims, ok := token.Claims.(jwt.MapClaims)
							 | 
						|
									if !ok {
							 | 
						|
										return nil, fmt.Errorf(ErrInvalidTokenClaims)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Verify issuer
							 | 
						|
									if iss, ok := claims[JWTClaimIssuer].(string); !ok || iss != t.issuer {
							 | 
						|
										return nil, fmt.Errorf(ErrInvalidIssuer)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Extract session ID
							 | 
						|
									sessionId, ok := claims[JWTClaimSubject].(string)
							 | 
						|
									if !ok {
							 | 
						|
										return nil, fmt.Errorf(ErrMissingSessionID)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									return &SessionTokenClaims{
							 | 
						|
										SessionId: sessionId,
							 | 
						|
										ExpiresAt: time.Unix(int64(claims[JWTClaimExpiration].(float64)), 0),
							 | 
						|
										IssuedAt:  time.Unix(int64(claims[JWTClaimIssuedAt].(float64)), 0),
							 | 
						|
									}, nil
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// ValidateJWTWithClaims validates and extracts comprehensive session claims from a JWT token
							 | 
						|
								func (t *TokenGenerator) ValidateJWTWithClaims(tokenString string) (*STSSessionClaims, error) {
							 | 
						|
									token, err := jwt.ParseWithClaims(tokenString, &STSSessionClaims{}, func(token *jwt.Token) (interface{}, error) {
							 | 
						|
										if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
							 | 
						|
											return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
							 | 
						|
										}
							 | 
						|
										return t.signingKey, nil
							 | 
						|
									})
							 | 
						|
								
							 | 
						|
									if err != nil {
							 | 
						|
										return nil, fmt.Errorf(ErrInvalidToken, err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if !token.Valid {
							 | 
						|
										return nil, fmt.Errorf(ErrTokenNotValid)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									claims, ok := token.Claims.(*STSSessionClaims)
							 | 
						|
									if !ok {
							 | 
						|
										return nil, fmt.Errorf(ErrInvalidTokenClaims)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Validate issuer
							 | 
						|
									if claims.Issuer != t.issuer {
							 | 
						|
										return nil, fmt.Errorf(ErrInvalidIssuer)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Validate that required fields are present
							 | 
						|
									if claims.SessionId == "" {
							 | 
						|
										return nil, fmt.Errorf(ErrMissingSessionID)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Additional validation using the claims' own validation method
							 | 
						|
									if !claims.IsValid() {
							 | 
						|
										return nil, fmt.Errorf(ErrTokenNotValid)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									return claims, nil
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// SessionTokenClaims represents parsed session token claims
							 | 
						|
								type SessionTokenClaims struct {
							 | 
						|
									SessionId string
							 | 
						|
									ExpiresAt time.Time
							 | 
						|
									IssuedAt  time.Time
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// CredentialGenerator generates AWS-compatible temporary credentials
							 | 
						|
								type CredentialGenerator struct{}
							 | 
						|
								
							 | 
						|
								// NewCredentialGenerator creates a new credential generator
							 | 
						|
								func NewCredentialGenerator() *CredentialGenerator {
							 | 
						|
									return &CredentialGenerator{}
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// GenerateTemporaryCredentials creates temporary AWS credentials
							 | 
						|
								func (c *CredentialGenerator) GenerateTemporaryCredentials(sessionId string, expiration time.Time) (*Credentials, error) {
							 | 
						|
									accessKeyId, err := c.generateAccessKeyId(sessionId)
							 | 
						|
									if err != nil {
							 | 
						|
										return nil, fmt.Errorf("failed to generate access key ID: %w", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									secretAccessKey, err := c.generateSecretAccessKey()
							 | 
						|
									if err != nil {
							 | 
						|
										return nil, fmt.Errorf("failed to generate secret access key: %w", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									sessionToken, err := c.generateSessionTokenId(sessionId)
							 | 
						|
									if err != nil {
							 | 
						|
										return nil, fmt.Errorf("failed to generate session token: %w", err)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									return &Credentials{
							 | 
						|
										AccessKeyId:     accessKeyId,
							 | 
						|
										SecretAccessKey: secretAccessKey,
							 | 
						|
										SessionToken:    sessionToken,
							 | 
						|
										Expiration:      expiration,
							 | 
						|
									}, nil
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// generateAccessKeyId generates an AWS-style access key ID
							 | 
						|
								func (c *CredentialGenerator) generateAccessKeyId(sessionId string) (string, error) {
							 | 
						|
									// Create a deterministic but unique access key ID based on session
							 | 
						|
									hash := sha256.Sum256([]byte("access-key:" + sessionId))
							 | 
						|
									return "AKIA" + hex.EncodeToString(hash[:8]), nil // AWS format: AKIA + 16 chars
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// generateSecretAccessKey generates a random secret access key
							 | 
						|
								func (c *CredentialGenerator) generateSecretAccessKey() (string, error) {
							 | 
						|
									// Generate 32 random bytes for secret key
							 | 
						|
									secretBytes := make([]byte, 32)
							 | 
						|
									_, err := rand.Read(secretBytes)
							 | 
						|
									if err != nil {
							 | 
						|
										return "", err
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									return base64.StdEncoding.EncodeToString(secretBytes), nil
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// generateSessionTokenId generates a session token identifier
							 | 
						|
								func (c *CredentialGenerator) generateSessionTokenId(sessionId string) (string, error) {
							 | 
						|
									// Create session token with session ID embedded
							 | 
						|
									hash := sha256.Sum256([]byte("session-token:" + sessionId))
							 | 
						|
									return "ST" + hex.EncodeToString(hash[:16]), nil // Custom format
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// generateSessionId generates a unique session ID
							 | 
						|
								func GenerateSessionId() (string, error) {
							 | 
						|
									randomBytes := make([]byte, 16)
							 | 
						|
									_, err := rand.Read(randomBytes)
							 | 
						|
									if err != nil {
							 | 
						|
										return "", err
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									return hex.EncodeToString(randomBytes), nil
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// generateAssumedRoleArn generates the ARN for an assumed role user
							 | 
						|
								func GenerateAssumedRoleArn(roleArn, sessionName string) string {
							 | 
						|
									// Convert role ARN to assumed role user ARN
							 | 
						|
									// arn:seaweed:iam::role/RoleName -> arn:seaweed:sts::assumed-role/RoleName/SessionName
							 | 
						|
									roleName := utils.ExtractRoleNameFromArn(roleArn)
							 | 
						|
									if roleName == "" {
							 | 
						|
										// This should not happen if validation is done properly upstream
							 | 
						|
										return fmt.Sprintf("arn:seaweed:sts::assumed-role/INVALID-ARN/%s", sessionName)
							 | 
						|
									}
							 | 
						|
									return fmt.Sprintf("arn:seaweed:sts::assumed-role/%s/%s", roleName, sessionName)
							 | 
						|
								}
							 |