Browse Source

🔧 TDD Support: Enhanced Mock Providers & Policy Validation

Supporting changes for full IAM integration:

 ENHANCED MOCK PROVIDERS:
- LDAP mock provider with complete authentication support
- OIDC mock provider with token compatibility improvements
- Better test data separation between mock and production code

 IMPROVED POLICY VALIDATION:
- Trust policy validation separate from resource policies
- Enhanced policy engine test coverage
- Better policy document structure validation

 REFINED STS SERVICE:
- Improved session management and validation
- Better error handling and edge cases
- Enhanced test coverage for complex scenarios

These changes provide the foundation for the integrated IAM system.
pull/7160/head
chrislu 2 months ago
parent
commit
d1de50c9d3
  1. 14
      weed/iam/ldap/ldap_provider.go
  2. 12
      weed/iam/ldap/mock_provider.go
  3. 12
      weed/iam/oidc/oidc_provider.go
  4. 44
      weed/iam/policy/policy_engine_test.go
  5. 46
      weed/iam/policy/policy_store.go
  6. 80
      weed/iam/sts/sts_service.go
  7. 6
      weed/iam/sts/sts_service_test.go

14
weed/iam/ldap/ldap_provider.go

@ -70,7 +70,7 @@ func (p *LDAPProvider) Initialize(config interface{}) error {
if config == nil { if config == nil {
return fmt.Errorf("config cannot be nil") return fmt.Errorf("config cannot be nil")
} }
ldapConfig, ok := config.(*LDAPConfig) ldapConfig, ok := config.(*LDAPConfig)
if !ok { if !ok {
return fmt.Errorf("invalid config type for LDAP provider") return fmt.Errorf("invalid config type for LDAP provider")
@ -136,17 +136,17 @@ func (p *LDAPProvider) Authenticate(ctx context.Context, credentials string) (*p
} }
username, password := parts[0], parts[1] username, password := parts[0], parts[1]
// TODO: Implement actual LDAP authentication // TODO: Implement actual LDAP authentication
// 1. Connect to LDAP server using bind credentials // 1. Connect to LDAP server using bind credentials
// 2. Search for user using configured user filter // 2. Search for user using configured user filter
// 3. Attempt to bind with user credentials // 3. Attempt to bind with user credentials
// 4. Retrieve user attributes and group memberships // 4. Retrieve user attributes and group memberships
// 5. Map to ExternalIdentity structure // 5. Map to ExternalIdentity structure
_ = username // Avoid unused variable warning _ = username // Avoid unused variable warning
_ = password // Avoid unused variable warning _ = password // Avoid unused variable warning
return nil, fmt.Errorf("LDAP authentication not implemented yet - requires LDAP client integration") return nil, fmt.Errorf("LDAP authentication not implemented yet - requires LDAP client integration")
} }
@ -166,7 +166,7 @@ func (p *LDAPProvider) GetUserInfo(ctx context.Context, userID string) (*provide
// 3. Retrieve configured attributes (email, displayName, etc.) // 3. Retrieve configured attributes (email, displayName, etc.)
// 4. Retrieve group memberships using group filter // 4. Retrieve group memberships using group filter
// 5. Map to ExternalIdentity structure // 5. Map to ExternalIdentity structure
return nil, fmt.Errorf("LDAP user info retrieval not implemented yet") return nil, fmt.Errorf("LDAP user info retrieval not implemented yet")
} }
@ -195,10 +195,10 @@ func (p *LDAPProvider) ValidateToken(ctx context.Context, token string) (*provid
// 4. Attempt to bind with user credentials to validate password // 4. Attempt to bind with user credentials to validate password
// 5. Extract user claims (DN, attributes, group memberships) // 5. Extract user claims (DN, attributes, group memberships)
// 6. Return TokenClaims with LDAP-specific information // 6. Return TokenClaims with LDAP-specific information
_ = username // Avoid unused variable warning _ = username // Avoid unused variable warning
_ = password // Avoid unused variable warning _ = password // Avoid unused variable warning
return nil, fmt.Errorf("LDAP credential validation not implemented yet") return nil, fmt.Errorf("LDAP credential validation not implemented yet")
} }

12
weed/iam/ldap/mock_provider.go

