You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							385 lines
						
					
					
						
							10 KiB
						
					
					
				
			
		
		
		
			
			
			
		
		
	
	
							385 lines
						
					
					
						
							10 KiB
						
					
					
				
								package cors
							 | 
						|
								
							 | 
						|
								import (
							 | 
						|
									"fmt"
							 | 
						|
									"net/http"
							 | 
						|
									"strconv"
							 | 
						|
									"strings"
							 | 
						|
								)
							 | 
						|
								
							 | 
						|
								// CORSRule represents a single CORS rule
							 | 
						|
								type CORSRule struct {
							 | 
						|
									AllowedHeaders []string `xml:"AllowedHeader,omitempty" json:"AllowedHeaders,omitempty"`
							 | 
						|
									AllowedMethods []string `xml:"AllowedMethod" json:"AllowedMethods"`
							 | 
						|
									AllowedOrigins []string `xml:"AllowedOrigin" json:"AllowedOrigins"`
							 | 
						|
									ExposeHeaders  []string `xml:"ExposeHeader,omitempty" json:"ExposeHeaders,omitempty"`
							 | 
						|
									MaxAgeSeconds  *int     `xml:"MaxAgeSeconds,omitempty" json:"MaxAgeSeconds,omitempty"`
							 | 
						|
									ID             string   `xml:"ID,omitempty" json:"ID,omitempty"`
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// CORSConfiguration represents the CORS configuration for a bucket
							 | 
						|
								type CORSConfiguration struct {
							 | 
						|
									CORSRules []CORSRule `xml:"CORSRule" json:"CORSRules"`
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// CORSRequest represents a CORS request
							 | 
						|
								type CORSRequest struct {
							 | 
						|
									Origin                      string
							 | 
						|
									Method                      string
							 | 
						|
									RequestHeaders              []string
							 | 
						|
									IsPreflightRequest          bool
							 | 
						|
									AccessControlRequestMethod  string
							 | 
						|
									AccessControlRequestHeaders []string
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// CORSResponse represents the response for a CORS request
							 | 
						|
								type CORSResponse struct {
							 | 
						|
									AllowOrigin      string
							 | 
						|
									AllowMethods     string
							 | 
						|
									AllowHeaders     string
							 | 
						|
									ExposeHeaders    string
							 | 
						|
									MaxAge           string
							 | 
						|
									AllowCredentials bool
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// ValidateConfiguration validates a CORS configuration
							 | 
						|
								func ValidateConfiguration(config *CORSConfiguration) error {
							 | 
						|
									if config == nil {
							 | 
						|
										return fmt.Errorf("CORS configuration cannot be nil")
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if len(config.CORSRules) == 0 {
							 | 
						|
										return fmt.Errorf("CORS configuration must have at least one rule")
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if len(config.CORSRules) > 100 {
							 | 
						|
										return fmt.Errorf("CORS configuration cannot have more than 100 rules")
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									for i, rule := range config.CORSRules {
							 | 
						|
										if err := validateRule(&rule); err != nil {
							 | 
						|
											return fmt.Errorf("invalid CORS rule at index %d: %v", i, err)
							 | 
						|
										}
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									return nil
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// ParseRequest parses an HTTP request to extract CORS information
							 | 
						|
								func ParseRequest(r *http.Request) *CORSRequest {
							 | 
						|
									corsReq := &CORSRequest{
							 | 
						|
										Origin: r.Header.Get("Origin"),
							 | 
						|
										Method: r.Method,
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Check if this is a preflight request
							 | 
						|
									if r.Method == "OPTIONS" {
							 | 
						|
										corsReq.IsPreflightRequest = true
							 | 
						|
										corsReq.AccessControlRequestMethod = r.Header.Get("Access-Control-Request-Method")
							 | 
						|
								
							 | 
						|
										if headers := r.Header.Get("Access-Control-Request-Headers"); headers != "" {
							 | 
						|
											corsReq.AccessControlRequestHeaders = strings.Split(headers, ",")
							 | 
						|
											for i := range corsReq.AccessControlRequestHeaders {
							 | 
						|
												corsReq.AccessControlRequestHeaders[i] = strings.TrimSpace(corsReq.AccessControlRequestHeaders[i])
							 | 
						|
											}
							 | 
						|
										}
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									return corsReq
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// validateRule validates a single CORS rule
							 | 
						|
								func validateRule(rule *CORSRule) error {
							 | 
						|
									if len(rule.AllowedMethods) == 0 {
							 | 
						|
										return fmt.Errorf("AllowedMethods cannot be empty")
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if len(rule.AllowedOrigins) == 0 {
							 | 
						|
										return fmt.Errorf("AllowedOrigins cannot be empty")
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Validate allowed methods
							 | 
						|
									validMethods := map[string]bool{
							 | 
						|
										"GET":    true,
							 | 
						|
										"PUT":    true,
							 | 
						|
										"POST":   true,
							 | 
						|
										"DELETE": true,
							 | 
						|
										"HEAD":   true,
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									for _, method := range rule.AllowedMethods {
							 | 
						|
										if !validMethods[method] {
							 | 
						|
											return fmt.Errorf("invalid HTTP method: %s", method)
							 | 
						|
										}
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Validate origins
							 | 
						|
									for _, origin := range rule.AllowedOrigins {
							 | 
						|
										if origin == "*" {
							 | 
						|
											continue
							 | 
						|
										}
							 | 
						|
										if err := validateOrigin(origin); err != nil {
							 | 
						|
											return fmt.Errorf("invalid origin %s: %v", origin, err)
							 | 
						|
										}
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Validate MaxAgeSeconds
							 | 
						|
									if rule.MaxAgeSeconds != nil && *rule.MaxAgeSeconds < 0 {
							 | 
						|
										return fmt.Errorf("MaxAgeSeconds cannot be negative")
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									return nil
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// validateOrigin validates an origin string
							 | 
						|
								func validateOrigin(origin string) error {
							 | 
						|
									if origin == "" {
							 | 
						|
										return fmt.Errorf("origin cannot be empty")
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Special case: "*" is always valid
							 | 
						|
									if origin == "*" {
							 | 
						|
										return nil
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Count wildcards
							 | 
						|
									wildcardCount := strings.Count(origin, "*")
							 | 
						|
									if wildcardCount > 1 {
							 | 
						|
										return fmt.Errorf("origin can contain at most one wildcard")
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// If there's a wildcard, it should be in a valid position
							 | 
						|
									if wildcardCount == 1 {
							 | 
						|
										// Must be in the format: http://*.example.com or https://*.example.com
							 | 
						|
										if !strings.HasPrefix(origin, "http://") && !strings.HasPrefix(origin, "https://") {
							 | 
						|
											return fmt.Errorf("origin with wildcard must start with http:// or https://")
							 | 
						|
										}
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									return nil
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// EvaluateRequest evaluates a CORS request against a CORS configuration
							 | 
						|
								func EvaluateRequest(config *CORSConfiguration, corsReq *CORSRequest) (*CORSResponse, error) {
							 | 
						|
									if config == nil || corsReq == nil {
							 | 
						|
										return nil, fmt.Errorf("config and corsReq cannot be nil")
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if corsReq.Origin == "" {
							 | 
						|
										return nil, fmt.Errorf("origin header is required for CORS requests")
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Find the first rule that matches the origin
							 | 
						|
									for _, rule := range config.CORSRules {
							 | 
						|
										if matchesOrigin(rule.AllowedOrigins, corsReq.Origin) {
							 | 
						|
											// For preflight requests, we need more detailed validation
							 | 
						|
											if corsReq.IsPreflightRequest {
							 | 
						|
												return buildPreflightResponse(&rule, corsReq), nil
							 | 
						|
											} else {
							 | 
						|
												// For actual requests, check method
							 | 
						|
												if containsString(rule.AllowedMethods, corsReq.Method) {
							 | 
						|
													return buildResponse(&rule, corsReq), nil
							 | 
						|
												}
							 | 
						|
											}
							 | 
						|
										}
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									return nil, fmt.Errorf("no matching CORS rule found")
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// buildPreflightResponse builds a CORS response for preflight requests
							 | 
						|
								func buildPreflightResponse(rule *CORSRule, corsReq *CORSRequest) *CORSResponse {
							 | 
						|
									response := &CORSResponse{
							 | 
						|
										AllowOrigin: corsReq.Origin,
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Check if the requested method is allowed
							 | 
						|
									methodAllowed := corsReq.AccessControlRequestMethod == "" || containsString(rule.AllowedMethods, corsReq.AccessControlRequestMethod)
							 | 
						|
								
							 | 
						|
									// Check requested headers
							 | 
						|
									var allowedRequestHeaders []string
							 | 
						|
									allHeadersAllowed := true
							 | 
						|
								
							 | 
						|
									if len(corsReq.AccessControlRequestHeaders) > 0 {
							 | 
						|
										// Check if wildcard is allowed
							 | 
						|
										hasWildcard := false
							 | 
						|
										for _, header := range rule.AllowedHeaders {
							 | 
						|
											if header == "*" {
							 | 
						|
												hasWildcard = true
							 | 
						|
												break
							 | 
						|
											}
							 | 
						|
										}
							 | 
						|
								
							 | 
						|
										if hasWildcard {
							 | 
						|
											// All requested headers are allowed with wildcard
							 | 
						|
											allowedRequestHeaders = corsReq.AccessControlRequestHeaders
							 | 
						|
										} else {
							 | 
						|
											// Check each requested header individually
							 | 
						|
											for _, requestedHeader := range corsReq.AccessControlRequestHeaders {
							 | 
						|
												if matchesHeader(rule.AllowedHeaders, requestedHeader) {
							 | 
						|
													allowedRequestHeaders = append(allowedRequestHeaders, requestedHeader)
							 | 
						|
												} else {
							 | 
						|
													allHeadersAllowed = false
							 | 
						|
												}
							 | 
						|
											}
							 | 
						|
										}
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Only set method and header info if both method and ALL headers are allowed
							 | 
						|
									if methodAllowed && allHeadersAllowed {
							 | 
						|
										response.AllowMethods = strings.Join(rule.AllowedMethods, ", ")
							 | 
						|
								
							 | 
						|
										if len(allowedRequestHeaders) > 0 {
							 | 
						|
											response.AllowHeaders = strings.Join(allowedRequestHeaders, ", ")
							 | 
						|
										}
							 | 
						|
								
							 | 
						|
										// Set exposed headers
							 | 
						|
										if len(rule.ExposeHeaders) > 0 {
							 | 
						|
											response.ExposeHeaders = strings.Join(rule.ExposeHeaders, ", ")
							 | 
						|
										}
							 | 
						|
								
							 | 
						|
										// Set max age
							 | 
						|
										if rule.MaxAgeSeconds != nil {
							 | 
						|
											response.MaxAge = strconv.Itoa(*rule.MaxAgeSeconds)
							 | 
						|
										}
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									return response
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// buildResponse builds a CORS response from a matching rule
							 | 
						|
								func buildResponse(rule *CORSRule, corsReq *CORSRequest) *CORSResponse {
							 | 
						|
									response := &CORSResponse{
							 | 
						|
										AllowOrigin: corsReq.Origin,
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Set allowed methods
							 | 
						|
									response.AllowMethods = strings.Join(rule.AllowedMethods, ", ")
							 | 
						|
								
							 | 
						|
									// Set allowed headers
							 | 
						|
									if len(rule.AllowedHeaders) > 0 {
							 | 
						|
										response.AllowHeaders = strings.Join(rule.AllowedHeaders, ", ")
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Set expose headers
							 | 
						|
									if len(rule.ExposeHeaders) > 0 {
							 | 
						|
										response.ExposeHeaders = strings.Join(rule.ExposeHeaders, ", ")
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Set max age
							 | 
						|
									if rule.MaxAgeSeconds != nil {
							 | 
						|
										response.MaxAge = strconv.Itoa(*rule.MaxAgeSeconds)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									return response
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// Helper functions
							 | 
						|
								
							 | 
						|
								// matchesOrigin checks if the request origin matches any allowed origin
							 | 
						|
								func matchesOrigin(allowedOrigins []string, origin string) bool {
							 | 
						|
									for _, allowedOrigin := range allowedOrigins {
							 | 
						|
										if allowedOrigin == "*" {
							 | 
						|
											return true
							 | 
						|
										}
							 | 
						|
										if allowedOrigin == origin {
							 | 
						|
											return true
							 | 
						|
										}
							 | 
						|
										// Handle wildcard patterns like https://*.example.com
							 | 
						|
										if strings.Contains(allowedOrigin, "*") {
							 | 
						|
											if matchWildcard(allowedOrigin, origin) {
							 | 
						|
												return true
							 | 
						|
											}
							 | 
						|
										}
							 | 
						|
									}
							 | 
						|
									return false
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// matchWildcard performs wildcard matching for origins
							 | 
						|
								func matchWildcard(pattern, text string) bool {
							 | 
						|
									// Simple wildcard matching - only supports single * at the beginning
							 | 
						|
									if strings.HasPrefix(pattern, "http://*") {
							 | 
						|
										suffix := pattern[8:] // Remove "http://*"
							 | 
						|
										return strings.HasPrefix(text, "http://") && strings.HasSuffix(text, suffix)
							 | 
						|
									}
							 | 
						|
									if strings.HasPrefix(pattern, "https://*") {
							 | 
						|
										suffix := pattern[9:] // Remove "https://*"
							 | 
						|
										return strings.HasPrefix(text, "https://") && strings.HasSuffix(text, suffix)
							 | 
						|
									}
							 | 
						|
									return false
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// matchesHeader checks if a header is allowed
							 | 
						|
								func matchesHeader(allowedHeaders []string, header string) bool {
							 | 
						|
									// If no headers are specified, all headers are allowed
							 | 
						|
									if len(allowedHeaders) == 0 {
							 | 
						|
										return true
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									// Header matching is case-insensitive
							 | 
						|
									header = strings.ToLower(header)
							 | 
						|
								
							 | 
						|
									for _, allowedHeader := range allowedHeaders {
							 | 
						|
										allowedHeaderLower := strings.ToLower(allowedHeader)
							 | 
						|
								
							 | 
						|
										// Wildcard match
							 | 
						|
										if allowedHeaderLower == "*" {
							 | 
						|
											return true
							 | 
						|
										}
							 | 
						|
								
							 | 
						|
										// Exact match
							 | 
						|
										if allowedHeaderLower == header {
							 | 
						|
											return true
							 | 
						|
										}
							 | 
						|
								
							 | 
						|
										// Prefix wildcard match (e.g., "x-amz-*" matches "x-amz-date")
							 | 
						|
										if strings.HasSuffix(allowedHeaderLower, "*") {
							 | 
						|
											prefix := strings.TrimSuffix(allowedHeaderLower, "*")
							 | 
						|
											if strings.HasPrefix(header, prefix) {
							 | 
						|
												return true
							 | 
						|
											}
							 | 
						|
										}
							 | 
						|
									}
							 | 
						|
									return false
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// containsString checks if a slice contains a specific string
							 | 
						|
								func containsString(slice []string, item string) bool {
							 | 
						|
									for _, s := range slice {
							 | 
						|
										if s == item {
							 | 
						|
											return true
							 | 
						|
										}
							 | 
						|
									}
							 | 
						|
									return false
							 | 
						|
								}
							 | 
						|
								
							 | 
						|
								// ApplyHeaders applies CORS headers to an HTTP response
							 | 
						|
								func ApplyHeaders(w http.ResponseWriter, corsResp *CORSResponse) {
							 | 
						|
									if corsResp == nil {
							 | 
						|
										return
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if corsResp.AllowOrigin != "" {
							 | 
						|
										w.Header().Set("Access-Control-Allow-Origin", corsResp.AllowOrigin)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if corsResp.AllowMethods != "" {
							 | 
						|
										w.Header().Set("Access-Control-Allow-Methods", corsResp.AllowMethods)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if corsResp.AllowHeaders != "" {
							 | 
						|
										w.Header().Set("Access-Control-Allow-Headers", corsResp.AllowHeaders)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if corsResp.ExposeHeaders != "" {
							 | 
						|
										w.Header().Set("Access-Control-Expose-Headers", corsResp.ExposeHeaders)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if corsResp.MaxAge != "" {
							 | 
						|
										w.Header().Set("Access-Control-Max-Age", corsResp.MaxAge)
							 | 
						|
									}
							 | 
						|
								
							 | 
						|
									if corsResp.AllowCredentials {
							 | 
						|
										w.Header().Set("Access-Control-Allow-Credentials", "true")
							 | 
						|
									}
							 | 
						|
								}
							 |