Browse Source

fix tests

pull/6987/head
chrislu 3 months ago
parent
commit
71aface936
  1. 75
      weed/s3api/cors/cors.go
  2. 7
      weed/s3api/cors/cors_test.go

75
weed/s3api/cors/cors.go

@ -181,10 +181,18 @@ func EvaluateRequest(config *CORSConfiguration, corsReq *CORSRequest) (*CORSResp
return nil, fmt.Errorf("origin header is required for CORS requests")
}
// Find the first matching rule
// Find the first rule that matches the origin
for _, rule := range config.CORSRules {
if matchesRule(&rule, corsReq) {
return buildResponse(&rule, corsReq), nil
if matchesOrigin(rule.AllowedOrigins, corsReq.Origin) {
// For preflight requests, we need more detailed validation
if corsReq.IsPreflightRequest {
return buildPreflightResponse(&rule, corsReq), nil
} else {
// For actual requests, check method
if contains(rule.AllowedMethods, corsReq.Method) {
return buildResponse(&rule, corsReq), nil
}
}
}
}
@ -328,6 +336,67 @@ func matchesHeader(allowedHeaders []string, header string) bool {
return false
}
// buildPreflightResponse builds a CORS response for preflight requests
// This function allows partial matches - origin can match while methods/headers may not
func buildPreflightResponse(rule *CORSRule, corsReq *CORSRequest) *CORSResponse {
response := &CORSResponse{
AllowOrigin: corsReq.Origin,
}
// Check if the requested method is allowed
methodAllowed := corsReq.AccessControlRequestMethod == "" || contains(rule.AllowedMethods, corsReq.AccessControlRequestMethod)
// Check requested headers
var allowedRequestHeaders []string
allHeadersAllowed := true
if len(corsReq.AccessControlRequestHeaders) > 0 {
// Check if wildcard is allowed
hasWildcard := false
for _, header := range rule.AllowedHeaders {
if header == "*" {
hasWildcard = true
break
}
}
if hasWildcard {
// All requested headers are allowed with wildcard
allowedRequestHeaders = corsReq.AccessControlRequestHeaders
} else {
// Check each requested header individually
for _, requestedHeader := range corsReq.AccessControlRequestHeaders {
if matchesHeader(rule.AllowedHeaders, requestedHeader) {
allowedRequestHeaders = append(allowedRequestHeaders, requestedHeader)
} else {
allHeadersAllowed = false
}
}
}
}
// Only set method and header info if both method and ALL headers are allowed
if methodAllowed && allHeadersAllowed {
response.AllowMethods = strings.Join(rule.AllowedMethods, ", ")
if len(allowedRequestHeaders) > 0 {
response.AllowHeaders = strings.Join(allowedRequestHeaders, ", ")
}
// Set exposed headers
if len(rule.ExposeHeaders) > 0 {
response.ExposeHeaders = strings.Join(rule.ExposeHeaders, ", ")
}
// Set max age
if rule.MaxAgeSeconds != nil {
response.MaxAge = strconv.Itoa(*rule.MaxAgeSeconds)
}
}
return response
}
// buildResponse builds a CORS response from a matching rule
func buildResponse(rule *CORSRule, corsReq *CORSRequest) *CORSResponse {
response := &CORSResponse{

7
weed/s3api/cors/cors_test.go

@ -422,8 +422,11 @@ func TestEvaluateRequest(t *testing.T) {
AccessControlRequestMethod: "POST",
AccessControlRequestHeaders: []string{"Authorization"},
},
want: nil,
wantErr: true,
want: &CORSResponse{
AllowOrigin: "http://example.com",
// No AllowMethods or AllowHeaders because the requested header is forbidden
},
wantErr: false,
},
{
name: "request without origin",

Loading…
Cancel
Save