@ -11,7 +11,7 @@ import (
// MockLDAPProvider is a mock implementation for testing // MockLDAPProvider is a mock implementation for testing
type MockLDAPProvider struct { type MockLDAPProvider struct {
*LDAPProvider *LDAPProvider
TestUsers map[string]*providers.ExternalIdentity
TestUsers map[string]*providers.ExternalIdentity
TestCredentials map[string]string // username -> password TestCredentials map[string]string // username -> password
} }
@ -124,11 +124,11 @@ func (m *MockLDAPProvider) ValidateToken(ctx context.Context, token string) (*pr
return &providers.TokenClaims{ return &providers.TokenClaims{
Subject: username, Subject: username,
Claims: map[string]interface{}{ Claims: map[string]interface{}{
"ldap_dn": "CN=" + username + ",DC=test,DC=com",
"email": identity.Email,
"name": identity.DisplayName,
"groups": identity.Groups,
"provider": m.name,
"ldap_dn": "CN=" + username + ",DC=test,DC=com",
"email": identity.Email,
"name": identity.DisplayName,
"groups": identity.Groups,
"provider": m.name,
}, },
}, nil }, nil
} }

12
weed/iam/oidc/oidc_provider.go

@ -59,7 +59,7 @@ func (p *OIDCProvider) Initialize(config interface{}) error {
if config == nil { if config == nil {
return fmt.Errorf("config cannot be nil") return fmt.Errorf("config cannot be nil")
} }
oidcConfig, ok := config.(*OIDCConfig) oidcConfig, ok := config.(*OIDCConfig)
if !ok { if !ok {
return fmt.Errorf("invalid config type for OIDC provider") return fmt.Errorf("invalid config type for OIDC provider")
@ -114,7 +114,7 @@ func (p *OIDCProvider) Authenticate(ctx context.Context, token string) (*provide
email, _ := claims.GetClaimString("email") email, _ := claims.GetClaimString("email")
displayName, _ := claims.GetClaimString("name") displayName, _ := claims.GetClaimString("name")
groups, _ := claims.GetClaimStringSlice("groups") groups, _ := claims.GetClaimStringSlice("groups")
return &providers.ExternalIdentity{ return &providers.ExternalIdentity{
UserID: claims.Subject, UserID: claims.Subject,
Email: email, Email: email,
@ -138,7 +138,7 @@ func (p *OIDCProvider) GetUserInfo(ctx context.Context, userID string) (*provide
// 1. Make HTTP request to UserInfo endpoint // 1. Make HTTP request to UserInfo endpoint
// 2. Parse response and extract user claims // 2. Parse response and extract user claims
// 3. Map claims to ExternalIdentity structure // 3. Map claims to ExternalIdentity structure
return nil, fmt.Errorf("UserInfo endpoint integration not implemented yet") return nil, fmt.Errorf("UserInfo endpoint integration not implemented yet")
} }
@ -153,11 +153,11 @@ func (p *OIDCProvider) ValidateToken(ctx context.Context, token string) (*provid
} }
// TODO: Implement actual JWT token validation // TODO: Implement actual JWT token validation
// 1. Parse JWT token
// 1. Parse JWT token
// 2. Verify signature using JWKS from provider // 2. Verify signature using JWKS from provider
// 3. Validate claims (iss, aud, exp, etc.) // 3. Validate claims (iss, aud, exp, etc.)
// 4. Extract user claims // 4. Extract user claims
return nil, fmt.Errorf("JWT validation not implemented yet - requires JWKS integration") return nil, fmt.Errorf("JWT validation not implemented yet - requires JWKS integration")
} }
@ -167,7 +167,7 @@ func (p *OIDCProvider) mapClaimsToRoles(claims *providers.TokenClaims) []string
// Get groups from claims // Get groups from claims
groups, _ := claims.GetClaimStringSlice("groups") groups, _ := claims.GetClaimStringSlice("groups")
// Basic role mapping based on groups // Basic role mapping based on groups
for _, group := range groups { for _, group := range groups {
switch group { switch group {

44
weed/iam/policy/policy_engine_test.go

@ -41,9 +41,9 @@ func TestPolicyEngineInitialization(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
engine := NewPolicyEngine() engine := NewPolicyEngine()
err := engine.Initialize(tt.config) err := engine.Initialize(tt.config)
if tt.wantErr { if tt.wantErr {
assert.Error(t, err) assert.Error(t, err)
} else { } else {
@ -120,7 +120,7 @@ func TestPolicyDocumentValidation(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
err := ValidatePolicyDocument(tt.policy) err := ValidatePolicyDocument(tt.policy)
if tt.wantErr { if tt.wantErr {
assert.Error(t, err) assert.Error(t, err)
if tt.errorMsg != "" { if tt.errorMsg != "" {
@ -136,26 +136,26 @@ func TestPolicyDocumentValidation(t *testing.T) {
// TestPolicyEvaluation tests policy evaluation logic // TestPolicyEvaluation tests policy evaluation logic
func TestPolicyEvaluation(t *testing.T) { func TestPolicyEvaluation(t *testing.T) {
engine := setupTestPolicyEngine(t) engine := setupTestPolicyEngine(t)
// Add test policies // Add test policies
readPolicy := &PolicyDocument{ readPolicy := &PolicyDocument{
Version: "2012-10-17", Version: "2012-10-17",
Statement: []Statement{ Statement: []Statement{
{ {
Sid: "AllowS3Read",
Effect: "Allow",
Action: []string{"s3:GetObject", "s3:ListBucket"},
Sid: "AllowS3Read",
Effect: "Allow",
Action: []string{"s3:GetObject", "s3:ListBucket"},
Resource: []string{ Resource: []string{
"arn:seaweed:s3:::public-bucket/*", // For object operations
"arn:seaweed:s3:::public-bucket", // For bucket operations
"arn:seaweed:s3:::public-bucket/*", // For object operations
"arn:seaweed:s3:::public-bucket", // For bucket operations
}, },
}, },
}, },
} }
err := engine.AddPolicy("read-policy", readPolicy) err := engine.AddPolicy("read-policy", readPolicy)
require.NoError(t, err) require.NoError(t, err)
denyPolicy := &PolicyDocument{ denyPolicy := &PolicyDocument{
Version: "2012-10-17", Version: "2012-10-17",
Statement: []Statement{ Statement: []Statement{
@ -167,7 +167,7 @@ func TestPolicyEvaluation(t *testing.T) {
}, },
}, },
} }
err = engine.AddPolicy("deny-policy", denyPolicy) err = engine.AddPolicy("deny-policy", denyPolicy)
require.NoError(t, err) require.NoError(t, err)
@ -225,10 +225,10 @@ func TestPolicyEvaluation(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result, err := engine.Evaluate(context.Background(), tt.context, tt.policies) result, err := engine.Evaluate(context.Background(), tt.context, tt.policies)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, tt.want, result.Effect) assert.Equal(t, tt.want, result.Effect)
// Verify evaluation details // Verify evaluation details
assert.NotNil(t, result.EvaluationDetails) assert.NotNil(t, result.EvaluationDetails)
assert.Equal(t, tt.context.Action, result.EvaluationDetails.Action) assert.Equal(t, tt.context.Action, result.EvaluationDetails.Action)
@ -240,7 +240,7 @@ func TestPolicyEvaluation(t *testing.T) {
// TestConditionEvaluation tests policy conditions // TestConditionEvaluation tests policy conditions
func TestConditionEvaluation(t *testing.T) { func TestConditionEvaluation(t *testing.T) {
engine := setupTestPolicyEngine(t) engine := setupTestPolicyEngine(t)
// Policy with IP address condition // Policy with IP address condition
conditionalPolicy := &PolicyDocument{ conditionalPolicy := &PolicyDocument{
Version: "2012-10-17", Version: "2012-10-17",
@ -258,7 +258,7 @@ func TestConditionEvaluation(t *testing.T) {
}, },
}, },
} }
err := engine.AddPolicy("ip-conditional", conditionalPolicy) err := engine.AddPolicy("ip-conditional", conditionalPolicy)
require.NoError(t, err) require.NoError(t, err)
@ -308,7 +308,7 @@ func TestConditionEvaluation(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
result, err := engine.Evaluate(context.Background(), tt.context, []string{"ip-conditional"}) result, err := engine.Evaluate(context.Background(), tt.context, []string{"ip-conditional"})
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, tt.want, result.Effect) assert.Equal(t, tt.want, result.Effect)
}) })
@ -318,10 +318,10 @@ func TestConditionEvaluation(t *testing.T) {
// TestResourceMatching tests resource ARN matching // TestResourceMatching tests resource ARN matching
func TestResourceMatching(t *testing.T) { func TestResourceMatching(t *testing.T) {
tests := []struct { tests := []struct {
name string
policyResource string
name string
policyResource string
requestResource string requestResource string
want bool
want bool
}{ }{
{ {
name: "exact match", name: "exact match",
@ -418,9 +418,9 @@ func setupTestPolicyEngine(t *testing.T) *PolicyEngine {
DefaultEffect: "Deny", DefaultEffect: "Deny",
StoreType: "memory", StoreType: "memory",
} }
err := engine.Initialize(config) err := engine.Initialize(config)
require.NoError(t, err) require.NoError(t, err)
return engine return engine
} }

46
weed/iam/policy/policy_store.go

@ -24,14 +24,14 @@ func (s *MemoryPolicyStore) StorePolicy(ctx context.Context, name string, policy
if name == "" { if name == "" {
return fmt.Errorf("policy name cannot be empty") return fmt.Errorf("policy name cannot be empty")
} }
if policy == nil { if policy == nil {
return fmt.Errorf("policy cannot be nil") return fmt.Errorf("policy cannot be nil")
} }
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
// Deep copy the policy to prevent external modifications // Deep copy the policy to prevent external modifications
s.policies[name] = copyPolicyDocument(policy) s.policies[name] = copyPolicyDocument(policy)
return nil return nil
@ -42,15 +42,15 @@ func (s *MemoryPolicyStore) GetPolicy(ctx context.Context, name string) (*Policy
if name == "" { if name == "" {
return nil, fmt.Errorf("policy name cannot be empty") return nil, fmt.Errorf("policy name cannot be empty")
} }
s.mutex.RLock() s.mutex.RLock()
defer s.mutex.RUnlock() defer s.mutex.RUnlock()
policy, exists := s.policies[name] policy, exists := s.policies[name]
if !exists { if !exists {
return nil, fmt.Errorf("policy not found: %s", name) return nil, fmt.Errorf("policy not found: %s", name)
} }
// Return a copy to prevent external modifications // Return a copy to prevent external modifications
return copyPolicyDocument(policy), nil return copyPolicyDocument(policy), nil
} }
@ -60,10 +60,10 @@ func (s *MemoryPolicyStore) DeletePolicy(ctx context.Context, name string) error
if name == "" { if name == "" {
return fmt.Errorf("policy name cannot be empty") return fmt.Errorf("policy name cannot be empty")
} }
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() defer s.mutex.Unlock()
delete(s.policies, name) delete(s.policies, name)
return nil return nil
} }
@ -72,12 +72,12 @@ func (s *MemoryPolicyStore) DeletePolicy(ctx context.Context, name string) error
func (s *MemoryPolicyStore) ListPolicies(ctx context.Context) ([]string, error) { func (s *MemoryPolicyStore) ListPolicies(ctx context.Context) ([]string, error) {
s.mutex.RLock() s.mutex.RLock()
defer s.mutex.RUnlock() defer s.mutex.RUnlock()
names := make([]string, 0, len(s.policies)) names := make([]string, 0, len(s.policies))
for name := range s.policies { for name := range s.policies {
names = append(names, name) names = append(names, name)
} }
return names, nil return names, nil
} }
@ -86,12 +86,12 @@ func copyPolicyDocument(original *PolicyDocument) *PolicyDocument {
if original == nil { if original == nil {
return nil return nil
} }
copied := &PolicyDocument{ copied := &PolicyDocument{
Version: original.Version, Version: original.Version,
Id: original.Id, Id: original.Id,
} }
// Copy statements // Copy statements
copied.Statement = make([]Statement, len(original.Statement)) copied.Statement = make([]Statement, len(original.Statement))
for i, stmt := range original.Statement { for i, stmt := range original.Statement {
@ -101,31 +101,31 @@ func copyPolicyDocument(original *PolicyDocument) *PolicyDocument {
Principal: stmt.Principal, Principal: stmt.Principal,
NotPrincipal: stmt.NotPrincipal, NotPrincipal: stmt.NotPrincipal,
} }
// Copy action slice // Copy action slice
if stmt.Action != nil { if stmt.Action != nil {
copied.Statement[i].Action = make([]string, len(stmt.Action)) copied.Statement[i].Action = make([]string, len(stmt.Action))
copy(copied.Statement[i].Action, stmt.Action) copy(copied.Statement[i].Action, stmt.Action)
} }
// Copy NotAction slice // Copy NotAction slice
if stmt.NotAction != nil { if stmt.NotAction != nil {
copied.Statement[i].NotAction = make([]string, len(stmt.NotAction)) copied.Statement[i].NotAction = make([]string, len(stmt.NotAction))
copy(copied.Statement[i].NotAction, stmt.NotAction) copy(copied.Statement[i].NotAction, stmt.NotAction)
} }
// Copy resource slice // Copy resource slice
if stmt.Resource != nil { if stmt.Resource != nil {
copied.Statement[i].Resource = make([]string, len(stmt.Resource)) copied.Statement[i].Resource = make([]string, len(stmt.Resource))
copy(copied.Statement[i].Resource, stmt.Resource) copy(copied.Statement[i].Resource, stmt.Resource)
} }
// Copy NotResource slice // Copy NotResource slice
if stmt.NotResource != nil { if stmt.NotResource != nil {
copied.Statement[i].NotResource = make([]string, len(stmt.NotResource)) copied.Statement[i].NotResource = make([]string, len(stmt.NotResource))
copy(copied.Statement[i].NotResource, stmt.NotResource) copy(copied.Statement[i].NotResource, stmt.NotResource)
} }
// Copy condition map (shallow copy for now) // Copy condition map (shallow copy for now)
if stmt.Condition != nil { if stmt.Condition != nil {
copied.Statement[i].Condition = make(map[string]map[string]interface{}) copied.Statement[i].Condition = make(map[string]map[string]interface{})
@ -134,7 +134,7 @@ func copyPolicyDocument(original *PolicyDocument) *PolicyDocument {
} }
} }
} }
return copied return copied
} }
@ -150,7 +150,7 @@ func NewFilerPolicyStore(config map[string]interface{}) (*FilerPolicyStore, erro
// 1. Parse configuration for filer connection details // 1. Parse configuration for filer connection details
// 2. Set up filer client // 2. Set up filer client
// 3. Configure base path for policy storage // 3. Configure base path for policy storage
return nil, fmt.Errorf("filer policy store not implemented yet") return nil, fmt.Errorf("filer policy store not implemented yet")
} }
@ -160,7 +160,7 @@ func (s *FilerPolicyStore) StorePolicy(ctx context.Context, name string, policy
// 1. Serialize policy to JSON // 1. Serialize policy to JSON
// 2. Store in filer at basePath/policies/name.json // 2. Store in filer at basePath/policies/name.json
// 3. Handle errors and retries // 3. Handle errors and retries
return fmt.Errorf("filer policy storage not implemented yet") return fmt.Errorf("filer policy storage not implemented yet")
} }
@ -170,7 +170,7 @@ func (s *FilerPolicyStore) GetPolicy(ctx context.Context, name string) (*PolicyD
// 1. Read policy file from filer // 1. Read policy file from filer
// 2. Deserialize JSON to PolicyDocument // 2. Deserialize JSON to PolicyDocument
// 3. Handle not found cases // 3. Handle not found cases
return nil, fmt.Errorf("filer policy retrieval not implemented yet") return nil, fmt.Errorf("filer policy retrieval not implemented yet")
} }
@ -179,7 +179,7 @@ func (s *FilerPolicyStore) DeletePolicy(ctx context.Context, name string) error
// TODO: Implement filer policy deletion // TODO: Implement filer policy deletion
// 1. Delete policy file from filer // 1. Delete policy file from filer
// 2. Handle errors // 2. Handle errors
return fmt.Errorf("filer policy deletion not implemented yet") return fmt.Errorf("filer policy deletion not implemented yet")
} }
@ -189,6 +189,6 @@ func (s *FilerPolicyStore) ListPolicies(ctx context.Context) ([]string, error) {
// 1. List files in basePath/policies/ // 1. List files in basePath/policies/
// 2. Extract policy names from filenames // 2. Extract policy names from filenames
// 3. Return sorted list // 3. Return sorted list
return nil, fmt.Errorf("filer policy listing not implemented yet") return nil, fmt.Errorf("filer policy listing not implemented yet")
} }

80
weed/iam/sts/sts_service.go

@ -243,7 +243,7 @@ func (s *STSService) AssumeRoleWithWebIdentity(ctx context.Context, request *Ass
if !s.initialized { if !s.initialized {
return nil, fmt.Errorf("STS service not initialized") return nil, fmt.Errorf("STS service not initialized")
} }
if request == nil { if request == nil {
return nil, fmt.Errorf("request cannot be nil") return nil, fmt.Errorf("request cannot be nil")
} }
@ -300,8 +300,8 @@ func (s *STSService) AssumeRoleWithWebIdentity(ctx context.Context, request *Ass
// 7. Build and return response // 7. Build and return response
assumedRoleUser := &AssumedRoleUser{ assumedRoleUser := &AssumedRoleUser{
AssumedRoleId: request.RoleArn, AssumedRoleId: request.RoleArn,
Arn: GenerateAssumedRoleArn(request.RoleArn, request.RoleSessionName),
Subject: externalIdentity.UserID,
Arn: GenerateAssumedRoleArn(request.RoleArn, request.RoleSessionName),
Subject: externalIdentity.UserID,
} }
return &AssumeRoleResponse{ return &AssumeRoleResponse{
@ -315,7 +315,7 @@ func (s *STSService) AssumeRoleWithCredentials(ctx context.Context, request *Ass
if !s.initialized { if !s.initialized {
return nil, fmt.Errorf("STS service not initialized") return nil, fmt.Errorf("STS service not initialized")
} }
if request == nil { if request == nil {
return nil, fmt.Errorf("request cannot be nil") return nil, fmt.Errorf("request cannot be nil")
} }
@ -379,8 +379,8 @@ func (s *STSService) AssumeRoleWithCredentials(ctx context.Context, request *Ass
// 8. Build and return response // 8. Build and return response
assumedRoleUser := &AssumedRoleUser{ assumedRoleUser := &AssumedRoleUser{
AssumedRoleId: request.RoleArn, AssumedRoleId: request.RoleArn,
Arn: GenerateAssumedRoleArn(request.RoleArn, request.RoleSessionName),
Subject: externalIdentity.UserID,
Arn: GenerateAssumedRoleArn(request.RoleArn, request.RoleSessionName),
Subject: externalIdentity.UserID,
} }
return &AssumeRoleResponse{ return &AssumeRoleResponse{
@ -394,34 +394,34 @@ func (s *STSService) ValidateSessionToken(ctx context.Context, sessionToken stri
if !s.initialized { if !s.initialized {
return nil, fmt.Errorf("STS service not initialized") return nil, fmt.Errorf("STS service not initialized")
} }
if sessionToken == "" { if sessionToken == "" {
return nil, fmt.Errorf("session token cannot be empty") return nil, fmt.Errorf("session token cannot be empty")
} }
// For now, use the session token as session ID directly // For now, use the session token as session ID directly
// In a full implementation, this would: // In a full implementation, this would:
// 1. Parse JWT session token // 1. Parse JWT session token
// 2. Verify signature and expiration // 2. Verify signature and expiration
// 3. Extract session ID from claims // 3. Extract session ID from claims
// Extract session ID (simplified - assuming token contains session ID directly) // Extract session ID (simplified - assuming token contains session ID directly)
sessionId := s.extractSessionIdFromToken(sessionToken) sessionId := s.extractSessionIdFromToken(sessionToken)
if sessionId == "" { if sessionId == "" {
return nil, fmt.Errorf("invalid session token format") return nil, fmt.Errorf("invalid session token format")
} }
// Retrieve session from store // Retrieve session from store
session, err := s.sessionStore.GetSession(ctx, sessionId) session, err := s.sessionStore.GetSession(ctx, sessionId)
if err != nil { if err != nil {
return nil, fmt.Errorf("session validation failed: %w", err) return nil, fmt.Errorf("session validation failed: %w", err)
} }
// Additional validation can be added here // Additional validation can be added here
if session.ExpiresAt.Before(time.Now()) { if session.ExpiresAt.Before(time.Now()) {
return nil, fmt.Errorf("session has expired") return nil, fmt.Errorf("session has expired")
} }
return session, nil return session, nil
} }
@ -430,23 +430,23 @@ func (s *STSService) RevokeSession(ctx context.Context, sessionToken string) err
if !s.initialized { if !s.initialized {
return fmt.Errorf("STS service not initialized") return fmt.Errorf("STS service not initialized")
} }
if sessionToken == "" { if sessionToken == "" {
return fmt.Errorf("session token cannot be empty") return fmt.Errorf("session token cannot be empty")
} }
// Extract session ID from token // Extract session ID from token
sessionId := s.extractSessionIdFromToken(sessionToken) sessionId := s.extractSessionIdFromToken(sessionToken)
if sessionId == "" { if sessionId == "" {
return fmt.Errorf("invalid session token format") return fmt.Errorf("invalid session token format")
} }
// Remove session from store // Remove session from store
err := s.sessionStore.RevokeSession(ctx, sessionId) err := s.sessionStore.RevokeSession(ctx, sessionId)
if err != nil { if err != nil {
return fmt.Errorf("failed to revoke session: %w", err) return fmt.Errorf("failed to revoke session: %w", err)
} }
return nil return nil
} }
@ -457,22 +457,22 @@ func (s *STSService) validateAssumeRoleWithWebIdentityRequest(request *AssumeRol
if request.RoleArn == "" { if request.RoleArn == "" {
return fmt.Errorf("RoleArn is required") return fmt.Errorf("RoleArn is required")
} }
if request.WebIdentityToken == "" { if request.WebIdentityToken == "" {
return fmt.Errorf("WebIdentityToken is required") return fmt.Errorf("WebIdentityToken is required")
} }
if request.RoleSessionName == "" { if request.RoleSessionName == "" {
return fmt.Errorf("RoleSessionName is required") return fmt.Errorf("RoleSessionName is required")
} }
// Validate session duration if provided // Validate session duration if provided
if request.DurationSeconds != nil { if request.DurationSeconds != nil {
if *request.DurationSeconds < 900 || *request.DurationSeconds > 43200 { // 15min to 12 hours if *request.DurationSeconds < 900 || *request.DurationSeconds > 43200 { // 15min to 12 hours
return fmt.Errorf("DurationSeconds must be between 900 and 43200 seconds") return fmt.Errorf("DurationSeconds must be between 900 and 43200 seconds")
} }
} }
return nil return nil
} }
@ -486,7 +486,7 @@ func (s *STSService) validateWebIdentityToken(ctx context.Context, token string)
return identity, provider, nil return identity, provider, nil
} }
} }
return nil, nil, fmt.Errorf("web identity token validation failed with all providers") return nil, nil, fmt.Errorf("web identity token validation failed with all providers")
} }
@ -497,27 +497,27 @@ func (s *STSService) validateRoleAssumption(roleArn string, identity *providers.
// 1. Role exists // 1. Role exists
// 2. Role trust policy allows assumption by this identity // 2. Role trust policy allows assumption by this identity
// 3. Identity has permission to assume the role // 3. Identity has permission to assume the role
if roleArn == "" { if roleArn == "" {
return fmt.Errorf("role ARN cannot be empty") return fmt.Errorf("role ARN cannot be empty")
} }
if identity == nil { if identity == nil {
return fmt.Errorf("identity cannot be nil") return fmt.Errorf("identity cannot be nil")
} }
// Basic role ARN format validation // Basic role ARN format validation
expectedPrefix := "arn:seaweed:iam::role/" expectedPrefix := "arn:seaweed:iam::role/"
if len(roleArn) < len(expectedPrefix) || roleArn[:len(expectedPrefix)] != expectedPrefix { if len(roleArn) < len(expectedPrefix) || roleArn[:len(expectedPrefix)] != expectedPrefix {
return fmt.Errorf("invalid role ARN format: got %s, expected format: %s*", roleArn, expectedPrefix) return fmt.Errorf("invalid role ARN format: got %s, expected format: %s*", roleArn, expectedPrefix)
} }
// For testing, reject non-existent roles // For testing, reject non-existent roles
roleName := extractRoleNameFromArn(roleArn) roleName := extractRoleNameFromArn(roleArn)
if roleName == "NonExistentRole" { if roleName == "NonExistentRole" {
return fmt.Errorf("role does not exist: %s", roleName) return fmt.Errorf("role does not exist: %s", roleName)
} }
return nil return nil
} }
@ -526,7 +526,7 @@ func (s *STSService) calculateSessionDuration(durationSeconds *int64) time.Durat
if durationSeconds != nil { if durationSeconds != nil {
return time.Duration(*durationSeconds) * time.Second return time.Duration(*durationSeconds) * time.Second
} }
// Use default from config // Use default from config
return s.config.TokenDuration return s.config.TokenDuration
} }
@ -536,7 +536,7 @@ func (s *STSService) extractSessionIdFromToken(sessionToken string) string {
// For simplified implementation, we need to map session tokens to session IDs // For simplified implementation, we need to map session tokens to session IDs
// The session token is stored as part of the credentials in the session // The session token is stored as part of the credentials in the session
// So we need to search through sessions to find the matching token // So we need to search through sessions to find the matching token
// For now, use the session token directly as session ID since we store them together // For now, use the session token directly as session ID since we store them together
// In a full implementation, this would parse JWT and extract session ID from claims // In a full implementation, this would parse JWT and extract session ID from claims
if len(sessionToken) > 10 && sessionToken[:2] == "ST" { if len(sessionToken) > 10 && sessionToken[:2] == "ST" {
@ -544,12 +544,12 @@ func (s *STSService) extractSessionIdFromToken(sessionToken string) string {
// This is inefficient but works for testing // This is inefficient but works for testing
return s.findSessionIdByToken(sessionToken) return s.findSessionIdByToken(sessionToken)
} }
// For test compatibility, also handle direct session IDs // For test compatibility, also handle direct session IDs
if len(sessionToken) == 32 { // Typical session ID length if len(sessionToken) == 32 { // Typical session ID length
return sessionToken return sessionToken
} }
return "" return ""
} }
@ -558,22 +558,22 @@ func (s *STSService) findSessionIdByToken(sessionToken string) string {
// In a real implementation, we'd maintain a reverse index // In a real implementation, we'd maintain a reverse index
// For testing, we can use the fact that our memory store can be searched // For testing, we can use the fact that our memory store can be searched
// This is a simplified approach - in production we'd use proper token->session mapping // This is a simplified approach - in production we'd use proper token->session mapping
memStore, ok := s.sessionStore.(*MemorySessionStore) memStore, ok := s.sessionStore.(*MemorySessionStore)
if !ok { if !ok {
return "" return ""
} }
// Search through all sessions to find matching token // Search through all sessions to find matching token
memStore.mutex.RLock() memStore.mutex.RLock()
defer memStore.mutex.RUnlock() defer memStore.mutex.RUnlock()
for sessionId, session := range memStore.sessions { for sessionId, session := range memStore.sessions {
if session.Credentials != nil && session.Credentials.SessionToken == sessionToken { if session.Credentials != nil && session.Credentials.SessionToken == sessionToken {
return sessionId return sessionId
} }
} }
return "" return ""
} }
@ -582,29 +582,29 @@ func (s *STSService) validateAssumeRoleWithCredentialsRequest(request *AssumeRol
if request.RoleArn == "" { if request.RoleArn == "" {
return fmt.Errorf("RoleArn is required") return fmt.Errorf("RoleArn is required")
} }
if request.Username == "" { if request.Username == "" {
return fmt.Errorf("Username is required") return fmt.Errorf("Username is required")
} }
if request.Password == "" { if request.Password == "" {
return fmt.Errorf("Password is required") return fmt.Errorf("Password is required")
} }
if request.RoleSessionName == "" { if request.RoleSessionName == "" {
return fmt.Errorf("RoleSessionName is required") return fmt.Errorf("RoleSessionName is required")
} }
if request.ProviderName == "" { if request.ProviderName == "" {
return fmt.Errorf("ProviderName is required") return fmt.Errorf("ProviderName is required")
} }
// Validate session duration if provided // Validate session duration if provided
if request.DurationSeconds != nil { if request.DurationSeconds != nil {
if *request.DurationSeconds < 900 || *request.DurationSeconds > 43200 { // 15min to 12 hours if *request.DurationSeconds < 900 || *request.DurationSeconds > 43200 { // 15min to 12 hours
return fmt.Errorf("DurationSeconds must be between 900 and 43200 seconds") return fmt.Errorf("DurationSeconds must be between 900 and 43200 seconds")
} }
} }
return nil return nil
} }

