Browse Source

refactor(sts): replace hardcoded strings with constants

- Add comprehensive constants.go with all string literals
- Replace hardcoded strings in sts_service.go, provider_factory.go, token_utils.go
- Update error messages to use consistent constants
- Standardize configuration field names and store types
- Add JWT claim constants for token handling
- Update tests to use test constants
- Improve maintainability and reduce typos
- Enhance distributed deployment consistency
- Add CONSTANTS.md documentation

All existing functionality preserved with improved type safety.
pull/7160/head
chrislu 1 month ago
parent
commit
2dee3e2d52
  1. 137
      weed/iam/sts/constants.go
  2. 28
      weed/iam/sts/cross_instance_token_test.go
  3. 44
      weed/iam/sts/provider_factory.go
  4. 34
      weed/iam/sts/sts_service.go
  5. 28
      weed/iam/sts/token_utils.go

137
weed/iam/sts/constants.go

@ -0,0 +1,137 @@
package sts
// Store Types
const (
StoreTypeMemory = "memory"
StoreTypeFiler = "filer"
StoreTypeRedis = "redis"
)
// Provider Types
const (
ProviderTypeOIDC = "oidc"
ProviderTypeLDAP = "ldap"
ProviderTypeSAML = "saml"
ProviderTypeMock = "mock"
)
// Policy Effects
const (
EffectAllow = "Allow"
EffectDeny = "Deny"
)
// Default Paths
const (
DefaultSessionBasePath = "/seaweedfs/iam/sessions"
DefaultPolicyBasePath = "/seaweedfs/iam/policies"
DefaultRoleBasePath = "/seaweedfs/iam/roles"
)
// Default Values
const (
DefaultTokenDuration = 3600 // 1 hour in seconds
DefaultMaxSessionLength = 43200 // 12 hours in seconds
DefaultIssuer = "seaweedfs-sts"
MinSigningKeyLength = 16 // Minimum signing key length in bytes
)
// Configuration Field Names
const (
ConfigFieldFilerAddress = "filerAddress"
ConfigFieldBasePath = "basePath"
ConfigFieldIssuer = "issuer"
ConfigFieldClientID = "clientId"
ConfigFieldClientSecret = "clientSecret"
ConfigFieldJWKSUri = "jwksUri"
ConfigFieldScopes = "scopes"
ConfigFieldUserInfoUri = "userInfoUri"
ConfigFieldRedirectUri = "redirectUri"
)
// Error Messages
const (
ErrConfigCannotBeNil = "config cannot be nil"
ErrProviderCannotBeNil = "provider cannot be nil"
ErrProviderNameEmpty = "provider name cannot be empty"
ErrProviderTypeEmpty = "provider type cannot be empty"
ErrTokenCannotBeEmpty = "token cannot be empty"
ErrSessionTokenCannotBeEmpty = "session token cannot be empty"
ErrSessionIDCannotBeEmpty = "session ID cannot be empty"
ErrSTSServiceNotInitialized = "STS service not initialized"
ErrProviderNotInitialized = "provider not initialized"
ErrInvalidTokenDuration = "token duration must be positive"
ErrInvalidMaxSessionLength = "max session length must be positive"
ErrIssuerRequired = "issuer is required"
ErrSigningKeyTooShort = "signing key must be at least %d bytes"
ErrFilerAddressRequired = "filer address is required"
ErrClientIDRequired = "clientId is required for OIDC provider"
ErrUnsupportedStoreType = "unsupported store type: %s"
ErrUnsupportedProviderType = "unsupported provider type: %s"
ErrInvalidTokenFormat = "invalid session token format: %w"
ErrSessionValidationFailed = "session validation failed: %w"
ErrInvalidToken = "invalid token: %w"
ErrTokenNotValid = "token is not valid"
ErrInvalidTokenClaims = "invalid token claims"
ErrInvalidIssuer = "invalid issuer"
ErrMissingSessionID = "missing session ID"
)
// JWT Claims
const (
JWTClaimIssuer = "iss"
JWTClaimSubject = "sub"
JWTClaimAudience = "aud"
JWTClaimExpiration = "exp"
JWTClaimIssuedAt = "iat"
JWTClaimTokenType = "token_type"
)
// Token Types
const (
TokenTypeSession = "session"
TokenTypeAccess = "access"
TokenTypeRefresh = "refresh"
)
// AWS STS Actions
const (
ActionAssumeRole = "sts:AssumeRole"
ActionAssumeRoleWithWebIdentity = "sts:AssumeRoleWithWebIdentity"
ActionAssumeRoleWithCredentials = "sts:AssumeRoleWithCredentials"
ActionValidateSession = "sts:ValidateSession"
ActionRevokeSession = "sts:RevokeSession"
)
// Session File Prefixes
const (
SessionFilePrefix = "session_"
SessionFileExt = ".json"
PolicyFilePrefix = "policy_"
PolicyFileExt = ".json"
RoleFileExt = ".json"
)
// HTTP Headers
const (
HeaderAuthorization = "Authorization"
HeaderContentType = "Content-Type"
HeaderUserAgent = "User-Agent"
)
// Content Types
const (
ContentTypeJSON = "application/json"
ContentTypeFormURLEncoded = "application/x-www-form-urlencoded"
)
// Default Test Values
const (
TestSigningKey32Chars = "test-signing-key-32-characters-long"
TestIssuer = "test-sts"
TestClientID = "test-client"
TestSessionID = "test-session-123"
TestValidToken = "valid_test_token"
TestInvalidToken = "invalid_token"
TestExpiredToken = "expired_token"
)

