From 9b349b12f4aae315a67944b284e6e43e9afe5916 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Thu, 23 Feb 2017 16:17:45 +0000 Subject: [PATCH] Update util dep to pull in util.JSONResponse --- vendor/manifest | 2 +- .../src/github.com/matrix-org/util/context.go | 35 ++++ .../src/github.com/matrix-org/util/error.go | 18 -- vendor/src/github.com/matrix-org/util/json.go | 154 +++++++++------- .../github.com/matrix-org/util/json_test.go | 170 ++++++++++++++++-- 5 files changed, 280 insertions(+), 99 deletions(-) create mode 100644 vendor/src/github.com/matrix-org/util/context.go delete mode 100644 vendor/src/github.com/matrix-org/util/error.go diff --git a/vendor/manifest b/vendor/manifest index 1bf7484..93d78e9 100644 --- a/vendor/manifest +++ b/vendor/manifest @@ -150,7 +150,7 @@ { "importpath": "github.com/matrix-org/util", "repository": "https://github.com/matrix-org/util", - "revision": "9b44af331cdd83d702e4f16433e47341d983c23b", + "revision": "ccef6dc7c24a7c896d96b433a9107b7c47ecf828", "branch": "master" }, { diff --git a/vendor/src/github.com/matrix-org/util/context.go b/vendor/src/github.com/matrix-org/util/context.go new file mode 100644 index 0000000..d8def4f --- /dev/null +++ b/vendor/src/github.com/matrix-org/util/context.go @@ -0,0 +1,35 @@ +package util + +import ( + "context" + + log "github.com/Sirupsen/logrus" +) + +// contextKeys is a type alias for string to namespace Context keys per-package. +type contextKeys string + +// ctxValueRequestID is the key to extract the request ID for an HTTP request +const ctxValueRequestID = contextKeys("requestid") + +// GetRequestID returns the request ID associated with this context, or the empty string +// if one is not associated with this context. +func GetRequestID(ctx context.Context) string { + id := ctx.Value(ctxValueRequestID) + if id == nil { + return "" + } + return id.(string) +} + +// ctxValueLogger is the key to extract the logrus Logger. +const ctxValueLogger = contextKeys("logger") + +// GetLogger retrieves the logrus logger from the supplied context. Returns nil if there is no logger. +func GetLogger(ctx context.Context) *log.Entry { + l := ctx.Value(ctxValueLogger) + if l == nil { + return nil + } + return l.(*log.Entry) +} diff --git a/vendor/src/github.com/matrix-org/util/error.go b/vendor/src/github.com/matrix-org/util/error.go deleted file mode 100644 index 9d40c57..0000000 --- a/vendor/src/github.com/matrix-org/util/error.go +++ /dev/null @@ -1,18 +0,0 @@ -package util - -import "fmt" - -// HTTPError An HTTP Error response, which may wrap an underlying native Go Error. -type HTTPError struct { - WrappedError error - Message string - Code int -} - -func (e HTTPError) Error() string { - var wrappedErrMsg string - if e.WrappedError != nil { - wrappedErrMsg = e.WrappedError.Error() - } - return fmt.Sprintf("%s: %d: %s", e.Message, e.Code, wrappedErrMsg) -} diff --git a/vendor/src/github.com/matrix-org/util/json.go b/vendor/src/github.com/matrix-org/util/json.go index 4735bf5..b0834ea 100644 --- a/vendor/src/github.com/matrix-org/util/json.go +++ b/vendor/src/github.com/matrix-org/util/json.go @@ -3,50 +3,75 @@ package util import ( "context" "encoding/json" - "fmt" "math/rand" "net/http" "runtime/debug" + "time" log "github.com/Sirupsen/logrus" ) -// ContextKeys is a type alias for string to namespace Context keys per-package. -type ContextKeys string +// JSONResponse represents an HTTP response which contains a JSON body. +type JSONResponse struct { + // HTTP status code. + Code int + // JSON represents the JSON that should be serialized and sent to the client + JSON interface{} + // Headers represent any headers that should be sent to the client + Headers map[string]string +} -// CtxValueLogger is the key to extract the logrus Logger. -const CtxValueLogger = ContextKeys("logger") +// Is2xx returns true if the Code is between 200 and 299. +func (r JSONResponse) Is2xx() bool { + return r.Code/100 == 2 +} -// JSONRequestHandler represents an interface that must be satisfied in order to respond to incoming -// HTTP requests with JSON. The interface returned will be marshalled into JSON to be sent to the client, -// unless the interface is []byte in which case the bytes are sent to the client unchanged. -// If an error is returned, a JSON error response will also be returned, unless the error code -// is a 302 REDIRECT in which case a redirect is sent based on the Message field. -type JSONRequestHandler interface { - OnIncomingRequest(req *http.Request) (interface{}, *HTTPError) +// RedirectResponse returns a JSONResponse which 302s the client to the given location. +func RedirectResponse(location string) JSONResponse { + headers := make(map[string]string) + headers["Location"] = location + return JSONResponse{ + Code: 302, + JSON: struct{}{}, + Headers: headers, + } +} + +// MessageResponse returns a JSONResponse with a 'message' key containing the given text. +func MessageResponse(code int, msg string) JSONResponse { + return JSONResponse{ + Code: code, + JSON: struct { + Message string `json:"message"` + }{msg}, + } } -// JSONError represents a JSON API error response -type JSONError struct { - Message string `json:"message"` +// ErrorResponse returns an HTTP 500 JSONResponse with the stringified form of the given error. +func ErrorResponse(err error) JSONResponse { + return MessageResponse(500, err.Error()) +} + +// JSONRequestHandler represents an interface that must be satisfied in order to respond to incoming +// HTTP requests with JSON. +type JSONRequestHandler interface { + OnIncomingRequest(req *http.Request) JSONResponse } // Protect panicking HTTP requests from taking down the entire process, and log them using // the correct logger, returning a 500 with a JSON response rather than abruptly closing the -// connection. The http.Request MUST have a CtxValueLogger. +// connection. The http.Request MUST have a ctxValueLogger. func Protect(handler http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, req *http.Request) { defer func() { if r := recover(); r != nil { - logger := req.Context().Value(CtxValueLogger).(*log.Entry) + logger := req.Context().Value(ctxValueLogger).(*log.Entry) logger.WithFields(log.Fields{ "panic": r, }).Errorf( "Request panicked!\n%s", debug.Stack(), ) - jsonErrorResponse( - w, req, &HTTPError{nil, "Internal Server Error", 500}, - ) + respond(w, req, MessageResponse(500, "Internal Server Error")) } }() handler(w, req) @@ -55,72 +80,67 @@ func Protect(handler http.HandlerFunc) http.HandlerFunc { // MakeJSONAPI creates an HTTP handler which always responds to incoming requests with JSON responses. // Incoming http.Requests will have a logger (with a request ID/method/path logged) attached to the Context. -// This can be accessed via the const CtxValueLogger. The type of the logger is *log.Entry from github.com/Sirupsen/logrus +// This can be accessed via GetLogger(Context). func MakeJSONAPI(handler JSONRequestHandler) http.HandlerFunc { return Protect(func(w http.ResponseWriter, req *http.Request) { - // Set a Logger on the context - ctx := context.WithValue(req.Context(), CtxValueLogger, log.WithFields(log.Fields{ + reqID := RandomString(12) + // Set a Logger and request ID on the context + ctx := context.WithValue(req.Context(), ctxValueLogger, log.WithFields(log.Fields{ "req.method": req.Method, "req.path": req.URL.Path, - "req.id": RandomString(12), + "req.id": reqID, })) + ctx = context.WithValue(ctx, ctxValueRequestID, reqID) req = req.WithContext(ctx) - logger := req.Context().Value(CtxValueLogger).(*log.Entry) + logger := req.Context().Value(ctxValueLogger).(*log.Entry) logger.Print("Incoming request") - res, httpErr := handler.OnIncomingRequest(req) + res := handler.OnIncomingRequest(req) // Set common headers returned regardless of the outcome of the request w.Header().Set("Content-Type", "application/json") SetCORSHeaders(w) - if httpErr != nil { - jsonErrorResponse(w, req, httpErr) - return - } - - // if they've returned bytes as the response, then just return them rather than marshalling as JSON. - // This gives handlers an escape hatch if they want to return cached bytes. - var resBytes []byte - resBytes, ok := res.([]byte) - if !ok { - r, err := json.Marshal(res) - if err != nil { - jsonErrorResponse(w, req, &HTTPError{nil, "Failed to serialise response as JSON", 500}) - return - } - resBytes = r - } - logger.Print(fmt.Sprintf("Responding (%d bytes)", len(resBytes))) - w.Write(resBytes) + respond(w, req, res) }) } -func jsonErrorResponse(w http.ResponseWriter, req *http.Request, httpErr *HTTPError) { - logger := req.Context().Value(CtxValueLogger).(*log.Entry) - if httpErr.Code == 302 { - logger.WithField("err", httpErr.Error()).Print("Redirecting") - http.Redirect(w, req, httpErr.Message, 302) - return - } - logger.WithFields(log.Fields{ - log.ErrorKey: httpErr, - }).Print("Responding with error") +func respond(w http.ResponseWriter, req *http.Request, res JSONResponse) { + logger := req.Context().Value(ctxValueLogger).(*log.Entry) - w.WriteHeader(httpErr.Code) // Set response code + // Set custom headers + if res.Headers != nil { + for h, val := range res.Headers { + w.Header().Set(h, val) + } + } - r, err := json.Marshal(&JSONError{ - Message: httpErr.Message, - }) + // Marshal JSON response into raw bytes to send as the HTTP body + resBytes, err := json.Marshal(res.JSON) if err != nil { - // We should never fail to marshal the JSON error response, but in this event just skip - // marshalling altogether - logger.Warn("Failed to marshal error response") - w.Write([]byte(`{}`)) - return + logger.WithError(err).Error("Failed to marshal JSONResponse") + // this should never fail to be marshalled so drop err to the floor + res = MessageResponse(500, "Internal Server Error") + resBytes, _ = json.Marshal(res.JSON) + } + + // Set status code and write the body + w.WriteHeader(res.Code) + logger.WithField("code", res.Code).Infof("Responding (%d bytes)", len(resBytes)) + w.Write(resBytes) +} + +// WithCORSOptions intercepts all OPTIONS requests and responds with CORS headers. The request handler +// is not invoked when this happens. +func WithCORSOptions(handler http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + if req.Method == "OPTIONS" { + SetCORSHeaders(w) + return + } + handler(w, req) } - w.Write(r) } // SetCORSHeaders sets unrestricted origin Access-Control headers on the response writer @@ -140,3 +160,7 @@ func RandomString(n int) string { } return string(b) } + +func init() { + rand.Seed(time.Now().UTC().UnixNano()) +} diff --git a/vendor/src/github.com/matrix-org/util/json_test.go b/vendor/src/github.com/matrix-org/util/json_test.go index 203fa70..687db27 100644 --- a/vendor/src/github.com/matrix-org/util/json_test.go +++ b/vendor/src/github.com/matrix-org/util/json_test.go @@ -2,6 +2,7 @@ package util import ( "context" + "errors" "net/http" "net/http/httptest" "testing" @@ -10,10 +11,10 @@ import ( ) type MockJSONRequestHandler struct { - handler func(req *http.Request) (interface{}, *HTTPError) + handler func(req *http.Request) JSONResponse } -func (h *MockJSONRequestHandler) OnIncomingRequest(req *http.Request) (interface{}, *HTTPError) { +func (h *MockJSONRequestHandler) OnIncomingRequest(req *http.Request) JSONResponse { return h.handler(req) } @@ -24,22 +25,27 @@ type MockResponse struct { func TestMakeJSONAPI(t *testing.T) { log.SetLevel(log.PanicLevel) // suppress logs in test output tests := []struct { - Return interface{} - Err *HTTPError + Return JSONResponse ExpectCode int ExpectJSON string }{ - {nil, &HTTPError{nil, "Everything is broken", 500}, 500, `{"message":"Everything is broken"}`}, // Error return values - {nil, &HTTPError{nil, "Not here", 404}, 404, `{"message":"Not here"}`}, // With different status codes - {&MockResponse{"yep"}, nil, 200, `{"foo":"yep"}`}, // Success return values - {[]MockResponse{{"yep"}, {"narp"}}, nil, 200, `[{"foo":"yep"},{"foo":"narp"}]`}, // Top-level array success values - {[]byte(`actually bytes`), nil, 200, `actually bytes`}, // raw []byte escape hatch - {func(cannotBe, marshalled string) {}, nil, 500, `{"message":"Failed to serialise response as JSON"}`}, // impossible marshal + // MessageResponse return values + {MessageResponse(500, "Everything is broken"), 500, `{"message":"Everything is broken"}`}, + // interface return values + {JSONResponse{500, MockResponse{"yep"}, nil}, 500, `{"foo":"yep"}`}, + // Error JSON return values which fail to be marshalled should fallback to text + {JSONResponse{500, struct { + Foo interface{} `json:"foo"` + }{func(cannotBe, marshalled string) {}}, nil}, 500, `{"message":"Internal Server Error"}`}, + // With different status codes + {JSONResponse{201, MockResponse{"narp"}, nil}, 201, `{"foo":"narp"}`}, + // Top-level array success values + {JSONResponse{200, []MockResponse{{"yep"}, {"narp"}}, nil}, 200, `[{"foo":"yep"},{"foo":"narp"}]`}, } for _, tst := range tests { - mock := MockJSONRequestHandler{func(req *http.Request) (interface{}, *HTTPError) { - return tst.Return, tst.Err + mock := MockJSONRequestHandler{func(req *http.Request) JSONResponse { + return tst.Return }} mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) mockWriter := httptest.NewRecorder() @@ -55,10 +61,38 @@ func TestMakeJSONAPI(t *testing.T) { } } +func TestMakeJSONAPICustomHeaders(t *testing.T) { + mock := MockJSONRequestHandler{func(req *http.Request) JSONResponse { + headers := make(map[string]string) + headers["Custom"] = "Thing" + headers["X-Custom"] = "Things" + return JSONResponse{ + Code: 200, + JSON: MockResponse{"yep"}, + Headers: headers, + } + }} + mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) + mockWriter := httptest.NewRecorder() + handlerFunc := MakeJSONAPI(&mock) + handlerFunc(mockWriter, mockReq) + if mockWriter.Code != 200 { + t.Errorf("TestMakeJSONAPICustomHeaders wanted HTTP status 200, got %d", mockWriter.Code) + } + h := mockWriter.Header().Get("Custom") + if h != "Thing" { + t.Errorf("TestMakeJSONAPICustomHeaders wanted header 'Custom: Thing' , got 'Custom: %s'", h) + } + h = mockWriter.Header().Get("X-Custom") + if h != "Things" { + t.Errorf("TestMakeJSONAPICustomHeaders wanted header 'X-Custom: Things' , got 'X-Custom: %s'", h) + } +} + func TestMakeJSONAPIRedirect(t *testing.T) { log.SetLevel(log.PanicLevel) // suppress logs in test output - mock := MockJSONRequestHandler{func(req *http.Request) (interface{}, *HTTPError) { - return nil, &HTTPError{nil, "https://matrix.org", 302} + mock := MockJSONRequestHandler{func(req *http.Request) JSONResponse { + return RedirectResponse("https://matrix.org") }} mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) mockWriter := httptest.NewRecorder() @@ -73,12 +107,74 @@ func TestMakeJSONAPIRedirect(t *testing.T) { } } +func TestMakeJSONAPIError(t *testing.T) { + log.SetLevel(log.PanicLevel) // suppress logs in test output + mock := MockJSONRequestHandler{func(req *http.Request) JSONResponse { + err := errors.New("oops") + return ErrorResponse(err) + }} + mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) + mockWriter := httptest.NewRecorder() + handlerFunc := MakeJSONAPI(&mock) + handlerFunc(mockWriter, mockReq) + if mockWriter.Code != 500 { + t.Errorf("TestMakeJSONAPIError wanted HTTP status 500, got %d", mockWriter.Code) + } + actualBody := mockWriter.Body.String() + expect := `{"message":"oops"}` + if actualBody != expect { + t.Errorf("TestMakeJSONAPIError wanted body '%s', got '%s'", expect, actualBody) + } +} + +func TestIs2xx(t *testing.T) { + tests := []struct { + Code int + Expect bool + }{ + {200, true}, + {201, true}, + {299, true}, + {300, false}, + {199, false}, + {0, false}, + {500, false}, + } + for _, test := range tests { + j := JSONResponse{ + Code: test.Code, + } + actual := j.Is2xx() + if actual != test.Expect { + t.Errorf("TestIs2xx wanted %t, got %t", test.Expect, actual) + } + } +} + +func TestGetLogger(t *testing.T) { + log.SetLevel(log.PanicLevel) // suppress logs in test output + entry := log.WithField("test", "yep") + mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) + ctx := context.WithValue(mockReq.Context(), ctxValueLogger, entry) + mockReq = mockReq.WithContext(ctx) + ctxLogger := GetLogger(mockReq.Context()) + if ctxLogger != entry { + t.Errorf("TestGetLogger wanted logger '%v', got '%v'", entry, ctxLogger) + } + + noLoggerInReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) + ctxLogger = GetLogger(noLoggerInReq.Context()) + if ctxLogger != nil { + t.Errorf("TestGetLogger wanted nil logger, got '%v'", ctxLogger) + } +} + func TestProtect(t *testing.T) { log.SetLevel(log.PanicLevel) // suppress logs in test output mockWriter := httptest.NewRecorder() mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) mockReq = mockReq.WithContext( - context.WithValue(mockReq.Context(), CtxValueLogger, log.WithField("test", "yep")), + context.WithValue(mockReq.Context(), ctxValueLogger, log.WithField("test", "yep")), ) h := Protect(func(w http.ResponseWriter, req *http.Request) { panic("oh noes!") @@ -97,3 +193,47 @@ func TestProtect(t *testing.T) { t.Errorf("TestProtect wanted body %s, got %s", expectBody, actualBody) } } + +func TestWithCORSOptions(t *testing.T) { + log.SetLevel(log.PanicLevel) // suppress logs in test output + mockWriter := httptest.NewRecorder() + mockReq, _ := http.NewRequest("OPTIONS", "http://example.com/foo", nil) + h := WithCORSOptions(func(w http.ResponseWriter, req *http.Request) { + w.WriteHeader(200) + w.Write([]byte("yep")) + }) + h(mockWriter, mockReq) + if mockWriter.Code != 200 { + t.Errorf("TestWithCORSOptions wanted HTTP status 200, got %d", mockWriter.Code) + } + + origin := mockWriter.Header().Get("Access-Control-Allow-Origin") + if origin != "*" { + t.Errorf("TestWithCORSOptions wanted Access-Control-Allow-Origin header '*', got '%s'", origin) + } + + // OPTIONS request shouldn't hit the handler func + expectBody := "" + actualBody := mockWriter.Body.String() + if actualBody != expectBody { + t.Errorf("TestWithCORSOptions wanted body %s, got %s", expectBody, actualBody) + } +} + +func TestGetRequestID(t *testing.T) { + log.SetLevel(log.PanicLevel) // suppress logs in test output + reqID := "alphabetsoup" + mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) + ctx := context.WithValue(mockReq.Context(), ctxValueRequestID, reqID) + mockReq = mockReq.WithContext(ctx) + ctxReqID := GetRequestID(mockReq.Context()) + if reqID != ctxReqID { + t.Errorf("TestGetRequestID wanted request ID '%s', got '%s'", reqID, ctxReqID) + } + + noReqIDInReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) + ctxReqID = GetRequestID(noReqIDInReq.Context()) + if ctxReqID != "" { + t.Errorf("TestGetRequestID wanted empty request ID, got '%s'", ctxReqID) + } +}