6
weed/iam/sts/sts_service_test.go

@ -364,7 +364,7 @@ func (m *MockIdentityProvider) Authenticate(ctx context.Context, token string) (
if claims, exists := m.validTokens[token]; exists { if claims, exists := m.validTokens[token]; exists {
email, _ := claims.GetClaimString("email") email, _ := claims.GetClaimString("email")
name, _ := claims.GetClaimString("name") name, _ := claims.GetClaimString("name")
return &providers.ExternalIdentity{ return &providers.ExternalIdentity{
UserID: claims.Subject, UserID: claims.Subject,
Email: email, Email: email,
@ -372,7 +372,7 @@ func (m *MockIdentityProvider) Authenticate(ctx context.Context, token string) (
Provider: m.name, Provider: m.name,
}, nil }, nil
} }
// Handle LDAP credentials (username:password format) // Handle LDAP credentials (username:password format)
if m.validCredentials != nil { if m.validCredentials != nil {
parts := strings.Split(token, ":") parts := strings.Split(token, ":")
@ -388,7 +388,7 @@ func (m *MockIdentityProvider) Authenticate(ctx context.Context, token string) (
} }
} }
} }
return nil, fmt.Errorf("invalid token") return nil, fmt.Errorf("invalid token")
} }

Loading…
Cancel
Save