diff --git a/csrf.go b/csrf.go new file mode 100644 index 0000000..a743a04 --- /dev/null +++ b/csrf.go @@ -0,0 +1,24 @@ +package main + +import ( + "net/http" + "strings" +) + +func strictReferrerCheck(r *http.Request, prefix string, whitelistHeaders []string) bool { + for _, header := range whitelistHeaders { + if r.Header.Get(header) != "" { + return true + } + } + + if referrer := r.Header.Get("Referer"); !strings.HasPrefix(referrer, prefix) { + return false + } + + if origin := r.Header.Get("Origin"); origin != "" && !strings.HasPrefix(origin, prefix) { + return false + } + + return true +} diff --git a/pages.go b/pages.go index d3a2249..d16bd9c 100644 --- a/pages.go +++ b/pages.go @@ -84,6 +84,11 @@ func oopsHandler(c web.C, w http.ResponseWriter, r *http.Request, rt RespType, m } } +func badRequestHandler(c web.C, w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadRequest) + http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) +} + func unauthorizedHandler(c web.C, w http.ResponseWriter, r *http.Request) { w.WriteHeader(401) err := Templates["401.html"].ExecuteWriter(pongo2.Context{}, w) diff --git a/server_test.go b/server_test.go index 4bea964..ced20b3 100644 --- a/server_test.go +++ b/server_test.go @@ -120,6 +120,7 @@ func TestPostCodeUpload(t *testing.T) { } req.PostForm = form req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Referer", Config.siteURL) goji.DefaultMux.ServeHTTP(w, req) @@ -132,6 +133,84 @@ func TestPostCodeUpload(t *testing.T) { } } +func TestPostCodeUploadWhitelistedHeader(t *testing.T) { + w := httptest.NewRecorder() + + filename := generateBarename() + extension := "txt" + + form := url.Values{} + form.Add("content", "File content") + form.Add("filename", filename) + form.Add("extension", extension) + + req, err := http.NewRequest("POST", "/upload/", nil) + if err != nil { + t.Fatal(err) + } + req.PostForm = form + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Linx-Expiry", "0") + + goji.DefaultMux.ServeHTTP(w, req) + + if w.Code != 301 { + t.Fatalf("Status code is not 301, but %d", w.Code) + } +} + +func TestPostCodeUploadNoReferrer(t *testing.T) { + w := httptest.NewRecorder() + + filename := generateBarename() + extension := "txt" + + form := url.Values{} + form.Add("content", "File content") + form.Add("filename", filename) + form.Add("extension", extension) + + req, err := http.NewRequest("POST", "/upload/", nil) + if err != nil { + t.Fatal(err) + } + req.PostForm = form + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + goji.DefaultMux.ServeHTTP(w, req) + + if w.Code != 400 { + t.Fatalf("Status code is not 400, but %d", w.Code) + } +} + +func TestPostCodeUploadBadOrigin(t *testing.T) { + w := httptest.NewRecorder() + + filename := generateBarename() + extension := "txt" + + form := url.Values{} + form.Add("content", "File content") + form.Add("filename", filename) + form.Add("extension", extension) + + req, err := http.NewRequest("POST", "/upload/", nil) + if err != nil { + t.Fatal(err) + } + req.PostForm = form + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("Referer", Config.siteURL) + req.Header.Set("Origin", "http://example.com/") + + goji.DefaultMux.ServeHTTP(w, req) + + if w.Code != 400 { + t.Fatalf("Status code is not 400, but %d", w.Code) + } +} + func TestPostCodeExpiryJSONUpload(t *testing.T) { w := httptest.NewRecorder() @@ -147,6 +226,7 @@ func TestPostCodeExpiryJSONUpload(t *testing.T) { req.PostForm = form req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Accept", "application/json") + req.Header.Set("Referer", Config.siteURL) goji.DefaultMux.ServeHTTP(w, req) @@ -193,6 +273,7 @@ func TestPostUpload(t *testing.T) { req, err := http.NewRequest("POST", "/upload/", &b) req.Header.Set("Content-Type", mw.FormDataContentType()) + req.Header.Set("Referer", Config.siteURL) if err != nil { t.Fatal(err) } @@ -226,6 +307,7 @@ func TestPostJSONUpload(t *testing.T) { req, err := http.NewRequest("POST", "/upload/", &b) req.Header.Set("Content-Type", mw.FormDataContentType()) req.Header.Set("Accept", "application/json") + req.Header.Set("Referer", Config.siteURL) if err != nil { t.Fatal(err) } @@ -280,6 +362,7 @@ func TestPostExpiresJSONUpload(t *testing.T) { req, err := http.NewRequest("POST", "/upload/", &b) req.Header.Set("Content-Type", mw.FormDataContentType()) req.Header.Set("Accept", "application/json") + req.Header.Set("Referer", Config.siteURL) if err != nil { t.Fatal(err) } @@ -340,6 +423,7 @@ func TestPostRandomizeJSONUpload(t *testing.T) { req, err := http.NewRequest("POST", "/upload/", &b) req.Header.Set("Content-Type", mw.FormDataContentType()) req.Header.Set("Accept", "application/json") + req.Header.Set("Referer", Config.siteURL) if err != nil { t.Fatal(err) } @@ -383,6 +467,7 @@ func TestPostEmptyUpload(t *testing.T) { req, err := http.NewRequest("POST", "/upload/", &b) req.Header.Set("Content-Type", mw.FormDataContentType()) + req.Header.Set("Referer", Config.siteURL) if err != nil { t.Fatal(err) } @@ -417,6 +502,7 @@ func TestPostEmptyJSONUpload(t *testing.T) { req, err := http.NewRequest("POST", "/upload/", &b) req.Header.Set("Content-Type", mw.FormDataContentType()) req.Header.Set("Accept", "application/json") + req.Header.Set("Referer", Config.siteURL) if err != nil { t.Fatal(err) } diff --git a/upload.go b/upload.go index 5e7f43b..50694c3 100644 --- a/upload.go +++ b/upload.go @@ -45,6 +45,11 @@ type Upload struct { } func uploadPostHandler(c web.C, w http.ResponseWriter, r *http.Request) { + if !strictReferrerCheck(r, Config.siteURL, []string{"Linx-Delete-Key", "Linx-Expiry", "Linx-Randomize"}) { + badRequestHandler(c, w, r) + return + } + upReq := UploadRequest{} uploadHeaderProcess(r, &upReq)