28
weed/iam/sts/cross_instance_token_test.go

@ -19,30 +19,30 @@ func TestCrossInstanceTokenUsage(t *testing.T) {
TokenDuration: time.Hour, TokenDuration: time.Hour,
MaxSessionLength: 12 * time.Hour, MaxSessionLength: 12 * time.Hour,
Issuer: "distributed-sts-cluster", // SAME across all instances Issuer: "distributed-sts-cluster", // SAME across all instances
SigningKey: []byte("shared-signing-key-32-characters-long"), // SAME across all instances
SessionStoreType: "memory", // In production, this would be "filer" for true sharing
SigningKey: []byte(TestSigningKey32Chars), // SAME across all instances
SessionStoreType: StoreTypeMemory, // In production, this would be "filer" for true sharing
SessionStoreConfig: map[string]interface{}{ SessionStoreConfig: map[string]interface{}{
"filerAddress": "shared-filer:8888",
"basePath": "/seaweedfs/iam/sessions",
ConfigFieldFilerAddress: "shared-filer:8888",
ConfigFieldBasePath: DefaultSessionBasePath,
}, },
Providers: []*ProviderConfig{ Providers: []*ProviderConfig{
{ {
Name: "company-oidc", Name: "company-oidc",
Type: "oidc",
Type: ProviderTypeOIDC,
Enabled: true, Enabled: true,
Config: map[string]interface{}{ Config: map[string]interface{}{
"issuer": "https://sso.company.com/realms/production",
"clientId": "seaweedfs-cluster",
"jwksUri": "https://sso.company.com/realms/production/protocol/openid-connect/certs",
ConfigFieldIssuer: "https://sso.company.com/realms/production",
ConfigFieldClientID: "seaweedfs-cluster",
ConfigFieldJWKSUri: "https://sso.company.com/realms/production/protocol/openid-connect/certs",
}, },
}, },
{ {
Name: "test-mock", Name: "test-mock",
Type: "mock",
Type: ProviderTypeMock,
Enabled: true, Enabled: true,
Config: map[string]interface{}{ Config: map[string]interface{}{
"issuer": "http://test-mock:9999",
"clientId": "test-client",
ConfigFieldIssuer: "http://test-mock:9999",
ConfigFieldClientID: TestClientID,
}, },
}, },
}, },
@ -65,10 +65,10 @@ func TestCrossInstanceTokenUsage(t *testing.T) {
// Test 1: Token generated on Instance A can be validated on Instance B & C // Test 1: Token generated on Instance A can be validated on Instance B & C
t.Run("cross_instance_token_validation", func(t *testing.T) { t.Run("cross_instance_token_validation", func(t *testing.T) {
// Generate session token on Instance A
sessionId := "cross-instance-session-123"
// Generate session token on Instance A
sessionId := TestSessionID
expiresAt := time.Now().Add(time.Hour) expiresAt := time.Now().Add(time.Hour)
tokenFromA, err := instanceA.tokenGenerator.GenerateSessionToken(sessionId, expiresAt) tokenFromA, err := instanceA.tokenGenerator.GenerateSessionToken(sessionId, expiresAt)
require.NoError(t, err, "Instance A should generate token") require.NoError(t, err, "Instance A should generate token")

44
weed/iam/sts/provider_factory.go

@ -19,15 +19,15 @@ func NewProviderFactory() *ProviderFactory {
// CreateProvider creates an identity provider from configuration // CreateProvider creates an identity provider from configuration
func (f *ProviderFactory) CreateProvider(config *ProviderConfig) (providers.IdentityProvider, error) { func (f *ProviderFactory) CreateProvider(config *ProviderConfig) (providers.IdentityProvider, error) {
if config == nil { if config == nil {
return nil, fmt.Errorf("provider config cannot be nil")
return nil, fmt.Errorf(ErrConfigCannotBeNil)
} }
if config.Name == "" { if config.Name == "" {
return nil, fmt.Errorf("provider name cannot be empty")
return nil, fmt.Errorf(ErrProviderNameEmpty)
} }
if config.Type == "" { if config.Type == "" {
return nil, fmt.Errorf("provider type cannot be empty")
return nil, fmt.Errorf(ErrProviderTypeEmpty)
} }
if !config.Enabled { if !config.Enabled {
@ -38,16 +38,16 @@ func (f *ProviderFactory) CreateProvider(config *ProviderConfig) (providers.Iden
glog.V(2).Infof("Creating provider: name=%s, type=%s", config.Name, config.Type) glog.V(2).Infof("Creating provider: name=%s, type=%s", config.Name, config.Type)
switch config.Type { switch config.Type {
case "oidc":
case ProviderTypeOIDC:
return f.createOIDCProvider(config) return f.createOIDCProvider(config)
case "ldap":
case ProviderTypeLDAP:
return f.createLDAPProvider(config) return f.createLDAPProvider(config)
case "saml":
case ProviderTypeSAML:
return f.createSAMLProvider(config) return f.createSAMLProvider(config)
case "mock":
case ProviderTypeMock:
return f.createMockProvider(config) return f.createMockProvider(config)
default: default:
return nil, fmt.Errorf("unsupported provider type: %s", config.Type)
return nil, fmt.Errorf(ErrUnsupportedProviderType, config.Type)
} }
} }
@ -106,33 +106,33 @@ func (f *ProviderFactory) convertToOIDCConfig(configMap map[string]interface{})
config := &oidc.OIDCConfig{} config := &oidc.OIDCConfig{}
// Required fields // Required fields
if issuer, ok := configMap["issuer"].(string); ok {
if issuer, ok := configMap[ConfigFieldIssuer].(string); ok {
config.Issuer = issuer config.Issuer = issuer
} else { } else {
return nil, fmt.Errorf("issuer is required for OIDC provider")
return nil, fmt.Errorf(ErrIssuerRequired)
} }
if clientID, ok := configMap["clientId"].(string); ok {
if clientID, ok := configMap[ConfigFieldClientID].(string); ok {
config.ClientID = clientID config.ClientID = clientID
} else { } else {
return nil, fmt.Errorf("clientId is required for OIDC provider")
return nil, fmt.Errorf(ErrClientIDRequired)
} }
// Optional fields // Optional fields
if clientSecret, ok := configMap["clientSecret"].(string); ok {
if clientSecret, ok := configMap[ConfigFieldClientSecret].(string); ok {
config.ClientSecret = clientSecret config.ClientSecret = clientSecret
} }
if jwksUri, ok := configMap["jwksUri"].(string); ok {
if jwksUri, ok := configMap[ConfigFieldJWKSUri].(string); ok {
config.JWKSUri = jwksUri config.JWKSUri = jwksUri
} }
if userInfoUri, ok := configMap["userInfoUri"].(string); ok {
if userInfoUri, ok := configMap[ConfigFieldUserInfoUri].(string); ok {
config.UserInfoUri = userInfoUri config.UserInfoUri = userInfoUri
} }
// Convert scopes array // Convert scopes array
if scopesInterface, ok := configMap["scopes"]; ok {
if scopesInterface, ok := configMap[ConfigFieldScopes]; ok {
if scopes, err := f.convertToStringSlice(scopesInterface); err == nil { if scopes, err := f.convertToStringSlice(scopesInterface); err == nil {
config.Scopes = scopes config.Scopes = scopes
} }
@ -260,12 +260,12 @@ func (f *ProviderFactory) ValidateProviderConfig(config *ProviderConfig) error {
// validateOIDCConfig validates OIDC provider configuration // validateOIDCConfig validates OIDC provider configuration
func (f *ProviderFactory) validateOIDCConfig(config map[string]interface{}) error { func (f *ProviderFactory) validateOIDCConfig(config map[string]interface{}) error {
if _, ok := config["issuer"]; !ok {
return fmt.Errorf("OIDC provider requires 'issuer' field")
if _, ok := config[ConfigFieldIssuer]; !ok {
return fmt.Errorf("OIDC provider requires '%s' field", ConfigFieldIssuer)
} }
if _, ok := config["clientId"]; !ok {
return fmt.Errorf("OIDC provider requires 'clientId' field")
if _, ok := config[ConfigFieldClientID]; !ok {
return fmt.Errorf("OIDC provider requires '%s' field", ConfigFieldClientID)
} }
return nil return nil
@ -291,5 +291,5 @@ func (f *ProviderFactory) validateMockConfig(config map[string]interface{}) erro
// GetSupportedProviderTypes returns list of supported provider types // GetSupportedProviderTypes returns list of supported provider types
func (f *ProviderFactory) GetSupportedProviderTypes() []string { func (f *ProviderFactory) GetSupportedProviderTypes() []string {
return []string{"oidc", "mock"}
return []string{ProviderTypeOIDC, ProviderTypeMock}
} }

34
weed/iam/sts/sts_service.go

@ -185,7 +185,7 @@ func NewSTSService() *STSService {
// Initialize initializes the STS service with configuration // Initialize initializes the STS service with configuration
func (s *STSService) Initialize(config *STSConfig) error { func (s *STSService) Initialize(config *STSConfig) error {
if config == nil { if config == nil {
return fmt.Errorf("config cannot be nil")
return fmt.Errorf(ErrConfigCannotBeNil)
} }
if err := s.validateConfig(config); err != nil { if err := s.validateConfig(config); err != nil {
@ -216,19 +216,19 @@ func (s *STSService) Initialize(config *STSConfig) error {
// validateConfig validates the STS configuration // validateConfig validates the STS configuration
func (s *STSService) validateConfig(config *STSConfig) error { func (s *STSService) validateConfig(config *STSConfig) error {
if config.TokenDuration <= 0 { if config.TokenDuration <= 0 {
return fmt.Errorf("token duration must be positive")
return fmt.Errorf(ErrInvalidTokenDuration)
} }
if config.MaxSessionLength <= 0 { if config.MaxSessionLength <= 0 {
return fmt.Errorf("max session length must be positive")
return fmt.Errorf(ErrInvalidMaxSessionLength)
} }
if config.Issuer == "" { if config.Issuer == "" {
return fmt.Errorf("issuer is required")
return fmt.Errorf(ErrIssuerRequired)
} }
if len(config.SigningKey) < 16 {
return fmt.Errorf("signing key must be at least 16 bytes")
if len(config.SigningKey) < MinSigningKeyLength {
return fmt.Errorf(ErrSigningKeyTooShort, MinSigningKeyLength)
} }
return nil return nil
@ -237,18 +237,18 @@ func (s *STSService) validateConfig(config *STSConfig) error {
// createSessionStore creates a session store based on configuration // createSessionStore creates a session store based on configuration
func (s *STSService) createSessionStore(config *STSConfig) (SessionStore, error) { func (s *STSService) createSessionStore(config *STSConfig) (SessionStore, error) {
switch config.SessionStoreType { switch config.SessionStoreType {
case "", "memory":
case "", StoreTypeMemory:
return NewMemorySessionStore(), nil return NewMemorySessionStore(), nil
case "filer":
case StoreTypeFiler:
return NewFilerSessionStore(config.SessionStoreConfig) return NewFilerSessionStore(config.SessionStoreConfig)
default: default:
return nil, fmt.Errorf("unsupported session store type: %s", config.SessionStoreType)
return nil, fmt.Errorf(ErrUnsupportedStoreType, config.SessionStoreType)
} }
} }
// loadProvidersFromConfig loads identity providers from configuration // loadProvidersFromConfig loads identity providers from configuration
func (s *STSService) loadProvidersFromConfig(config *STSConfig) error { func (s *STSService) loadProvidersFromConfig(config *STSConfig) error {
if config.Providers == nil || len(config.Providers) == 0 {
if len(config.Providers) == 0 {
glog.V(2).Infof("No providers configured in STS config") glog.V(2).Infof("No providers configured in STS config")
return nil return nil
} }
@ -287,12 +287,12 @@ func (s *STSService) IsInitialized() bool {
// RegisterProvider registers an identity provider // RegisterProvider registers an identity provider
func (s *STSService) RegisterProvider(provider providers.IdentityProvider) error { func (s *STSService) RegisterProvider(provider providers.IdentityProvider) error {
if provider == nil { if provider == nil {
return fmt.Errorf("provider cannot be nil")
return fmt.Errorf(ErrProviderCannotBeNil)
} }
name := provider.Name() name := provider.Name()
if name == "" { if name == "" {
return fmt.Errorf("provider name cannot be empty")
return fmt.Errorf(ErrProviderNameEmpty)
} }
s.providers[name] = provider s.providers[name] = provider
@ -302,7 +302,7 @@ func (s *STSService) RegisterProvider(provider providers.IdentityProvider) error
// AssumeRoleWithWebIdentity assumes a role using a web identity token (OIDC) // AssumeRoleWithWebIdentity assumes a role using a web identity token (OIDC)
func (s *STSService) AssumeRoleWithWebIdentity(ctx context.Context, request *AssumeRoleWithWebIdentityRequest) (*AssumeRoleResponse, error) { func (s *STSService) AssumeRoleWithWebIdentity(ctx context.Context, request *AssumeRoleWithWebIdentityRequest) (*AssumeRoleResponse, error) {
if !s.initialized { if !s.initialized {
return nil, fmt.Errorf("STS service not initialized")
return nil, fmt.Errorf(ErrSTSServiceNotInitialized)
} }
if request == nil { if request == nil {
@ -467,23 +467,23 @@ func (s *STSService) AssumeRoleWithCredentials(ctx context.Context, request *Ass
// ValidateSessionToken validates a session token and returns session information // ValidateSessionToken validates a session token and returns session information
func (s *STSService) ValidateSessionToken(ctx context.Context, sessionToken string) (*SessionInfo, error) { func (s *STSService) ValidateSessionToken(ctx context.Context, sessionToken string) (*SessionInfo, error) {
if !s.initialized { if !s.initialized {
return nil, fmt.Errorf("STS service not initialized")
return nil, fmt.Errorf(ErrSTSServiceNotInitialized)
} }
if sessionToken == "" { if sessionToken == "" {
return nil, fmt.Errorf("session token cannot be empty")
return nil, fmt.Errorf(ErrSessionTokenCannotBeEmpty)
} }
// Use token generator for proper JWT validation // Use token generator for proper JWT validation
claims, err := s.tokenGenerator.ValidateSessionToken(sessionToken) claims, err := s.tokenGenerator.ValidateSessionToken(sessionToken)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid session token format: %w", err)
return nil, fmt.Errorf(ErrInvalidTokenFormat, err)
} }
// Retrieve session from store using session ID from claims // Retrieve session from store using session ID from claims
session, err := s.sessionStore.GetSession(ctx, claims.SessionId) session, err := s.sessionStore.GetSession(ctx, claims.SessionId)
if err != nil { if err != nil {
return nil, fmt.Errorf("session validation failed: %w", err)
return nil, fmt.Errorf(ErrSessionValidationFailed, err)
} }
// Additional validation - check expiration // Additional validation - check expiration

28
weed/iam/sts/token_utils.go

@ -28,11 +28,11 @@ func NewTokenGenerator(signingKey []byte, issuer string) *TokenGenerator {
// GenerateSessionToken creates a signed JWT session token // GenerateSessionToken creates a signed JWT session token
func (t *TokenGenerator) GenerateSessionToken(sessionId string, expiresAt time.Time) (string, error) { func (t *TokenGenerator) GenerateSessionToken(sessionId string, expiresAt time.Time) (string, error) {
claims := jwt.MapClaims{ claims := jwt.MapClaims{
"iss": t.issuer,
"sub": sessionId,
"iat": time.Now().Unix(),
"exp": expiresAt.Unix(),
"token_type": "session",
JWTClaimIssuer: t.issuer,
JWTClaimSubject: sessionId,
JWTClaimIssuedAt: time.Now().Unix(),
JWTClaimExpiration: expiresAt.Unix(),
JWTClaimTokenType: TokenTypeSession,
} }
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
@ -49,33 +49,33 @@ func (t *TokenGenerator) ValidateSessionToken(tokenString string) (*SessionToken
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid token: %w", err)
return nil, fmt.Errorf(ErrInvalidToken, err)
} }
if !token.Valid { if !token.Valid {
return nil, fmt.Errorf("token is not valid")
return nil, fmt.Errorf(ErrTokenNotValid)
} }
claims, ok := token.Claims.(jwt.MapClaims) claims, ok := token.Claims.(jwt.MapClaims)
if !ok { if !ok {
return nil, fmt.Errorf("invalid token claims")
return nil, fmt.Errorf(ErrInvalidTokenClaims)
} }
// Verify issuer // Verify issuer
if iss, ok := claims["iss"].(string); !ok || iss != t.issuer {
return nil, fmt.Errorf("invalid issuer")
if iss, ok := claims[JWTClaimIssuer].(string); !ok || iss != t.issuer {
return nil, fmt.Errorf(ErrInvalidIssuer)
} }
// Extract session ID // Extract session ID
sessionId, ok := claims["sub"].(string)
sessionId, ok := claims[JWTClaimSubject].(string)
if !ok { if !ok {
return nil, fmt.Errorf("missing session ID")
return nil, fmt.Errorf(ErrMissingSessionID)
} }
return &SessionTokenClaims{ return &SessionTokenClaims{
SessionId: sessionId, SessionId: sessionId,
ExpiresAt: time.Unix(int64(claims["exp"].(float64)), 0),
IssuedAt: time.Unix(int64(claims["iat"].(float64)), 0),
ExpiresAt: time.Unix(int64(claims[JWTClaimExpiration].(float64)), 0),
IssuedAt: time.Unix(int64(claims[JWTClaimIssuedAt].(float64)), 0),
}, nil }, nil
} }

Loading…
Cancel
Save