|
|
|
@ -326,3 +326,80 @@ func TestMiddlewareFallbackConfigWithMultipleOrigins(t *testing.T) { |
|
|
|
}) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
// TestMiddlewareFallbackWithError tests that real errors (not "no config") don't trigger fallback
|
|
|
|
func TestMiddlewareFallbackWithError(t *testing.T) { |
|
|
|
fallbackConfig := &CORSConfiguration{ |
|
|
|
CORSRules: []CORSRule{ |
|
|
|
{ |
|
|
|
AllowedOrigins: []string{"*"}, |
|
|
|
AllowedMethods: []string{"GET", "POST"}, |
|
|
|
AllowedHeaders: []string{"*"}, |
|
|
|
}, |
|
|
|
}, |
|
|
|
} |
|
|
|
|
|
|
|
tests := []struct { |
|
|
|
name string |
|
|
|
errCode s3err.ErrorCode |
|
|
|
expectedOriginHeader string |
|
|
|
description string |
|
|
|
}{ |
|
|
|
{ |
|
|
|
name: "ErrAccessDenied should not trigger fallback", |
|
|
|
errCode: s3err.ErrAccessDenied, |
|
|
|
expectedOriginHeader: "", |
|
|
|
description: "Access denied errors should not expose CORS headers", |
|
|
|
}, |
|
|
|
{ |
|
|
|
name: "ErrInternalError should not trigger fallback", |
|
|
|
errCode: s3err.ErrInternalError, |
|
|
|
expectedOriginHeader: "", |
|
|
|
description: "Internal errors should not expose CORS headers", |
|
|
|
}, |
|
|
|
{ |
|
|
|
name: "ErrNoSuchBucket should not trigger fallback", |
|
|
|
errCode: s3err.ErrNoSuchBucket, |
|
|
|
expectedOriginHeader: "", |
|
|
|
description: "Bucket not found errors should not expose CORS headers", |
|
|
|
}, |
|
|
|
{ |
|
|
|
name: "ErrNoSuchCORSConfiguration should trigger fallback", |
|
|
|
errCode: s3err.ErrNoSuchCORSConfiguration, |
|
|
|
expectedOriginHeader: "https://example.com", |
|
|
|
description: "Explicit no CORS config should use fallback", |
|
|
|
}, |
|
|
|
} |
|
|
|
|
|
|
|
for _, tt := range tests { |
|
|
|
t.Run(tt.name, func(t *testing.T) { |
|
|
|
bucketChecker := &mockBucketChecker{bucketExists: true} |
|
|
|
configGetter := &mockCORSConfigGetter{ |
|
|
|
config: nil, |
|
|
|
errCode: tt.errCode, |
|
|
|
} |
|
|
|
|
|
|
|
middleware := NewMiddleware(bucketChecker, configGetter, fallbackConfig) |
|
|
|
|
|
|
|
req := httptest.NewRequest("GET", "/testbucket/testobject", nil) |
|
|
|
req = mux.SetURLVars(req, map[string]string{ |
|
|
|
"bucket": "testbucket", |
|
|
|
"object": "testobject", |
|
|
|
}) |
|
|
|
req.Header.Set("Origin", "https://example.com") |
|
|
|
|
|
|
|
w := httptest.NewRecorder() |
|
|
|
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|
|
|
w.WriteHeader(http.StatusOK) |
|
|
|
}) |
|
|
|
|
|
|
|
middleware.Handler(nextHandler).ServeHTTP(w, req) |
|
|
|
|
|
|
|
actualOrigin := w.Header().Get("Access-Control-Allow-Origin") |
|
|
|
if actualOrigin != tt.expectedOriginHeader { |
|
|
|
t.Errorf("%s: expected Access-Control-Allow-Origin='%s', got '%s'", |
|
|
|
tt.description, tt.expectedOriginHeader, actualOrigin) |
|
|
|
} |
|
|
|
}) |
|
|
|
} |
|
|
|
} |