Browse Source

fix: Forces refresh on kid-miss

pull/8311/head
Yannick Goetschel 4 weeks ago
parent
commit
ae3b61f486
  1. 33
      weed/iam/oidc/oidc_provider.go

33
weed/iam/oidc/oidc_provider.go

@ -560,7 +560,7 @@ func (p *OIDCProvider) mapClaimsToRolesWithConfig(claims *providers.TokenClaims)
func (p *OIDCProvider) getPublicKey(ctx context.Context, kid string) (interface{}, error) {
// Fast path: Try to find key in cache with read lock
p.jwksMutex.RLock()
cacheValid := p.jwksCache != nil && (p.jwksFetchedAt.IsZero() || time.Since(p.jwksFetchedAt) <= p.jwksTTL)
cacheValid := p.jwksCache != nil && !p.jwksFetchedAt.IsZero() && time.Since(p.jwksFetchedAt) <= p.jwksTTL
if cacheValid {
// Make a copy of the JWK to avoid use-after-unlock
jwkCopy := p.findAndCopyKey(kid)
@ -570,17 +570,38 @@ func (p *OIDCProvider) getPublicKey(ctx context.Context, kid string) (interface{
// Parse key outside lock for better concurrency
return p.parseJWK(jwkCopy)
}
// Key not found in valid cache - need refresh
// Key not found in valid cache - force refresh for potential key rotation
// Use distinct singleflight key to ensure refresh happens regardless of TTL
// This handles mid-TTL key rotation where IdP adds new keys
_, err, _ := p.jwksFetchGroup.Do("jwks-force-refresh", func() (interface{}, error) {
return nil, p.fetchJWKS(ctx)
})
if err != nil {
return nil, fmt.Errorf("failed to refresh JWKS for missing kid: %v", err)
}
// Search in refreshed JWKS
p.jwksMutex.RLock()
jwkCopy = p.findAndCopyKey(kid)
p.jwksMutex.RUnlock()
if jwkCopy != nil {
return p.parseJWK(jwkCopy)
}
return nil, fmt.Errorf("key with ID %s not found in JWKS after forced refresh", kid)
} else {
p.jwksMutex.RUnlock()
}
// Slow path: Need to fetch/refresh JWKS
// Slow path: Cache expired or nil - need to fetch/refresh JWKS
// Use singleflight to ensure only one fetch happens even with many concurrent requests
_, err, _ := p.jwksFetchGroup.Do("jwks", func() (interface{}, error) {
// Double-check: another goroutine may have just fetched
p.jwksMutex.RLock()
stillNeedFetch := p.jwksCache == nil || time.Since(p.jwksFetchedAt) > p.jwksTTL
stillNeedFetch := p.jwksCache == nil || p.jwksFetchedAt.IsZero() || time.Since(p.jwksFetchedAt) > p.jwksTTL
p.jwksMutex.RUnlock()
if stillNeedFetch {
@ -603,8 +624,8 @@ func (p *OIDCProvider) getPublicKey(ctx context.Context, kid string) (interface{
return p.parseJWK(jwkCopy)
}
// Key not found even after refresh - this could be key rotation
return nil, fmt.Errorf("key with ID %s not found in JWKS after refresh", kid)
// Key not found even after refresh
return nil, fmt.Errorf("key with ID %s not found in JWKS", kid)
}
// findAndCopyKey searches for a key and returns a copy (not a pointer)

Loading…
Cancel
Save