|
|
|
@ -560,7 +560,7 @@ func (p *OIDCProvider) mapClaimsToRolesWithConfig(claims *providers.TokenClaims) |
|
|
|
func (p *OIDCProvider) getPublicKey(ctx context.Context, kid string) (interface{}, error) { |
|
|
|
// Fast path: Try to find key in cache with read lock
|
|
|
|
p.jwksMutex.RLock() |
|
|
|
cacheValid := p.jwksCache != nil && (p.jwksFetchedAt.IsZero() || time.Since(p.jwksFetchedAt) <= p.jwksTTL) |
|
|
|
cacheValid := p.jwksCache != nil && !p.jwksFetchedAt.IsZero() && time.Since(p.jwksFetchedAt) <= p.jwksTTL |
|
|
|
if cacheValid { |
|
|
|
// Make a copy of the JWK to avoid use-after-unlock
|
|
|
|
jwkCopy := p.findAndCopyKey(kid) |
|
|
|
@ -570,17 +570,38 @@ func (p *OIDCProvider) getPublicKey(ctx context.Context, kid string) (interface{ |
|
|
|
// Parse key outside lock for better concurrency
|
|
|
|
return p.parseJWK(jwkCopy) |
|
|
|
} |
|
|
|
// Key not found in valid cache - need refresh
|
|
|
|
|
|
|
|
// Key not found in valid cache - force refresh for potential key rotation
|
|
|
|
// Use distinct singleflight key to ensure refresh happens regardless of TTL
|
|
|
|
// This handles mid-TTL key rotation where IdP adds new keys
|
|
|
|
_, err, _ := p.jwksFetchGroup.Do("jwks-force-refresh", func() (interface{}, error) { |
|
|
|
return nil, p.fetchJWKS(ctx) |
|
|
|
}) |
|
|
|
|
|
|
|
if err != nil { |
|
|
|
return nil, fmt.Errorf("failed to refresh JWKS for missing kid: %v", err) |
|
|
|
} |
|
|
|
|
|
|
|
// Search in refreshed JWKS
|
|
|
|
p.jwksMutex.RLock() |
|
|
|
jwkCopy = p.findAndCopyKey(kid) |
|
|
|
p.jwksMutex.RUnlock() |
|
|
|
|
|
|
|
if jwkCopy != nil { |
|
|
|
return p.parseJWK(jwkCopy) |
|
|
|
} |
|
|
|
|
|
|
|
return nil, fmt.Errorf("key with ID %s not found in JWKS after forced refresh", kid) |
|
|
|
} else { |
|
|
|
p.jwksMutex.RUnlock() |
|
|
|
} |
|
|
|
|
|
|
|
// Slow path: Need to fetch/refresh JWKS
|
|
|
|
// Slow path: Cache expired or nil - need to fetch/refresh JWKS
|
|
|
|
// Use singleflight to ensure only one fetch happens even with many concurrent requests
|
|
|
|
_, err, _ := p.jwksFetchGroup.Do("jwks", func() (interface{}, error) { |
|
|
|
// Double-check: another goroutine may have just fetched
|
|
|
|
p.jwksMutex.RLock() |
|
|
|
stillNeedFetch := p.jwksCache == nil || time.Since(p.jwksFetchedAt) > p.jwksTTL |
|
|
|
stillNeedFetch := p.jwksCache == nil || p.jwksFetchedAt.IsZero() || time.Since(p.jwksFetchedAt) > p.jwksTTL |
|
|
|
p.jwksMutex.RUnlock() |
|
|
|
|
|
|
|
if stillNeedFetch { |
|
|
|
@ -603,8 +624,8 @@ func (p *OIDCProvider) getPublicKey(ctx context.Context, kid string) (interface{ |
|
|
|
return p.parseJWK(jwkCopy) |
|
|
|
} |
|
|
|
|
|
|
|
// Key not found even after refresh - this could be key rotation
|
|
|
|
return nil, fmt.Errorf("key with ID %s not found in JWKS after refresh", kid) |
|
|
|
// Key not found even after refresh
|
|
|
|
return nil, fmt.Errorf("key with ID %s not found in JWKS", kid) |
|
|
|
} |
|
|
|
|
|
|
|
// findAndCopyKey searches for a key and returns a copy (not a pointer)
|
|
|
|
|