You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
96 lines
2.6 KiB
96 lines
2.6 KiB
package dash
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"crypto/subtle"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"net/http"
|
|
|
|
"github.com/gorilla/sessions"
|
|
)
|
|
|
|
const sessionCSRFTokenKey = "csrf_token"
|
|
|
|
func generateCSRFToken() (string, error) {
|
|
tokenBytes := make([]byte, 32)
|
|
if _, err := rand.Read(tokenBytes); err != nil {
|
|
return "", err
|
|
}
|
|
return hex.EncodeToString(tokenBytes), nil
|
|
}
|
|
|
|
func getOrCreateSessionCSRFToken(session *sessions.Session, r *http.Request, w http.ResponseWriter) (string, error) {
|
|
if existing, ok := session.Values[sessionCSRFTokenKey].(string); ok && existing != "" {
|
|
return existing, nil
|
|
}
|
|
token, err := generateCSRFToken()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
session.Values[sessionCSRFTokenKey] = token
|
|
if err := session.Save(r, w); err != nil {
|
|
return "", err
|
|
}
|
|
return token, nil
|
|
}
|
|
|
|
func requireSessionCSRFToken(w http.ResponseWriter, r *http.Request) bool {
|
|
expectedToken := CSRFTokenFromContext(r.Context())
|
|
username := UsernameFromContext(r.Context())
|
|
if expectedToken == "" {
|
|
// Admin UI can run without auth; in that mode CSRF token checks are not applicable.
|
|
if username == "" {
|
|
return true
|
|
}
|
|
writeJSONError(w, http.StatusForbidden, "missing CSRF session token")
|
|
return false
|
|
}
|
|
|
|
providedToken, err := getProvidedCSRFToken(r)
|
|
if err != nil {
|
|
writeJSONError(w, http.StatusBadRequest, "Failed to parse form: "+err.Error())
|
|
return false
|
|
}
|
|
if providedToken == "" || subtle.ConstantTimeCompare([]byte(expectedToken), []byte(providedToken)) != 1 {
|
|
writeJSONError(w, http.StatusForbidden, "invalid CSRF token")
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func getProvidedCSRFToken(r *http.Request) (string, error) {
|
|
providedToken := r.Header.Get("X-CSRF-Token")
|
|
if providedToken != "" {
|
|
return providedToken, nil
|
|
}
|
|
if err := r.ParseForm(); err != nil {
|
|
return "", err
|
|
}
|
|
return r.FormValue("csrf_token"), nil
|
|
}
|
|
|
|
func EnsureSessionCSRFToken(session *sessions.Session, r *http.Request, w http.ResponseWriter) (string, error) {
|
|
if session == nil {
|
|
return "", fmt.Errorf("session is nil")
|
|
}
|
|
return getOrCreateSessionCSRFToken(session, r, w)
|
|
}
|
|
|
|
func ValidateSessionCSRFToken(session *sessions.Session, r *http.Request) error {
|
|
if session == nil {
|
|
return fmt.Errorf("session is nil")
|
|
}
|
|
expectedToken, _ := session.Values[sessionCSRFTokenKey].(string)
|
|
providedToken, err := getProvidedCSRFToken(r)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to read CSRF token: %w", err)
|
|
}
|
|
if expectedToken == "" {
|
|
return fmt.Errorf("missing session CSRF token")
|
|
}
|
|
if providedToken == "" || subtle.ConstantTimeCompare([]byte(expectedToken), []byte(providedToken)) != 1 {
|
|
return fmt.Errorf("invalid CSRF token")
|
|
}
|
|
return nil
|
|
}
|