diff --git a/weed/iam/oidc/oidc_provider.go b/weed/iam/oidc/oidc_provider.go index 5f947fce5..f6a81eb90 100644 --- a/weed/iam/oidc/oidc_provider.go +++ b/weed/iam/oidc/oidc_provider.go @@ -163,12 +163,75 @@ func (p *OIDCProvider) GetUserInfo(ctx context.Context, userID string) (*provide return nil, fmt.Errorf("user ID cannot be empty") } - // TODO: Implement UserInfo endpoint call - // 1. Make HTTP request to UserInfo endpoint - // 2. Parse response and extract user claims - // 3. Map claims to ExternalIdentity structure + // For now, we'll use a token-based approach since OIDC UserInfo typically requires a token + // In a real implementation, this would need an access token from the authentication flow + return p.getUserInfoWithToken(ctx, userID, "") +} + +// GetUserInfoWithToken retrieves user information using an access token +func (p *OIDCProvider) GetUserInfoWithToken(ctx context.Context, accessToken string) (*providers.ExternalIdentity, error) { + if !p.initialized { + return nil, fmt.Errorf("provider not initialized") + } + + if accessToken == "" { + return nil, fmt.Errorf("access token cannot be empty") + } - return nil, fmt.Errorf("UserInfo endpoint integration not implemented yet") + return p.getUserInfoWithToken(ctx, "", accessToken) +} + +// getUserInfoWithToken is the internal implementation for UserInfo endpoint calls +func (p *OIDCProvider) getUserInfoWithToken(ctx context.Context, userID, accessToken string) (*providers.ExternalIdentity, error) { + // Determine UserInfo endpoint URL + userInfoUri := p.config.UserInfoUri + if userInfoUri == "" { + // Use standard OIDC discovery endpoint convention + userInfoUri = strings.TrimSuffix(p.config.Issuer, "/") + "/userinfo" + } + + // Create HTTP request + req, err := http.NewRequestWithContext(ctx, "GET", userInfoUri, nil) + if err != nil { + return nil, fmt.Errorf("failed to create UserInfo request: %v", err) + } + + // Set authorization header if access token is provided + if accessToken != "" { + req.Header.Set("Authorization", "Bearer "+accessToken) + } + req.Header.Set("Accept", "application/json") + + // Make HTTP request + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to call UserInfo endpoint: %v", err) + } + defer resp.Body.Close() + + // Check response status + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("UserInfo endpoint returned status %d", resp.StatusCode) + } + + // Parse JSON response + var userInfo map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil { + return nil, fmt.Errorf("failed to decode UserInfo response: %v", err) + } + + glog.V(4).Infof("Received UserInfo response: %+v", userInfo) + + // Map UserInfo claims to ExternalIdentity + identity := p.mapUserInfoToIdentity(userInfo) + + // If userID was provided but not found in claims, use it + if userID != "" && identity.UserID == "" { + identity.UserID = userID + } + + glog.V(3).Infof("Retrieved user info from OIDC provider: %s", identity.UserID) + return identity, nil } // ValidateToken validates an OIDC JWT token @@ -365,3 +428,82 @@ func (p *OIDCProvider) parseRSAKey(key *JWK) (*rsa.PublicKey, error) { return pubKey, nil } + +// mapUserInfoToIdentity maps UserInfo response to ExternalIdentity +func (p *OIDCProvider) mapUserInfoToIdentity(userInfo map[string]interface{}) *providers.ExternalIdentity { + identity := &providers.ExternalIdentity{ + Provider: p.name, + Attributes: make(map[string]string), + } + + // Map standard OIDC claims + if sub, ok := userInfo["sub"].(string); ok { + identity.UserID = sub + } + + if email, ok := userInfo["email"].(string); ok { + identity.Email = email + } + + if name, ok := userInfo["name"].(string); ok { + identity.DisplayName = name + } + + // Handle groups claim (can be array of strings or single string) + if groupsData, exists := userInfo["groups"]; exists { + switch groups := groupsData.(type) { + case []interface{}: + // Array of groups + for _, group := range groups { + if groupStr, ok := group.(string); ok { + identity.Groups = append(identity.Groups, groupStr) + } + } + case []string: + // Direct string array + identity.Groups = groups + case string: + // Single group as string + identity.Groups = []string{groups} + } + } + + // Map configured custom claims + if p.config.ClaimsMapping != nil { + for identityField, oidcClaim := range p.config.ClaimsMapping { + if value, exists := userInfo[oidcClaim]; exists { + if strValue, ok := value.(string); ok { + switch identityField { + case "email": + if identity.Email == "" { + identity.Email = strValue + } + case "displayName": + if identity.DisplayName == "" { + identity.DisplayName = strValue + } + case "userID": + if identity.UserID == "" { + identity.UserID = strValue + } + default: + identity.Attributes[identityField] = strValue + } + } + } + } + } + + // Store all additional claims as attributes + for key, value := range userInfo { + if key != "sub" && key != "email" && key != "name" && key != "groups" { + if strValue, ok := value.(string); ok { + identity.Attributes[key] = strValue + } else if jsonValue, err := json.Marshal(value); err == nil { + identity.Attributes[key] = string(jsonValue) + } + } + } + + return identity +} diff --git a/weed/iam/oidc/oidc_provider_test.go b/weed/iam/oidc/oidc_provider_test.go index adbfe9115..8f8de7bc0 100644 --- a/weed/iam/oidc/oidc_provider_test.go +++ b/weed/iam/oidc/oidc_provider_test.go @@ -236,11 +236,37 @@ func TestOIDCProviderUserInfo(t *testing.T) { // Set up test server with UserInfo endpoint server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/userinfo" { + // Check for Authorization header + authHeader := r.Header.Get("Authorization") + if !strings.HasPrefix(authHeader, "Bearer ") { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"error": "unauthorized"}`)) + return + } + + accessToken := strings.TrimPrefix(authHeader, "Bearer ") + + // Return 401 for explicitly invalid tokens + if accessToken == "invalid-token" { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"error": "invalid_token"}`)) + return + } + + // Mock user info response userInfo := map[string]interface{}{ - "sub": r.URL.Query().Get("user_id"), + "sub": "user123", "email": "user@example.com", "name": "Test User", + "groups": []string{"users", "developers"}, + } + + // Customize response based on token + if strings.Contains(accessToken, "admin") { + userInfo["groups"] = []string{"admins"} } + + w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(userInfo) } })) @@ -256,16 +282,64 @@ func TestOIDCProviderUserInfo(t *testing.T) { err := provider.Initialize(config) require.NoError(t, err) - t.Run("get user info", func(t *testing.T) { - identity, err := provider.GetUserInfo(context.Background(), "user123") - if err != nil && strings.Contains(err.Error(), "not implemented yet") { - t.Skip("UserInfo endpoint integration not yet implemented - skipping user info test") - return - } + t.Run("get user info with access token", func(t *testing.T) { + // Test using access token (real UserInfo endpoint call) + identity, err := provider.GetUserInfoWithToken(context.Background(), "valid-access-token") + require.NoError(t, err) + require.NotNil(t, identity) + assert.Equal(t, "user123", identity.UserID) + assert.Equal(t, "user@example.com", identity.Email) + assert.Equal(t, "Test User", identity.DisplayName) + assert.Contains(t, identity.Groups, "users") + assert.Contains(t, identity.Groups, "developers") + assert.Equal(t, "test-oidc", identity.Provider) + }) + t.Run("get admin user info", func(t *testing.T) { + // Test admin token response + identity, err := provider.GetUserInfoWithToken(context.Background(), "admin-access-token") require.NoError(t, err) require.NotNil(t, identity) assert.Equal(t, "user123", identity.UserID) + assert.Contains(t, identity.Groups, "admins") + }) + + t.Run("get user info without token", func(t *testing.T) { + // Test without access token (should fail) + _, err := provider.GetUserInfoWithToken(context.Background(), "") + assert.Error(t, err) + assert.Contains(t, err.Error(), "access token cannot be empty") + }) + + t.Run("get user info with invalid token", func(t *testing.T) { + // Test with invalid access token (should get 401) + _, err := provider.GetUserInfoWithToken(context.Background(), "invalid-token") + assert.Error(t, err) + assert.Contains(t, err.Error(), "UserInfo endpoint returned status 401") + }) + + t.Run("get user info with custom claims mapping", func(t *testing.T) { + // Create provider with custom claims mapping + customProvider := NewOIDCProvider("test-custom-oidc") + customConfig := &OIDCConfig{ + Issuer: server.URL, + ClientID: "test-client", + UserInfoUri: server.URL + "/userinfo", + ClaimsMapping: map[string]string{ + "customEmail": "email", + "customName": "name", + }, + } + + err := customProvider.Initialize(customConfig) + require.NoError(t, err) + + identity, err := customProvider.GetUserInfoWithToken(context.Background(), "valid-access-token") + require.NoError(t, err) + require.NotNil(t, identity) + + // Standard claims should still work + assert.Equal(t, "user123", identity.UserID) assert.Equal(t, "user@example.com", identity.Email) assert.Equal(t, "Test User", identity.DisplayName) }) @@ -316,12 +390,46 @@ func setupOIDCTestServer(t *testing.T, publicKey *rsa.PublicKey) *httptest.Serve switch r.URL.Path { case "/.well-known/openid_configuration": config := map[string]interface{}{ - "issuer": "http://" + r.Host, - "jwks_uri": "http://" + r.Host + "/jwks", + "issuer": "http://" + r.Host, + "jwks_uri": "http://" + r.Host + "/jwks", + "userinfo_endpoint": "http://" + r.Host + "/userinfo", } json.NewEncoder(w).Encode(config) case "/jwks": json.NewEncoder(w).Encode(jwks) + case "/userinfo": + // Mock UserInfo endpoint + authHeader := r.Header.Get("Authorization") + if !strings.HasPrefix(authHeader, "Bearer ") { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"error": "unauthorized"}`)) + return + } + + accessToken := strings.TrimPrefix(authHeader, "Bearer ") + + // Return 401 for explicitly invalid tokens + if accessToken == "invalid-token" { + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"error": "invalid_token"}`)) + return + } + + // Mock user info response based on access token + userInfo := map[string]interface{}{ + "sub": "user123", + "email": "user@example.com", + "name": "Test User", + "groups": []string{"users", "developers"}, + } + + // Customize response based on token + if strings.Contains(accessToken, "admin") { + userInfo["groups"] = []string{"admins"} + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(userInfo) default: http.NotFound(w, r) } diff --git a/weed/iam/policy/policy_store.go b/weed/iam/policy/policy_store.go index 67cceda9b..5e11af25d 100644 --- a/weed/iam/policy/policy_store.go +++ b/weed/iam/policy/policy_store.go @@ -174,7 +174,7 @@ func NewFilerPolicyStore(config map[string]interface{}) (*FilerPolicyStore, erro return nil, fmt.Errorf("filer address is required for FilerPolicyStore") } - glog.V(2).Infof("Initialized FilerPolicyStore with filer %s, basePath %s", + glog.V(2).Infof("Initialized FilerPolicyStore with filer %s, basePath %s", store.filerGrpcAddress, store.basePath) return store, nil