You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
385 lines
10 KiB
385 lines
10 KiB
package cors
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
)
|
|
|
|
// CORSRule represents a single CORS rule
|
|
type CORSRule struct {
|
|
AllowedHeaders []string `xml:"AllowedHeader,omitempty" json:"AllowedHeaders,omitempty"`
|
|
AllowedMethods []string `xml:"AllowedMethod" json:"AllowedMethods"`
|
|
AllowedOrigins []string `xml:"AllowedOrigin" json:"AllowedOrigins"`
|
|
ExposeHeaders []string `xml:"ExposeHeader,omitempty" json:"ExposeHeaders,omitempty"`
|
|
MaxAgeSeconds *int `xml:"MaxAgeSeconds,omitempty" json:"MaxAgeSeconds,omitempty"`
|
|
ID string `xml:"ID,omitempty" json:"ID,omitempty"`
|
|
}
|
|
|
|
// CORSConfiguration represents the CORS configuration for a bucket
|
|
type CORSConfiguration struct {
|
|
CORSRules []CORSRule `xml:"CORSRule" json:"CORSRules"`
|
|
}
|
|
|
|
// CORSRequest represents a CORS request
|
|
type CORSRequest struct {
|
|
Origin string
|
|
Method string
|
|
RequestHeaders []string
|
|
IsPreflightRequest bool
|
|
AccessControlRequestMethod string
|
|
AccessControlRequestHeaders []string
|
|
}
|
|
|
|
// CORSResponse represents the response for a CORS request
|
|
type CORSResponse struct {
|
|
AllowOrigin string
|
|
AllowMethods string
|
|
AllowHeaders string
|
|
ExposeHeaders string
|
|
MaxAge string
|
|
AllowCredentials bool
|
|
}
|
|
|
|
// ValidateConfiguration validates a CORS configuration
|
|
func ValidateConfiguration(config *CORSConfiguration) error {
|
|
if config == nil {
|
|
return fmt.Errorf("CORS configuration cannot be nil")
|
|
}
|
|
|
|
if len(config.CORSRules) == 0 {
|
|
return fmt.Errorf("CORS configuration must have at least one rule")
|
|
}
|
|
|
|
if len(config.CORSRules) > 100 {
|
|
return fmt.Errorf("CORS configuration cannot have more than 100 rules")
|
|
}
|
|
|
|
for i, rule := range config.CORSRules {
|
|
if err := validateRule(&rule); err != nil {
|
|
return fmt.Errorf("invalid CORS rule at index %d: %v", i, err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// ParseRequest parses an HTTP request to extract CORS information
|
|
func ParseRequest(r *http.Request) *CORSRequest {
|
|
corsReq := &CORSRequest{
|
|
Origin: r.Header.Get("Origin"),
|
|
Method: r.Method,
|
|
}
|
|
|
|
// Check if this is a preflight request
|
|
if r.Method == "OPTIONS" {
|
|
corsReq.IsPreflightRequest = true
|
|
corsReq.AccessControlRequestMethod = r.Header.Get("Access-Control-Request-Method")
|
|
|
|
if headers := r.Header.Get("Access-Control-Request-Headers"); headers != "" {
|
|
corsReq.AccessControlRequestHeaders = strings.Split(headers, ",")
|
|
for i := range corsReq.AccessControlRequestHeaders {
|
|
corsReq.AccessControlRequestHeaders[i] = strings.TrimSpace(corsReq.AccessControlRequestHeaders[i])
|
|
}
|
|
}
|
|
}
|
|
|
|
return corsReq
|
|
}
|
|
|
|
// validateRule validates a single CORS rule
|
|
func validateRule(rule *CORSRule) error {
|
|
if len(rule.AllowedMethods) == 0 {
|
|
return fmt.Errorf("AllowedMethods cannot be empty")
|
|
}
|
|
|
|
if len(rule.AllowedOrigins) == 0 {
|
|
return fmt.Errorf("AllowedOrigins cannot be empty")
|
|
}
|
|
|
|
// Validate allowed methods
|
|
validMethods := map[string]bool{
|
|
"GET": true,
|
|
"PUT": true,
|
|
"POST": true,
|
|
"DELETE": true,
|
|
"HEAD": true,
|
|
}
|
|
|
|
for _, method := range rule.AllowedMethods {
|
|
if !validMethods[method] {
|
|
return fmt.Errorf("invalid HTTP method: %s", method)
|
|
}
|
|
}
|
|
|
|
// Validate origins
|
|
for _, origin := range rule.AllowedOrigins {
|
|
if origin == "*" {
|
|
continue
|
|
}
|
|
if err := validateOrigin(origin); err != nil {
|
|
return fmt.Errorf("invalid origin %s: %v", origin, err)
|
|
}
|
|
}
|
|
|
|
// Validate MaxAgeSeconds
|
|
if rule.MaxAgeSeconds != nil && *rule.MaxAgeSeconds < 0 {
|
|
return fmt.Errorf("MaxAgeSeconds cannot be negative")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// validateOrigin validates an origin string
|
|
func validateOrigin(origin string) error {
|
|
if origin == "" {
|
|
return fmt.Errorf("origin cannot be empty")
|
|
}
|
|
|
|
// Special case: "*" is always valid
|
|
if origin == "*" {
|
|
return nil
|
|
}
|
|
|
|
// Count wildcards
|
|
wildcardCount := strings.Count(origin, "*")
|
|
if wildcardCount > 1 {
|
|
return fmt.Errorf("origin can contain at most one wildcard")
|
|
}
|
|
|
|
// If there's a wildcard, it should be in a valid position
|
|
if wildcardCount == 1 {
|
|
// Must be in the format: http://*.example.com or https://*.example.com
|
|
if !strings.HasPrefix(origin, "http://") && !strings.HasPrefix(origin, "https://") {
|
|
return fmt.Errorf("origin with wildcard must start with http:// or https://")
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// EvaluateRequest evaluates a CORS request against a CORS configuration
|
|
func EvaluateRequest(config *CORSConfiguration, corsReq *CORSRequest) (*CORSResponse, error) {
|
|
if config == nil || corsReq == nil {
|
|
return nil, fmt.Errorf("config and corsReq cannot be nil")
|
|
}
|
|
|
|
if corsReq.Origin == "" {
|
|
return nil, fmt.Errorf("origin header is required for CORS requests")
|
|
}
|
|
|
|
// Find the first rule that matches the origin
|
|
for _, rule := range config.CORSRules {
|
|
if matchesOrigin(rule.AllowedOrigins, corsReq.Origin) {
|
|
// For preflight requests, we need more detailed validation
|
|
if corsReq.IsPreflightRequest {
|
|
return buildPreflightResponse(&rule, corsReq), nil
|
|
} else {
|
|
// For actual requests, check method
|
|
if containsString(rule.AllowedMethods, corsReq.Method) {
|
|
return buildResponse(&rule, corsReq), nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil, fmt.Errorf("no matching CORS rule found")
|
|
}
|
|
|
|
// buildPreflightResponse builds a CORS response for preflight requests
|
|
func buildPreflightResponse(rule *CORSRule, corsReq *CORSRequest) *CORSResponse {
|
|
response := &CORSResponse{
|
|
AllowOrigin: corsReq.Origin,
|
|
}
|
|
|
|
// Check if the requested method is allowed
|
|
methodAllowed := corsReq.AccessControlRequestMethod == "" || containsString(rule.AllowedMethods, corsReq.AccessControlRequestMethod)
|
|
|
|
// Check requested headers
|
|
var allowedRequestHeaders []string
|
|
allHeadersAllowed := true
|
|
|
|
if len(corsReq.AccessControlRequestHeaders) > 0 {
|
|
// Check if wildcard is allowed
|
|
hasWildcard := false
|
|
for _, header := range rule.AllowedHeaders {
|
|
if header == "*" {
|
|
hasWildcard = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if hasWildcard {
|
|
// All requested headers are allowed with wildcard
|
|
allowedRequestHeaders = corsReq.AccessControlRequestHeaders
|
|
} else {
|
|
// Check each requested header individually
|
|
for _, requestedHeader := range corsReq.AccessControlRequestHeaders {
|
|
if matchesHeader(rule.AllowedHeaders, requestedHeader) {
|
|
allowedRequestHeaders = append(allowedRequestHeaders, requestedHeader)
|
|
} else {
|
|
allHeadersAllowed = false
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Only set method and header info if both method and ALL headers are allowed
|
|
if methodAllowed && allHeadersAllowed {
|
|
response.AllowMethods = strings.Join(rule.AllowedMethods, ", ")
|
|
|
|
if len(allowedRequestHeaders) > 0 {
|
|
response.AllowHeaders = strings.Join(allowedRequestHeaders, ", ")
|
|
}
|
|
|
|
// Set exposed headers
|
|
if len(rule.ExposeHeaders) > 0 {
|
|
response.ExposeHeaders = strings.Join(rule.ExposeHeaders, ", ")
|
|
}
|
|
|
|
// Set max age
|
|
if rule.MaxAgeSeconds != nil {
|
|
response.MaxAge = strconv.Itoa(*rule.MaxAgeSeconds)
|
|
}
|
|
}
|
|
|
|
return response
|
|
}
|
|
|
|
// buildResponse builds a CORS response from a matching rule
|
|
func buildResponse(rule *CORSRule, corsReq *CORSRequest) *CORSResponse {
|
|
response := &CORSResponse{
|
|
AllowOrigin: corsReq.Origin,
|
|
}
|
|
|
|
// Set allowed methods
|
|
response.AllowMethods = strings.Join(rule.AllowedMethods, ", ")
|
|
|
|
// Set allowed headers
|
|
if len(rule.AllowedHeaders) > 0 {
|
|
response.AllowHeaders = strings.Join(rule.AllowedHeaders, ", ")
|
|
}
|
|
|
|
// Set expose headers
|
|
if len(rule.ExposeHeaders) > 0 {
|
|
response.ExposeHeaders = strings.Join(rule.ExposeHeaders, ", ")
|
|
}
|
|
|
|
// Set max age
|
|
if rule.MaxAgeSeconds != nil {
|
|
response.MaxAge = strconv.Itoa(*rule.MaxAgeSeconds)
|
|
}
|
|
|
|
return response
|
|
}
|
|
|
|
// Helper functions
|
|
|
|
// matchesOrigin checks if the request origin matches any allowed origin
|
|
func matchesOrigin(allowedOrigins []string, origin string) bool {
|
|
for _, allowedOrigin := range allowedOrigins {
|
|
if allowedOrigin == "*" {
|
|
return true
|
|
}
|
|
if allowedOrigin == origin {
|
|
return true
|
|
}
|
|
// Handle wildcard patterns like https://*.example.com
|
|
if strings.Contains(allowedOrigin, "*") {
|
|
if matchWildcard(allowedOrigin, origin) {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// matchWildcard performs wildcard matching for origins
|
|
func matchWildcard(pattern, text string) bool {
|
|
// Simple wildcard matching - only supports single * at the beginning
|
|
if strings.HasPrefix(pattern, "http://*") {
|
|
suffix := pattern[8:] // Remove "http://*"
|
|
return strings.HasPrefix(text, "http://") && strings.HasSuffix(text, suffix)
|
|
}
|
|
if strings.HasPrefix(pattern, "https://*") {
|
|
suffix := pattern[9:] // Remove "https://*"
|
|
return strings.HasPrefix(text, "https://") && strings.HasSuffix(text, suffix)
|
|
}
|
|
return false
|
|
}
|
|
|
|
// matchesHeader checks if a header is allowed
|
|
func matchesHeader(allowedHeaders []string, header string) bool {
|
|
// If no headers are specified, all headers are allowed
|
|
if len(allowedHeaders) == 0 {
|
|
return true
|
|
}
|
|
|
|
// Header matching is case-insensitive
|
|
header = strings.ToLower(header)
|
|
|
|
for _, allowedHeader := range allowedHeaders {
|
|
allowedHeaderLower := strings.ToLower(allowedHeader)
|
|
|
|
// Wildcard match
|
|
if allowedHeaderLower == "*" {
|
|
return true
|
|
}
|
|
|
|
// Exact match
|
|
if allowedHeaderLower == header {
|
|
return true
|
|
}
|
|
|
|
// Prefix wildcard match (e.g., "x-amz-*" matches "x-amz-date")
|
|
if strings.HasSuffix(allowedHeaderLower, "*") {
|
|
prefix := strings.TrimSuffix(allowedHeaderLower, "*")
|
|
if strings.HasPrefix(header, prefix) {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// containsString checks if a slice contains a specific string
|
|
func containsString(slice []string, item string) bool {
|
|
for _, s := range slice {
|
|
if s == item {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// ApplyHeaders applies CORS headers to an HTTP response
|
|
func ApplyHeaders(w http.ResponseWriter, corsResp *CORSResponse) {
|
|
if corsResp == nil {
|
|
return
|
|
}
|
|
|
|
if corsResp.AllowOrigin != "" {
|
|
w.Header().Set("Access-Control-Allow-Origin", corsResp.AllowOrigin)
|
|
}
|
|
|
|
if corsResp.AllowMethods != "" {
|
|
w.Header().Set("Access-Control-Allow-Methods", corsResp.AllowMethods)
|
|
}
|
|
|
|
if corsResp.AllowHeaders != "" {
|
|
w.Header().Set("Access-Control-Allow-Headers", corsResp.AllowHeaders)
|
|
}
|
|
|
|
if corsResp.ExposeHeaders != "" {
|
|
w.Header().Set("Access-Control-Expose-Headers", corsResp.ExposeHeaders)
|
|
}
|
|
|
|
if corsResp.MaxAge != "" {
|
|
w.Header().Set("Access-Control-Max-Age", corsResp.MaxAge)
|
|
}
|
|
|
|
if corsResp.AllowCredentials {
|
|
w.Header().Set("Access-Control-Allow-Credentials", "true")
|
|
}
|
|
}
|