You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
217 lines
6.5 KiB
217 lines
6.5 KiB
package sts
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/seaweedfs/seaweedfs/weed/iam/utils"
|
|
)
|
|
|
|
// TokenGenerator handles token generation and validation
|
|
type TokenGenerator struct {
|
|
signingKey []byte
|
|
issuer string
|
|
}
|
|
|
|
// NewTokenGenerator creates a new token generator
|
|
func NewTokenGenerator(signingKey []byte, issuer string) *TokenGenerator {
|
|
return &TokenGenerator{
|
|
signingKey: signingKey,
|
|
issuer: issuer,
|
|
}
|
|
}
|
|
|
|
// GenerateSessionToken creates a signed JWT session token (legacy method for compatibility)
|
|
func (t *TokenGenerator) GenerateSessionToken(sessionId string, expiresAt time.Time) (string, error) {
|
|
claims := NewSTSSessionClaims(sessionId, t.issuer, expiresAt)
|
|
return t.GenerateJWTWithClaims(claims)
|
|
}
|
|
|
|
// GenerateJWTWithClaims creates a signed JWT token with comprehensive session claims
|
|
func (t *TokenGenerator) GenerateJWTWithClaims(claims *STSSessionClaims) (string, error) {
|
|
if claims == nil {
|
|
return "", fmt.Errorf("claims cannot be nil")
|
|
}
|
|
|
|
// Ensure issuer is set from token generator
|
|
if claims.Issuer == "" {
|
|
claims.Issuer = t.issuer
|
|
}
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
return token.SignedString(t.signingKey)
|
|
}
|
|
|
|
// ValidateSessionToken validates and extracts claims from a session token
|
|
func (t *TokenGenerator) ValidateSessionToken(tokenString string) (*SessionTokenClaims, error) {
|
|
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
}
|
|
return t.signingKey, nil
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf(ErrInvalidToken, err)
|
|
}
|
|
|
|
if !token.Valid {
|
|
return nil, fmt.Errorf(ErrTokenNotValid)
|
|
}
|
|
|
|
claims, ok := token.Claims.(jwt.MapClaims)
|
|
if !ok {
|
|
return nil, fmt.Errorf(ErrInvalidTokenClaims)
|
|
}
|
|
|
|
// Verify issuer
|
|
if iss, ok := claims[JWTClaimIssuer].(string); !ok || iss != t.issuer {
|
|
return nil, fmt.Errorf(ErrInvalidIssuer)
|
|
}
|
|
|
|
// Extract session ID
|
|
sessionId, ok := claims[JWTClaimSubject].(string)
|
|
if !ok {
|
|
return nil, fmt.Errorf(ErrMissingSessionID)
|
|
}
|
|
|
|
return &SessionTokenClaims{
|
|
SessionId: sessionId,
|
|
ExpiresAt: time.Unix(int64(claims[JWTClaimExpiration].(float64)), 0),
|
|
IssuedAt: time.Unix(int64(claims[JWTClaimIssuedAt].(float64)), 0),
|
|
}, nil
|
|
}
|
|
|
|
// ValidateJWTWithClaims validates and extracts comprehensive session claims from a JWT token
|
|
func (t *TokenGenerator) ValidateJWTWithClaims(tokenString string) (*STSSessionClaims, error) {
|
|
token, err := jwt.ParseWithClaims(tokenString, &STSSessionClaims{}, func(token *jwt.Token) (interface{}, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
}
|
|
return t.signingKey, nil
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf(ErrInvalidToken, err)
|
|
}
|
|
|
|
if !token.Valid {
|
|
return nil, fmt.Errorf(ErrTokenNotValid)
|
|
}
|
|
|
|
claims, ok := token.Claims.(*STSSessionClaims)
|
|
if !ok {
|
|
return nil, fmt.Errorf(ErrInvalidTokenClaims)
|
|
}
|
|
|
|
// Validate issuer
|
|
if claims.Issuer != t.issuer {
|
|
return nil, fmt.Errorf(ErrInvalidIssuer)
|
|
}
|
|
|
|
// Validate that required fields are present
|
|
if claims.SessionId == "" {
|
|
return nil, fmt.Errorf(ErrMissingSessionID)
|
|
}
|
|
|
|
// Additional validation using the claims' own validation method
|
|
if !claims.IsValid() {
|
|
return nil, fmt.Errorf(ErrTokenNotValid)
|
|
}
|
|
|
|
return claims, nil
|
|
}
|
|
|
|
// SessionTokenClaims represents parsed session token claims
|
|
type SessionTokenClaims struct {
|
|
SessionId string
|
|
ExpiresAt time.Time
|
|
IssuedAt time.Time
|
|
}
|
|
|
|
// CredentialGenerator generates AWS-compatible temporary credentials
|
|
type CredentialGenerator struct{}
|
|
|
|
// NewCredentialGenerator creates a new credential generator
|
|
func NewCredentialGenerator() *CredentialGenerator {
|
|
return &CredentialGenerator{}
|
|
}
|
|
|
|
// GenerateTemporaryCredentials creates temporary AWS credentials
|
|
func (c *CredentialGenerator) GenerateTemporaryCredentials(sessionId string, expiration time.Time) (*Credentials, error) {
|
|
accessKeyId, err := c.generateAccessKeyId(sessionId)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to generate access key ID: %w", err)
|
|
}
|
|
|
|
secretAccessKey, err := c.generateSecretAccessKey()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to generate secret access key: %w", err)
|
|
}
|
|
|
|
sessionToken, err := c.generateSessionTokenId(sessionId)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to generate session token: %w", err)
|
|
}
|
|
|
|
return &Credentials{
|
|
AccessKeyId: accessKeyId,
|
|
SecretAccessKey: secretAccessKey,
|
|
SessionToken: sessionToken,
|
|
Expiration: expiration,
|
|
}, nil
|
|
}
|
|
|
|
// generateAccessKeyId generates an AWS-style access key ID
|
|
func (c *CredentialGenerator) generateAccessKeyId(sessionId string) (string, error) {
|
|
// Create a deterministic but unique access key ID based on session
|
|
hash := sha256.Sum256([]byte("access-key:" + sessionId))
|
|
return "AKIA" + hex.EncodeToString(hash[:8]), nil // AWS format: AKIA + 16 chars
|
|
}
|
|
|
|
// generateSecretAccessKey generates a random secret access key
|
|
func (c *CredentialGenerator) generateSecretAccessKey() (string, error) {
|
|
// Generate 32 random bytes for secret key
|
|
secretBytes := make([]byte, 32)
|
|
_, err := rand.Read(secretBytes)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return base64.StdEncoding.EncodeToString(secretBytes), nil
|
|
}
|
|
|
|
// generateSessionTokenId generates a session token identifier
|
|
func (c *CredentialGenerator) generateSessionTokenId(sessionId string) (string, error) {
|
|
// Create session token with session ID embedded
|
|
hash := sha256.Sum256([]byte("session-token:" + sessionId))
|
|
return "ST" + hex.EncodeToString(hash[:16]), nil // Custom format
|
|
}
|
|
|
|
// generateSessionId generates a unique session ID
|
|
func GenerateSessionId() (string, error) {
|
|
randomBytes := make([]byte, 16)
|
|
_, err := rand.Read(randomBytes)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return hex.EncodeToString(randomBytes), nil
|
|
}
|
|
|
|
// generateAssumedRoleArn generates the ARN for an assumed role user
|
|
func GenerateAssumedRoleArn(roleArn, sessionName string) string {
|
|
// Convert role ARN to assumed role user ARN
|
|
// arn:seaweed:iam::role/RoleName -> arn:seaweed:sts::assumed-role/RoleName/SessionName
|
|
roleName := utils.ExtractRoleNameFromArn(roleArn)
|
|
if roleName == "" {
|
|
// This should not happen if validation is done properly upstream
|
|
return fmt.Sprintf("arn:seaweed:sts::assumed-role/INVALID-ARN/%s", sessionName)
|
|
}
|
|
return fmt.Sprintf("arn:seaweed:sts::assumed-role/%s/%s", roleName, sessionName)
|
|
}
|