Browse Source

pass in filerAddressProvider

pull/7160/head
chrislu 1 month ago
parent
commit
db0a9bd031
  1. 4
      weed/iam/integration/iam_integration_test.go
  2. 26
      weed/iam/integration/iam_manager.go
  3. 4
      weed/iam/integration/role_store_test.go
  4. 4
      weed/s3api/s3_end_to_end_test.go
  5. 4
      weed/s3api/s3_iam_simple_test.go
  6. 8
      weed/s3api/s3api_server.go

4
weed/iam/integration/iam_integration_test.go

@ -364,7 +364,9 @@ func setupIntegratedIAMSystem(t *testing.T) *IAMManager {
}, },
} }
err := manager.Initialize(config)
err := manager.Initialize(config, func() string {
return "localhost:8888" // Mock filer address for testing
})
require.NoError(t, err) require.NoError(t, err)
// Set up test providers // Set up test providers

26
weed/iam/integration/iam_manager.go

@ -18,6 +18,7 @@ type IAMManager struct {
stsService *sts.STSService stsService *sts.STSService
policyEngine *policy.PolicyEngine policyEngine *policy.PolicyEngine
roleStore RoleStore roleStore RoleStore
filerAddressProvider func() string // Function to get current filer address
initialized bool initialized bool
} }
@ -84,11 +85,14 @@ func NewIAMManager() *IAMManager {
} }
// Initialize initializes the IAM manager with all components // Initialize initializes the IAM manager with all components
func (m *IAMManager) Initialize(config *IAMConfig) error {
func (m *IAMManager) Initialize(config *IAMConfig, filerAddressProvider func() string) error {
if config == nil { if config == nil {
return fmt.Errorf("config cannot be nil") return fmt.Errorf("config cannot be nil")
} }
// Store the filer address provider function
m.filerAddressProvider = filerAddressProvider
// Initialize STS service // Initialize STS service
m.stsService = sts.NewSTSService() m.stsService = sts.NewSTSService()
if err := m.stsService.Initialize(config.STS); err != nil { if err := m.stsService.Initialize(config.STS); err != nil {
@ -115,6 +119,14 @@ func (m *IAMManager) Initialize(config *IAMConfig) error {
return nil return nil
} }
// getFilerAddress returns the current filer address using the provider function
func (m *IAMManager) getFilerAddress() string {
if m.filerAddressProvider != nil {
return m.filerAddressProvider()
}
return "" // Fallback to empty string if no provider is set
}
// createRoleStore creates a role store based on configuration // createRoleStore creates a role store based on configuration
func (m *IAMManager) createRoleStore(config *RoleStoreConfig) (RoleStore, error) { func (m *IAMManager) createRoleStore(config *RoleStoreConfig) (RoleStore, error) {
if config == nil { if config == nil {
@ -190,7 +202,7 @@ func (m *IAMManager) AssumeRoleWithWebIdentity(ctx context.Context, request *sts
roleName := extractRoleNameFromArn(request.RoleArn) roleName := extractRoleNameFromArn(request.RoleArn)
// Get role definition // Get role definition
roleDef, err := m.roleStore.GetRole(ctx, "", roleName)
roleDef, err := m.roleStore.GetRole(ctx, m.getFilerAddress(), roleName)
if err != nil { if err != nil {
return nil, fmt.Errorf("role not found: %s", roleName) return nil, fmt.Errorf("role not found: %s", roleName)
} }
@ -214,7 +226,7 @@ func (m *IAMManager) AssumeRoleWithCredentials(ctx context.Context, request *sts
roleName := extractRoleNameFromArn(request.RoleArn) roleName := extractRoleNameFromArn(request.RoleArn)
// Get role definition // Get role definition
roleDef, err := m.roleStore.GetRole(ctx, "", roleName)
roleDef, err := m.roleStore.GetRole(ctx, m.getFilerAddress(), roleName)
if err != nil { if err != nil {
return nil, fmt.Errorf("role not found: %s", roleName) return nil, fmt.Errorf("role not found: %s", roleName)
} }
@ -249,7 +261,7 @@ func (m *IAMManager) IsActionAllowed(ctx context.Context, request *ActionRequest
} }
// Get role definition // Get role definition
roleDef, err := m.roleStore.GetRole(ctx, "", roleName)
roleDef, err := m.roleStore.GetRole(ctx, m.getFilerAddress(), roleName)
if err != nil { if err != nil {
return false, fmt.Errorf("role not found: %s", roleName) return false, fmt.Errorf("role not found: %s", roleName)
} }
@ -274,7 +286,7 @@ func (m *IAMManager) IsActionAllowed(ctx context.Context, request *ActionRequest
// ValidateTrustPolicy validates if a principal can assume a role (for testing) // ValidateTrustPolicy validates if a principal can assume a role (for testing)
func (m *IAMManager) ValidateTrustPolicy(ctx context.Context, roleArn, provider, userID string) bool { func (m *IAMManager) ValidateTrustPolicy(ctx context.Context, roleArn, provider, userID string) bool {
roleName := extractRoleNameFromArn(roleArn) roleName := extractRoleNameFromArn(roleArn)
roleDef, err := m.roleStore.GetRole(ctx, "", roleName)
roleDef, err := m.roleStore.GetRole(ctx, m.getFilerAddress(), roleName)
if err != nil { if err != nil {
return false return false
} }
@ -542,7 +554,7 @@ func (m *IAMManager) ValidateTrustPolicyForWebIdentity(ctx context.Context, role
roleName := extractRoleNameFromArn(roleArn) roleName := extractRoleNameFromArn(roleArn)
// Get role definition // Get role definition
roleDef, err := m.roleStore.GetRole(ctx, "", roleName)
roleDef, err := m.roleStore.GetRole(ctx, m.getFilerAddress(), roleName)
if err != nil { if err != nil {
return fmt.Errorf("role not found: %s", roleName) return fmt.Errorf("role not found: %s", roleName)
} }
@ -561,7 +573,7 @@ func (m *IAMManager) ValidateTrustPolicyForCredentials(ctx context.Context, role
roleName := extractRoleNameFromArn(roleArn) roleName := extractRoleNameFromArn(roleArn)
// Get role definition // Get role definition
roleDef, err := m.roleStore.GetRole(ctx, "", roleName)
roleDef, err := m.roleStore.GetRole(ctx, m.getFilerAddress(), roleName)
if err != nil { if err != nil {
return fmt.Errorf("role not found: %s", roleName) return fmt.Errorf("role not found: %s", roleName)
} }

4
weed/iam/integration/role_store_test.go

@ -103,7 +103,9 @@ func TestDistributedIAMManagerWithRoleStore(t *testing.T) {
} }
iamManager := NewIAMManager() iamManager := NewIAMManager()
err := iamManager.Initialize(config)
err := iamManager.Initialize(config, func() string {
return "localhost:8888" // Mock filer address for testing
})
require.NoError(t, err) require.NoError(t, err)
// Test creating a role // Test creating a role

4
weed/s3api/s3_end_to_end_test.go

@ -296,7 +296,9 @@ func setupCompleteS3IAMSystem(t *testing.T) (http.Handler, *integration.IAMManag
}, },
} }
err := iamManager.Initialize(config)
err := iamManager.Initialize(config, func() string {
return "localhost:8888" // Mock filer address for testing
})
require.NoError(t, err) require.NoError(t, err)
// Set up test identity providers // Set up test identity providers

4
weed/s3api/s3_iam_simple_test.go

@ -39,7 +39,9 @@ func TestS3IAMMiddleware(t *testing.T) {
}, },
} }
err := iamManager.Initialize(config)
err := iamManager.Initialize(config, func() string {
return "localhost:8888" // Mock filer address for testing
})
require.NoError(t, err) require.NoError(t, err)
// Create S3 IAM integration // Create S3 IAM integration

8
weed/s3api/s3api_server.go

@ -102,7 +102,9 @@ func NewS3ApiServerWithStore(router *mux.Router, option *S3ApiServerOption, expl
if option.IamConfig != "" { if option.IamConfig != "" {
glog.V(0).Infof("Loading advanced IAM configuration from: %s", option.IamConfig) glog.V(0).Infof("Loading advanced IAM configuration from: %s", option.IamConfig)
iamManager, err := loadIAMManagerFromConfig(option.IamConfig)
iamManager, err := loadIAMManagerFromConfig(option.IamConfig, func() string {
return string(option.Filer)
})
if err != nil { if err != nil {
glog.Errorf("Failed to load IAM configuration: %v", err) glog.Errorf("Failed to load IAM configuration: %v", err)
} else { } else {
@ -412,7 +414,7 @@ func (s3a *S3ApiServer) registerRouter(router *mux.Router) {
} }
// loadIAMManagerFromConfig loads the advanced IAM manager from configuration file // loadIAMManagerFromConfig loads the advanced IAM manager from configuration file
func loadIAMManagerFromConfig(configPath string) (*integration.IAMManager, error) {
func loadIAMManagerFromConfig(configPath string, filerAddressProvider func() string) (*integration.IAMManager, error) {
// Read configuration file // Read configuration file
configData, err := os.ReadFile(configPath) configData, err := os.ReadFile(configPath)
if err != nil { if err != nil {
@ -446,7 +448,7 @@ func loadIAMManagerFromConfig(configPath string) (*integration.IAMManager, error
// Initialize IAM manager // Initialize IAM manager
iamManager := integration.NewIAMManager() iamManager := integration.NewIAMManager()
if err := iamManager.Initialize(iamConfig); err != nil {
if err := iamManager.Initialize(iamConfig, filerAddressProvider); err != nil {
return nil, fmt.Errorf("failed to initialize IAM manager: %w", err) return nil, fmt.Errorf("failed to initialize IAM manager: %w", err)
} }

Loading…
Cancel
Save