|
@ -2,14 +2,19 @@ package main |
|
|
|
|
|
|
|
|
import ( |
|
|
import ( |
|
|
"net/http" |
|
|
"net/http" |
|
|
"strings" |
|
|
|
|
|
|
|
|
"net/url" |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
// Do a strict referrer check, matching against both the Origin header (if
|
|
|
|
|
|
// present) and the Referrer header. If a list of headers is specified, then
|
|
|
|
|
|
// Referrer checking will be skipped if any of those headers are present.
|
|
|
func strictReferrerCheck(r *http.Request, prefix string, whitelistHeaders []string) bool { |
|
|
func strictReferrerCheck(r *http.Request, prefix string, whitelistHeaders []string) bool { |
|
|
p := strings.TrimSuffix(prefix, "/") |
|
|
|
|
|
|
|
|
p, _ := url.Parse(prefix) |
|
|
|
|
|
|
|
|
|
|
|
// if there's an Origin header, check it and skip other checks
|
|
|
if origin := r.Header.Get("Origin"); origin != "" { |
|
|
if origin := r.Header.Get("Origin"); origin != "" { |
|
|
// if there's an Origin header, check it and ignore the rest
|
|
|
|
|
|
return strings.HasPrefix(origin, p) |
|
|
|
|
|
|
|
|
u, _ := url.Parse(origin) |
|
|
|
|
|
return sameOrigin(u, p) |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
for _, header := range whitelistHeaders { |
|
|
for _, header := range whitelistHeaders { |
|
@ -18,9 +23,13 @@ func strictReferrerCheck(r *http.Request, prefix string, whitelistHeaders []stri |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if referrer := r.Header.Get("Referer"); !strings.HasPrefix(referrer, p) { |
|
|
|
|
|
return false |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
referrer := r.Header.Get("Referer") |
|
|
|
|
|
u, _ := url.Parse(referrer) |
|
|
|
|
|
return sameOrigin(u, p) |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
return true |
|
|
|
|
|
|
|
|
// Check if two URLs have the same origin
|
|
|
|
|
|
func sameOrigin(u1, u2 *url.URL) bool { |
|
|
|
|
|
// host also contains the port if one was specified
|
|
|
|
|
|
return (u1.Scheme == u2.Scheme && u1.Host == u2.Host) |
|
|
} |
|
|
} |