diff --git a/weed/iam/oidc/oidc_provider.go b/weed/iam/oidc/oidc_provider.go index 367cd9c0f..70e69e3fd 100644 --- a/weed/iam/oidc/oidc_provider.go +++ b/weed/iam/oidc/oidc_provider.go @@ -560,7 +560,7 @@ func (p *OIDCProvider) mapClaimsToRolesWithConfig(claims *providers.TokenClaims) func (p *OIDCProvider) getPublicKey(ctx context.Context, kid string) (interface{}, error) { // Fast path: Try to find key in cache with read lock p.jwksMutex.RLock() - cacheValid := p.jwksCache != nil && (p.jwksFetchedAt.IsZero() || time.Since(p.jwksFetchedAt) <= p.jwksTTL) + cacheValid := p.jwksCache != nil && !p.jwksFetchedAt.IsZero() && time.Since(p.jwksFetchedAt) <= p.jwksTTL if cacheValid { // Make a copy of the JWK to avoid use-after-unlock jwkCopy := p.findAndCopyKey(kid) @@ -570,17 +570,38 @@ func (p *OIDCProvider) getPublicKey(ctx context.Context, kid string) (interface{ // Parse key outside lock for better concurrency return p.parseJWK(jwkCopy) } - // Key not found in valid cache - need refresh + + // Key not found in valid cache - force refresh for potential key rotation + // Use distinct singleflight key to ensure refresh happens regardless of TTL + // This handles mid-TTL key rotation where IdP adds new keys + _, err, _ := p.jwksFetchGroup.Do("jwks-force-refresh", func() (interface{}, error) { + return nil, p.fetchJWKS(ctx) + }) + + if err != nil { + return nil, fmt.Errorf("failed to refresh JWKS for missing kid: %v", err) + } + + // Search in refreshed JWKS + p.jwksMutex.RLock() + jwkCopy = p.findAndCopyKey(kid) + p.jwksMutex.RUnlock() + + if jwkCopy != nil { + return p.parseJWK(jwkCopy) + } + + return nil, fmt.Errorf("key with ID %s not found in JWKS after forced refresh", kid) } else { p.jwksMutex.RUnlock() } - // Slow path: Need to fetch/refresh JWKS + // Slow path: Cache expired or nil - need to fetch/refresh JWKS // Use singleflight to ensure only one fetch happens even with many concurrent requests _, err, _ := p.jwksFetchGroup.Do("jwks", func() (interface{}, error) { // Double-check: another goroutine may have just fetched p.jwksMutex.RLock() - stillNeedFetch := p.jwksCache == nil || time.Since(p.jwksFetchedAt) > p.jwksTTL + stillNeedFetch := p.jwksCache == nil || p.jwksFetchedAt.IsZero() || time.Since(p.jwksFetchedAt) > p.jwksTTL p.jwksMutex.RUnlock() if stillNeedFetch { @@ -603,8 +624,8 @@ func (p *OIDCProvider) getPublicKey(ctx context.Context, kid string) (interface{ return p.parseJWK(jwkCopy) } - // Key not found even after refresh - this could be key rotation - return nil, fmt.Errorf("key with ID %s not found in JWKS after refresh", kid) + // Key not found even after refresh + return nil, fmt.Errorf("key with ID %s not found in JWKS", kid) } // findAndCopyKey searches for a key and returns a copy (not a pointer)