diff --git a/weed/s3api/cors/middleware.go b/weed/s3api/cors/middleware.go index 49cdaae5c..a0e5a362c 100644 --- a/weed/s3api/cors/middleware.go +++ b/weed/s3api/cors/middleware.go @@ -34,6 +34,45 @@ func NewMiddleware(storage *Storage, bucketChecker BucketChecker, corsConfigGett } } +// evaluateCORSRequest performs the common CORS request evaluation logic +func (m *Middleware) evaluateCORSRequest(w http.ResponseWriter, r *http.Request) (*CORSResponse, bool) { + // Parse CORS request + corsReq := ParseRequest(r) + if corsReq.Origin == "" { + // Not a CORS request + return nil, false + } + + // Extract bucket from request + bucket, _ := s3_constants.GetBucketAndObject(r) + if bucket == "" { + return nil, false + } + + // Check if bucket exists + if err := m.bucketChecker.CheckBucket(r, bucket); err != s3err.ErrNone { + s3err.WriteErrorResponse(w, r, err) + return nil, true // Return true to indicate response was written + } + + // Load CORS configuration from cache + config, errCode := m.corsConfigGetter.GetCORSConfiguration(bucket) + if errCode != s3err.ErrNone || config == nil { + s3err.WriteErrorResponse(w, r, s3err.ErrAccessDenied) + return nil, true // Return true to indicate response was written + } + + // Evaluate CORS request + corsResp, err := EvaluateRequest(config, corsReq) + if err != nil { + glog.V(3).Infof("CORS evaluation failed for bucket %s: %v", bucket, err) + s3err.WriteErrorResponse(w, r, s3err.ErrAccessDenied) + return nil, true // Return true to indicate response was written + } + + return corsResp, false +} + // Handler returns the CORS middleware handler func (m *Middleware) Handler(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -105,45 +144,19 @@ func (m *Middleware) Handler(next http.Handler) http.Handler { // HandleOptionsRequest handles OPTIONS requests for CORS preflight func (m *Middleware) HandleOptionsRequest(w http.ResponseWriter, r *http.Request) { - // This is handled by the CORS middleware, but we need a specific OPTIONS handler - // for the router to recognize OPTIONS requests - - // Parse CORS request - corsReq := ParseRequest(r) - if corsReq.Origin == "" { - // Not a CORS request - w.WriteHeader(http.StatusOK) + // Use the common evaluation logic + corsResp, responseWritten := m.evaluateCORSRequest(w, r) + if responseWritten { + // Response was already written (error case) return } - // Extract bucket from request - bucket, _ := s3_constants.GetBucketAndObject(r) - if bucket == "" { + if corsResp == nil { + // Not a CORS request w.WriteHeader(http.StatusOK) return } - // Check if bucket exists - if err := m.bucketChecker.CheckBucket(r, bucket); err != s3err.ErrNone { - s3err.WriteErrorResponse(w, r, err) - return - } - - // Load CORS configuration from cache - config, errCode := m.corsConfigGetter.GetCORSConfiguration(bucket) - if errCode != s3err.ErrNone || config == nil { - s3err.WriteErrorResponse(w, r, s3err.ErrAccessDenied) - return - } - - // Evaluate CORS request - corsResp, err := EvaluateRequest(config, corsReq) - if err != nil { - glog.V(3).Infof("CORS evaluation failed for bucket %s: %v", bucket, err) - s3err.WriteErrorResponse(w, r, s3err.ErrAccessDenied) - return - } - // Apply CORS headers and return success ApplyHeaders(w, corsResp) w.WriteHeader(http.StatusOK) diff --git a/weed/s3api/s3api_server.go b/weed/s3api/s3api_server.go index f83775480..9e9a72c9f 100644 --- a/weed/s3api/s3api_server.go +++ b/weed/s3api/s3api_server.go @@ -121,6 +121,34 @@ func NewS3ApiServerWithStore(router *mux.Router, option *S3ApiServerOption, expl return s3ApiServer, nil } +// handleCORSOriginValidation handles the common CORS origin validation logic +func (s3a *S3ApiServer) handleCORSOriginValidation(w http.ResponseWriter, r *http.Request) bool { + origin := r.Header.Get("Origin") + if origin != "" { + if len(s3a.option.AllowedOrigins) == 0 || s3a.option.AllowedOrigins[0] == "*" { + origin = "*" + } else { + originFound := false + for _, allowedOrigin := range s3a.option.AllowedOrigins { + if origin == allowedOrigin { + originFound = true + break + } + } + if !originFound { + writeFailureResponse(w, r, http.StatusForbidden) + return false + } + } + } + + w.Header().Set("Access-Control-Allow-Origin", origin) + w.Header().Set("Access-Control-Expose-Headers", "*") + w.Header().Set("Access-Control-Allow-Methods", "*") + w.Header().Set("Access-Control-Allow-Headers", "*") + return true +} + func (s3a *S3ApiServer) registerRouter(router *mux.Router) { // API Router apiRouter := router.PathPrefix("/").Subrouter() @@ -326,29 +354,9 @@ func (s3a *S3ApiServer) registerRouter(router *mux.Router) { return } - origin := r.Header.Get("Origin") - if origin != "" { - if len(s3a.option.AllowedOrigins) == 0 || s3a.option.AllowedOrigins[0] == "*" { - origin = "*" - } else { - originFound := false - for _, allowedOrigin := range s3a.option.AllowedOrigins { - if origin == allowedOrigin { - originFound = true - } - } - if !originFound { - writeFailureResponse(w, r, http.StatusForbidden) - return - } - } + if s3a.handleCORSOriginValidation(w, r) { + writeSuccessResponseEmpty(w, r) } - - w.Header().Set("Access-Control-Allow-Origin", origin) - w.Header().Set("Access-Control-Expose-Headers", "*") - w.Header().Set("Access-Control-Allow-Methods", "*") - w.Header().Set("Access-Control-Allow-Headers", "*") - writeSuccessResponseEmpty(w, r) }) // ListBuckets