|
|
@ -2,6 +2,7 @@ package util |
|
|
|
|
|
|
|
import ( |
|
|
|
"context" |
|
|
|
"errors" |
|
|
|
"net/http" |
|
|
|
"net/http/httptest" |
|
|
|
"testing" |
|
|
@ -10,10 +11,10 @@ import ( |
|
|
|
) |
|
|
|
|
|
|
|
type MockJSONRequestHandler struct { |
|
|
|
handler func(req *http.Request) (interface{}, *HTTPError) |
|
|
|
handler func(req *http.Request) JSONResponse |
|
|
|
} |
|
|
|
|
|
|
|
func (h *MockJSONRequestHandler) OnIncomingRequest(req *http.Request) (interface{}, *HTTPError) { |
|
|
|
func (h *MockJSONRequestHandler) OnIncomingRequest(req *http.Request) JSONResponse { |
|
|
|
return h.handler(req) |
|
|
|
} |
|
|
|
|
|
|
@ -24,22 +25,27 @@ type MockResponse struct { |
|
|
|
func TestMakeJSONAPI(t *testing.T) { |
|
|
|
log.SetLevel(log.PanicLevel) // suppress logs in test output
|
|
|
|
tests := []struct { |
|
|
|
Return interface{} |
|
|
|
Err *HTTPError |
|
|
|
Return JSONResponse |
|
|
|
ExpectCode int |
|
|
|
ExpectJSON string |
|
|
|
}{ |
|
|
|
{nil, &HTTPError{nil, "Everything is broken", 500}, 500, `{"message":"Everything is broken"}`}, // Error return values
|
|
|
|
{nil, &HTTPError{nil, "Not here", 404}, 404, `{"message":"Not here"}`}, // With different status codes
|
|
|
|
{&MockResponse{"yep"}, nil, 200, `{"foo":"yep"}`}, // Success return values
|
|
|
|
{[]MockResponse{{"yep"}, {"narp"}}, nil, 200, `[{"foo":"yep"},{"foo":"narp"}]`}, // Top-level array success values
|
|
|
|
{[]byte(`actually bytes`), nil, 200, `actually bytes`}, // raw []byte escape hatch
|
|
|
|
{func(cannotBe, marshalled string) {}, nil, 500, `{"message":"Failed to serialise response as JSON"}`}, // impossible marshal
|
|
|
|
// MessageResponse return values
|
|
|
|
{MessageResponse(500, "Everything is broken"), 500, `{"message":"Everything is broken"}`}, |
|
|
|
// interface return values
|
|
|
|
{JSONResponse{500, MockResponse{"yep"}, nil}, 500, `{"foo":"yep"}`}, |
|
|
|
// Error JSON return values which fail to be marshalled should fallback to text
|
|
|
|
{JSONResponse{500, struct { |
|
|
|
Foo interface{} `json:"foo"` |
|
|
|
}{func(cannotBe, marshalled string) {}}, nil}, 500, `{"message":"Internal Server Error"}`}, |
|
|
|
// With different status codes
|
|
|
|
{JSONResponse{201, MockResponse{"narp"}, nil}, 201, `{"foo":"narp"}`}, |
|
|
|
// Top-level array success values
|
|
|
|
{JSONResponse{200, []MockResponse{{"yep"}, {"narp"}}, nil}, 200, `[{"foo":"yep"},{"foo":"narp"}]`}, |
|
|
|
} |
|
|
|
|
|
|
|
for _, tst := range tests { |
|
|
|
mock := MockJSONRequestHandler{func(req *http.Request) (interface{}, *HTTPError) { |
|
|
|
return tst.Return, tst.Err |
|
|
|
mock := MockJSONRequestHandler{func(req *http.Request) JSONResponse { |
|
|
|
return tst.Return |
|
|
|
}} |
|
|
|
mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) |
|
|
|
mockWriter := httptest.NewRecorder() |
|
|
@ -55,10 +61,38 @@ func TestMakeJSONAPI(t *testing.T) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
func TestMakeJSONAPICustomHeaders(t *testing.T) { |
|
|
|
mock := MockJSONRequestHandler{func(req *http.Request) JSONResponse { |
|
|
|
headers := make(map[string]string) |
|
|
|
headers["Custom"] = "Thing" |
|
|
|
headers["X-Custom"] = "Things" |
|
|
|
return JSONResponse{ |
|
|
|
Code: 200, |
|
|
|
JSON: MockResponse{"yep"}, |
|
|
|
Headers: headers, |
|
|
|
} |
|
|
|
}} |
|
|
|
mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) |
|
|
|
mockWriter := httptest.NewRecorder() |
|
|
|
handlerFunc := MakeJSONAPI(&mock) |
|
|
|
handlerFunc(mockWriter, mockReq) |
|
|
|
if mockWriter.Code != 200 { |
|
|
|
t.Errorf("TestMakeJSONAPICustomHeaders wanted HTTP status 200, got %d", mockWriter.Code) |
|
|
|
} |
|
|
|
h := mockWriter.Header().Get("Custom") |
|
|
|
if h != "Thing" { |
|
|
|
t.Errorf("TestMakeJSONAPICustomHeaders wanted header 'Custom: Thing' , got 'Custom: %s'", h) |
|
|
|
} |
|
|
|
h = mockWriter.Header().Get("X-Custom") |
|
|
|
if h != "Things" { |
|
|
|
t.Errorf("TestMakeJSONAPICustomHeaders wanted header 'X-Custom: Things' , got 'X-Custom: %s'", h) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
func TestMakeJSONAPIRedirect(t *testing.T) { |
|
|
|
log.SetLevel(log.PanicLevel) // suppress logs in test output
|
|
|
|
mock := MockJSONRequestHandler{func(req *http.Request) (interface{}, *HTTPError) { |
|
|
|
return nil, &HTTPError{nil, "https://matrix.org", 302} |
|
|
|
mock := MockJSONRequestHandler{func(req *http.Request) JSONResponse { |
|
|
|
return RedirectResponse("https://matrix.org") |
|
|
|
}} |
|
|
|
mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) |
|
|
|
mockWriter := httptest.NewRecorder() |
|
|
@ -73,12 +107,74 @@ func TestMakeJSONAPIRedirect(t *testing.T) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
func TestMakeJSONAPIError(t *testing.T) { |
|
|
|
log.SetLevel(log.PanicLevel) // suppress logs in test output
|
|
|
|
mock := MockJSONRequestHandler{func(req *http.Request) JSONResponse { |
|
|
|
err := errors.New("oops") |
|
|
|
return ErrorResponse(err) |
|
|
|
}} |
|
|
|
mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) |
|
|
|
mockWriter := httptest.NewRecorder() |
|
|
|
handlerFunc := MakeJSONAPI(&mock) |
|
|
|
handlerFunc(mockWriter, mockReq) |
|
|
|
if mockWriter.Code != 500 { |
|
|
|
t.Errorf("TestMakeJSONAPIError wanted HTTP status 500, got %d", mockWriter.Code) |
|
|
|
} |
|
|
|
actualBody := mockWriter.Body.String() |
|
|
|
expect := `{"message":"oops"}` |
|
|
|
if actualBody != expect { |
|
|
|
t.Errorf("TestMakeJSONAPIError wanted body '%s', got '%s'", expect, actualBody) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
func TestIs2xx(t *testing.T) { |
|
|
|
tests := []struct { |
|
|
|
Code int |
|
|
|
Expect bool |
|
|
|
}{ |
|
|
|
{200, true}, |
|
|
|
{201, true}, |
|
|
|
{299, true}, |
|
|
|
{300, false}, |
|
|
|
{199, false}, |
|
|
|
{0, false}, |
|
|
|
{500, false}, |
|
|
|
} |
|
|
|
for _, test := range tests { |
|
|
|
j := JSONResponse{ |
|
|
|
Code: test.Code, |
|
|
|
} |
|
|
|
actual := j.Is2xx() |
|
|
|
if actual != test.Expect { |
|
|
|
t.Errorf("TestIs2xx wanted %t, got %t", test.Expect, actual) |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
func TestGetLogger(t *testing.T) { |
|
|
|
log.SetLevel(log.PanicLevel) // suppress logs in test output
|
|
|
|
entry := log.WithField("test", "yep") |
|
|
|
mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) |
|
|
|
ctx := context.WithValue(mockReq.Context(), ctxValueLogger, entry) |
|
|
|
mockReq = mockReq.WithContext(ctx) |
|
|
|
ctxLogger := GetLogger(mockReq.Context()) |
|
|
|
if ctxLogger != entry { |
|
|
|
t.Errorf("TestGetLogger wanted logger '%v', got '%v'", entry, ctxLogger) |
|
|
|
} |
|
|
|
|
|
|
|
noLoggerInReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) |
|
|
|
ctxLogger = GetLogger(noLoggerInReq.Context()) |
|
|
|
if ctxLogger != nil { |
|
|
|
t.Errorf("TestGetLogger wanted nil logger, got '%v'", ctxLogger) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
func TestProtect(t *testing.T) { |
|
|
|
log.SetLevel(log.PanicLevel) // suppress logs in test output
|
|
|
|
mockWriter := httptest.NewRecorder() |
|
|
|
mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) |
|
|
|
mockReq = mockReq.WithContext( |
|
|
|
context.WithValue(mockReq.Context(), CtxValueLogger, log.WithField("test", "yep")), |
|
|
|
context.WithValue(mockReq.Context(), ctxValueLogger, log.WithField("test", "yep")), |
|
|
|
) |
|
|
|
h := Protect(func(w http.ResponseWriter, req *http.Request) { |
|
|
|
panic("oh noes!") |
|
|
@ -97,3 +193,47 @@ func TestProtect(t *testing.T) { |
|
|
|
t.Errorf("TestProtect wanted body %s, got %s", expectBody, actualBody) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
func TestWithCORSOptions(t *testing.T) { |
|
|
|
log.SetLevel(log.PanicLevel) // suppress logs in test output
|
|
|
|
mockWriter := httptest.NewRecorder() |
|
|
|
mockReq, _ := http.NewRequest("OPTIONS", "http://example.com/foo", nil) |
|
|
|
h := WithCORSOptions(func(w http.ResponseWriter, req *http.Request) { |
|
|
|
w.WriteHeader(200) |
|
|
|
w.Write([]byte("yep")) |
|
|
|
}) |
|
|
|
h(mockWriter, mockReq) |
|
|
|
if mockWriter.Code != 200 { |
|
|
|
t.Errorf("TestWithCORSOptions wanted HTTP status 200, got %d", mockWriter.Code) |
|
|
|
} |
|
|
|
|
|
|
|
origin := mockWriter.Header().Get("Access-Control-Allow-Origin") |
|
|
|
if origin != "*" { |
|
|
|
t.Errorf("TestWithCORSOptions wanted Access-Control-Allow-Origin header '*', got '%s'", origin) |
|
|
|
} |
|
|
|
|
|
|
|
// OPTIONS request shouldn't hit the handler func
|
|
|
|
expectBody := "" |
|
|
|
actualBody := mockWriter.Body.String() |
|
|
|
if actualBody != expectBody { |
|
|
|
t.Errorf("TestWithCORSOptions wanted body %s, got %s", expectBody, actualBody) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
func TestGetRequestID(t *testing.T) { |
|
|
|
log.SetLevel(log.PanicLevel) // suppress logs in test output
|
|
|
|
reqID := "alphabetsoup" |
|
|
|
mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) |
|
|
|
ctx := context.WithValue(mockReq.Context(), ctxValueRequestID, reqID) |
|
|
|
mockReq = mockReq.WithContext(ctx) |
|
|
|
ctxReqID := GetRequestID(mockReq.Context()) |
|
|
|
if reqID != ctxReqID { |
|
|
|
t.Errorf("TestGetRequestID wanted request ID '%s', got '%s'", reqID, ctxReqID) |
|
|
|
} |
|
|
|
|
|
|
|
noReqIDInReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) |
|
|
|
ctxReqID = GetRequestID(noReqIDInReq.Context()) |
|
|
|
if ctxReqID != "" { |
|
|
|
t.Errorf("TestGetRequestID wanted empty request ID, got '%s'", ctxReqID) |
|
|
|
} |
|
|
|
} |