Browse Source

add fallback for cors

pull/7404/head
chrislu 1 month ago
parent
commit
09e7ca5097
  1. 40
      weed/s3api/cors/middleware.go
  2. 328
      weed/s3api/cors/middleware_test.go
  3. 28
      weed/s3api/s3api_bucket_cors_handlers.go

40
weed/s3api/cors/middleware.go

@ -22,13 +22,15 @@ type CORSConfigGetter interface {
type Middleware struct {
bucketChecker BucketChecker
corsConfigGetter CORSConfigGetter
fallbackConfig *CORSConfiguration // Global CORS configuration as fallback
}
// NewMiddleware creates a new CORS middleware instance
func NewMiddleware(bucketChecker BucketChecker, corsConfigGetter CORSConfigGetter) *Middleware {
// NewMiddleware creates a new CORS middleware instance with optional global fallback config
func NewMiddleware(bucketChecker BucketChecker, corsConfigGetter CORSConfigGetter, fallbackConfig *CORSConfiguration) *Middleware {
return &Middleware{
bucketChecker: bucketChecker,
corsConfigGetter: corsConfigGetter,
fallbackConfig: fallbackConfig,
}
}
@ -61,15 +63,20 @@ func (m *Middleware) Handler(next http.Handler) http.Handler {
// Load CORS configuration from cache
config, errCode := m.corsConfigGetter.GetCORSConfiguration(bucket)
if errCode != s3err.ErrNone || config == nil {
// No CORS configuration, handle based on request type
if corsReq.IsPreflightRequest {
// Preflight request without CORS config should fail
s3err.WriteErrorResponse(w, r, s3err.ErrAccessDenied)
// No bucket-level CORS configuration, try fallback (global config)
if m.fallbackConfig != nil {
config = m.fallbackConfig
} else {
// No CORS configuration at all, handle based on request type
if corsReq.IsPreflightRequest {
// Preflight request without CORS config should fail
s3err.WriteErrorResponse(w, r, s3err.ErrAccessDenied)
return
}
// Non-preflight request, continue normally
next.ServeHTTP(w, r)
return
}
// Non-preflight request, continue normally
next.ServeHTTP(w, r)
return
}
// Evaluate CORS request
@ -129,13 +136,18 @@ func (m *Middleware) HandleOptionsRequest(w http.ResponseWriter, r *http.Request
// Load CORS configuration from cache
config, errCode := m.corsConfigGetter.GetCORSConfiguration(bucket)
if errCode != s3err.ErrNone || config == nil {
// No CORS configuration for OPTIONS request should return access denied
if corsReq.IsPreflightRequest {
s3err.WriteErrorResponse(w, r, s3err.ErrAccessDenied)
// No bucket-level CORS configuration, try fallback (global config)
if m.fallbackConfig != nil {
config = m.fallbackConfig
} else {
// No CORS configuration at all for OPTIONS request should return access denied
if corsReq.IsPreflightRequest {
s3err.WriteErrorResponse(w, r, s3err.ErrAccessDenied)
return
}
w.WriteHeader(http.StatusOK)
return
}
w.WriteHeader(http.StatusOK)
return
}
// Evaluate CORS request

328
weed/s3api/cors/middleware_test.go

@ -0,0 +1,328 @@
package cors
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gorilla/mux"
"github.com/seaweedfs/seaweedfs/weed/s3api/s3err"
)
// Mock implementations for testing
type mockBucketChecker struct {
bucketExists bool
}
func (m *mockBucketChecker) CheckBucket(r *http.Request, bucket string) s3err.ErrorCode {
if m.bucketExists {
return s3err.ErrNone
}
return s3err.ErrNoSuchBucket
}
type mockCORSConfigGetter struct {
config *CORSConfiguration
errCode s3err.ErrorCode
}
func (m *mockCORSConfigGetter) GetCORSConfiguration(bucket string) (*CORSConfiguration, s3err.ErrorCode) {
return m.config, m.errCode
}
// TestMiddlewareFallbackConfig tests that the middleware uses fallback config when bucket-level config is not available
func TestMiddlewareFallbackConfig(t *testing.T) {
tests := []struct {
name string
bucketConfig *CORSConfiguration
fallbackConfig *CORSConfiguration
requestOrigin string
requestMethod string
isOptions bool
expectedStatus int
expectedOriginHeader string
description string
}{
{
name: "No bucket config, fallback to global config with wildcard",
bucketConfig: nil,
fallbackConfig: &CORSConfiguration{
CORSRules: []CORSRule{
{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "HEAD"},
AllowedHeaders: []string{"*"},
},
},
},
requestOrigin: "https://example.com",
requestMethod: "GET",
isOptions: false,
expectedStatus: http.StatusOK,
expectedOriginHeader: "https://example.com",
description: "Should use fallback global config when no bucket config exists",
},
{
name: "No bucket config, fallback to global config with specific origin",
bucketConfig: nil,
fallbackConfig: &CORSConfiguration{
CORSRules: []CORSRule{
{
AllowedOrigins: []string{"https://example.com"},
AllowedMethods: []string{"GET", "POST"},
AllowedHeaders: []string{"*"},
},
},
},
requestOrigin: "https://example.com",
requestMethod: "GET",
isOptions: false,
expectedStatus: http.StatusOK,
expectedOriginHeader: "https://example.com",
description: "Should use fallback config with specific origin match",
},
{
name: "No bucket config, fallback rejects non-matching origin",
bucketConfig: nil,
fallbackConfig: &CORSConfiguration{
CORSRules: []CORSRule{
{
AllowedOrigins: []string{"https://allowed.com"},
AllowedMethods: []string{"GET"},
AllowedHeaders: []string{"*"},
},
},
},
requestOrigin: "https://notallowed.com",
requestMethod: "GET",
isOptions: false,
expectedStatus: http.StatusOK,
expectedOriginHeader: "",
description: "Should not apply CORS headers when origin doesn't match fallback config",
},
{
name: "Bucket config takes precedence over fallback",
bucketConfig: &CORSConfiguration{
CORSRules: []CORSRule{
{
AllowedOrigins: []string{"https://bucket-specific.com"},
AllowedMethods: []string{"GET"},
AllowedHeaders: []string{"*"},
},
},
},
fallbackConfig: &CORSConfiguration{
CORSRules: []CORSRule{
{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "POST"},
AllowedHeaders: []string{"*"},
},
},
},
requestOrigin: "https://bucket-specific.com",
requestMethod: "GET",
isOptions: false,
expectedStatus: http.StatusOK,
expectedOriginHeader: "https://bucket-specific.com",
description: "Bucket-level config should be used instead of fallback",
},
{
name: "Bucket config rejects, even though fallback would allow",
bucketConfig: &CORSConfiguration{
CORSRules: []CORSRule{
{
AllowedOrigins: []string{"https://restricted.com"},
AllowedMethods: []string{"GET"},
AllowedHeaders: []string{"*"},
},
},
},
fallbackConfig: &CORSConfiguration{
CORSRules: []CORSRule{
{
AllowedOrigins: []string{"*"},
AllowedMethods: []string{"GET", "POST"},
AllowedHeaders: []string{"*"},
},
},
},
requestOrigin: "https://example.com",
requestMethod: "GET",
isOptions: false,
expectedStatus: http.StatusOK,
expectedOriginHeader: "",
description: "Bucket-level config is authoritative, fallback should not apply",
},
{
name: "No config at all, no CORS headers",
bucketConfig: nil,
fallbackConfig: nil,
requestOrigin: "https://example.com",
requestMethod: "GET",
isOptions: false,
expectedStatus: http.StatusOK,
expectedOriginHeader: "",
description: "Without any config, no CORS headers should be applied",
},
{
name: "OPTIONS preflight with fallback config",
bucketConfig: nil,
fallbackConfig: &CORSConfiguration{
CORSRules: []CORSRule{
{
AllowedOrigins: []string{"https://example.com"},
AllowedMethods: []string{"GET", "POST"},
AllowedHeaders: []string{"*"},
},
},
},
requestOrigin: "https://example.com",
requestMethod: "OPTIONS",
isOptions: true,
expectedStatus: http.StatusOK,
expectedOriginHeader: "https://example.com",
description: "OPTIONS preflight should work with fallback config",
},
{
name: "OPTIONS preflight without any config should fail",
bucketConfig: nil,
fallbackConfig: nil,
requestOrigin: "https://example.com",
requestMethod: "OPTIONS",
isOptions: true,
expectedStatus: http.StatusForbidden,
expectedOriginHeader: "",
description: "OPTIONS preflight without config should return 403",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Setup mocks
bucketChecker := &mockBucketChecker{bucketExists: true}
configGetter := &mockCORSConfigGetter{
config: tt.bucketConfig,
errCode: s3err.ErrNone,
}
// Create middleware with optional fallback
middleware := NewMiddleware(bucketChecker, configGetter, tt.fallbackConfig)
// Create request with mux variables
req := httptest.NewRequest(tt.requestMethod, "/testbucket/testobject", nil)
req = mux.SetURLVars(req, map[string]string{
"bucket": "testbucket",
"object": "testobject",
})
if tt.requestOrigin != "" {
req.Header.Set("Origin", tt.requestOrigin)
}
if tt.isOptions {
req.Header.Set("Access-Control-Request-Method", "GET")
}
// Create response recorder
w := httptest.NewRecorder()
// Create a simple handler that returns 200 OK
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Execute middleware
if tt.isOptions {
middleware.HandleOptionsRequest(w, req)
} else {
middleware.Handler(nextHandler).ServeHTTP(w, req)
}
// Check status code
if w.Code != tt.expectedStatus {
t.Errorf("%s: expected status %d, got %d", tt.description, tt.expectedStatus, w.Code)
}
// Check CORS header
actualOrigin := w.Header().Get("Access-Control-Allow-Origin")
if actualOrigin != tt.expectedOriginHeader {
t.Errorf("%s: expected Access-Control-Allow-Origin='%s', got '%s'",
tt.description, tt.expectedOriginHeader, actualOrigin)
}
})
}
}
// TestMiddlewareFallbackConfigWithMultipleOrigins tests fallback with multiple allowed origins
func TestMiddlewareFallbackConfigWithMultipleOrigins(t *testing.T) {
fallbackConfig := &CORSConfiguration{
CORSRules: []CORSRule{
{
AllowedOrigins: []string{"https://example1.com", "https://example2.com"},
AllowedMethods: []string{"GET", "POST"},
AllowedHeaders: []string{"*"},
},
},
}
bucketChecker := &mockBucketChecker{bucketExists: true}
configGetter := &mockCORSConfigGetter{
config: nil, // No bucket config
errCode: s3err.ErrNone,
}
middleware := NewMiddleware(bucketChecker, configGetter, fallbackConfig)
tests := []struct {
origin string
shouldMatch bool
description string
}{
{
origin: "https://example1.com",
shouldMatch: true,
description: "First allowed origin should match",
},
{
origin: "https://example2.com",
shouldMatch: true,
description: "Second allowed origin should match",
},
{
origin: "https://example3.com",
shouldMatch: false,
description: "Non-allowed origin should not match",
},
}
for _, tt := range tests {
t.Run(tt.description, func(t *testing.T) {
req := httptest.NewRequest("GET", "/testbucket/testobject", nil)
req = mux.SetURLVars(req, map[string]string{
"bucket": "testbucket",
"object": "testobject",
})
req.Header.Set("Origin", tt.origin)
w := httptest.NewRecorder()
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
middleware.Handler(nextHandler).ServeHTTP(w, req)
actualOrigin := w.Header().Get("Access-Control-Allow-Origin")
if tt.shouldMatch {
if actualOrigin != tt.origin {
t.Errorf("%s: expected Access-Control-Allow-Origin='%s', got '%s'",
tt.description, tt.origin, actualOrigin)
}
} else {
if actualOrigin != "" {
t.Errorf("%s: expected no Access-Control-Allow-Origin header, got '%s'",
tt.description, actualOrigin)
}
}
})
}
}

28
weed/s3api/s3api_bucket_cors_handlers.go

@ -28,12 +28,36 @@ func (g *S3CORSConfigGetter) GetCORSConfiguration(bucket string) (*cors.CORSConf
return g.server.getCORSConfiguration(bucket)
}
// getCORSMiddleware returns a CORS middleware instance with caching
// getCORSMiddleware returns a CORS middleware instance with global fallback config
func (s3a *S3ApiServer) getCORSMiddleware() *cors.Middleware {
bucketChecker := &S3BucketChecker{server: s3a}
corsConfigGetter := &S3CORSConfigGetter{server: s3a}
return cors.NewMiddleware(bucketChecker, corsConfigGetter)
// Create fallback CORS configuration from global AllowedOrigins setting
fallbackConfig := s3a.createFallbackCORSConfig()
return cors.NewMiddleware(bucketChecker, corsConfigGetter, fallbackConfig)
}
// createFallbackCORSConfig creates a CORS configuration from global AllowedOrigins
func (s3a *S3ApiServer) createFallbackCORSConfig() *cors.CORSConfiguration {
if len(s3a.option.AllowedOrigins) == 0 {
return nil
}
// Create a permissive CORS rule based on global allowed origins
// This matches the behavior of handleCORSOriginValidation
rule := cors.CORSRule{
AllowedOrigins: s3a.option.AllowedOrigins,
AllowedMethods: []string{"GET", "PUT", "POST", "DELETE", "HEAD"},
AllowedHeaders: []string{"*"},
ExposeHeaders: []string{"ETag", "Content-Length", "Content-Type"},
MaxAgeSeconds: nil, // No max age by default
}
return &cors.CORSConfiguration{
CORSRules: []cors.CORSRule{rule},
}
}
// GetBucketCorsHandler handles Get bucket CORS configuration

Loading…
Cancel
Save