Browse Source

Update util dep to pull in util.JSONResponse

pull/162/head
Kegan Dougal 8 years ago
parent
commit
9b349b12f4
  1. 2
      vendor/manifest
  2. 35
      vendor/src/github.com/matrix-org/util/context.go
  3. 18
      vendor/src/github.com/matrix-org/util/error.go
  4. 146
      vendor/src/github.com/matrix-org/util/json.go
  5. 170
      vendor/src/github.com/matrix-org/util/json_test.go

2
vendor/manifest

@ -150,7 +150,7 @@
{ {
"importpath": "github.com/matrix-org/util", "importpath": "github.com/matrix-org/util",
"repository": "https://github.com/matrix-org/util", "repository": "https://github.com/matrix-org/util",
"revision": "9b44af331cdd83d702e4f16433e47341d983c23b",
"revision": "ccef6dc7c24a7c896d96b433a9107b7c47ecf828",
"branch": "master" "branch": "master"
}, },
{ {

35
vendor/src/github.com/matrix-org/util/context.go

@ -0,0 +1,35 @@
package util
import (
"context"
log "github.com/Sirupsen/logrus"
)
// contextKeys is a type alias for string to namespace Context keys per-package.
type contextKeys string
// ctxValueRequestID is the key to extract the request ID for an HTTP request
const ctxValueRequestID = contextKeys("requestid")
// GetRequestID returns the request ID associated with this context, or the empty string
// if one is not associated with this context.
func GetRequestID(ctx context.Context) string {
id := ctx.Value(ctxValueRequestID)
if id == nil {
return ""
}
return id.(string)
}
// ctxValueLogger is the key to extract the logrus Logger.
const ctxValueLogger = contextKeys("logger")
// GetLogger retrieves the logrus logger from the supplied context. Returns nil if there is no logger.
func GetLogger(ctx context.Context) *log.Entry {
l := ctx.Value(ctxValueLogger)
if l == nil {
return nil
}
return l.(*log.Entry)
}

18
vendor/src/github.com/matrix-org/util/error.go

@ -1,18 +0,0 @@
package util
import "fmt"
// HTTPError An HTTP Error response, which may wrap an underlying native Go Error.
type HTTPError struct {
WrappedError error
Message string
Code int
}
func (e HTTPError) Error() string {
var wrappedErrMsg string
if e.WrappedError != nil {
wrappedErrMsg = e.WrappedError.Error()
}
return fmt.Sprintf("%s: %d: %s", e.Message, e.Code, wrappedErrMsg)
}

146
vendor/src/github.com/matrix-org/util/json.go

@ -3,50 +3,75 @@ package util
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"math/rand" "math/rand"
"net/http" "net/http"
"runtime/debug" "runtime/debug"
"time"
log "github.com/Sirupsen/logrus" log "github.com/Sirupsen/logrus"
) )
// ContextKeys is a type alias for string to namespace Context keys per-package.
type ContextKeys string
// JSONResponse represents an HTTP response which contains a JSON body.
type JSONResponse struct {
// HTTP status code.
Code int
// JSON represents the JSON that should be serialized and sent to the client
JSON interface{}
// Headers represent any headers that should be sent to the client
Headers map[string]string
}
// CtxValueLogger is the key to extract the logrus Logger.
const CtxValueLogger = ContextKeys("logger")
// Is2xx returns true if the Code is between 200 and 299.
func (r JSONResponse) Is2xx() bool {
return r.Code/100 == 2
}
// JSONRequestHandler represents an interface that must be satisfied in order to respond to incoming
// HTTP requests with JSON. The interface returned will be marshalled into JSON to be sent to the client,
// unless the interface is []byte in which case the bytes are sent to the client unchanged.
// If an error is returned, a JSON error response will also be returned, unless the error code
// is a 302 REDIRECT in which case a redirect is sent based on the Message field.
type JSONRequestHandler interface {
OnIncomingRequest(req *http.Request) (interface{}, *HTTPError)
// RedirectResponse returns a JSONResponse which 302s the client to the given location.
func RedirectResponse(location string) JSONResponse {
headers := make(map[string]string)
headers["Location"] = location
return JSONResponse{
Code: 302,
JSON: struct{}{},
Headers: headers,
}
} }
// JSONError represents a JSON API error response
type JSONError struct {
// MessageResponse returns a JSONResponse with a 'message' key containing the given text.
func MessageResponse(code int, msg string) JSONResponse {
return JSONResponse{
Code: code,
JSON: struct {
Message string `json:"message"` Message string `json:"message"`
}{msg},
}
}
// ErrorResponse returns an HTTP 500 JSONResponse with the stringified form of the given error.
func ErrorResponse(err error) JSONResponse {
return MessageResponse(500, err.Error())
}
// JSONRequestHandler represents an interface that must be satisfied in order to respond to incoming
// HTTP requests with JSON.
type JSONRequestHandler interface {
OnIncomingRequest(req *http.Request) JSONResponse
} }
// Protect panicking HTTP requests from taking down the entire process, and log them using // Protect panicking HTTP requests from taking down the entire process, and log them using
// the correct logger, returning a 500 with a JSON response rather than abruptly closing the // the correct logger, returning a 500 with a JSON response rather than abruptly closing the
// connection. The http.Request MUST have a CtxValueLogger.
// connection. The http.Request MUST have a ctxValueLogger.
func Protect(handler http.HandlerFunc) http.HandlerFunc { func Protect(handler http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) { return func(w http.ResponseWriter, req *http.Request) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
logger := req.Context().Value(CtxValueLogger).(*log.Entry)
logger := req.Context().Value(ctxValueLogger).(*log.Entry)
logger.WithFields(log.Fields{ logger.WithFields(log.Fields{
"panic": r, "panic": r,
}).Errorf( }).Errorf(
"Request panicked!\n%s", debug.Stack(), "Request panicked!\n%s", debug.Stack(),
) )
jsonErrorResponse(
w, req, &HTTPError{nil, "Internal Server Error", 500},
)
respond(w, req, MessageResponse(500, "Internal Server Error"))
} }
}() }()
handler(w, req) handler(w, req)
@ -55,72 +80,67 @@ func Protect(handler http.HandlerFunc) http.HandlerFunc {
// MakeJSONAPI creates an HTTP handler which always responds to incoming requests with JSON responses. // MakeJSONAPI creates an HTTP handler which always responds to incoming requests with JSON responses.
// Incoming http.Requests will have a logger (with a request ID/method/path logged) attached to the Context. // Incoming http.Requests will have a logger (with a request ID/method/path logged) attached to the Context.
// This can be accessed via the const CtxValueLogger. The type of the logger is *log.Entry from github.com/Sirupsen/logrus
// This can be accessed via GetLogger(Context).
func MakeJSONAPI(handler JSONRequestHandler) http.HandlerFunc { func MakeJSONAPI(handler JSONRequestHandler) http.HandlerFunc {
return Protect(func(w http.ResponseWriter, req *http.Request) { return Protect(func(w http.ResponseWriter, req *http.Request) {
// Set a Logger on the context
ctx := context.WithValue(req.Context(), CtxValueLogger, log.WithFields(log.Fields{
reqID := RandomString(12)
// Set a Logger and request ID on the context
ctx := context.WithValue(req.Context(), ctxValueLogger, log.WithFields(log.Fields{
"req.method": req.Method, "req.method": req.Method,
"req.path": req.URL.Path, "req.path": req.URL.Path,
"req.id": RandomString(12),
"req.id": reqID,
})) }))
ctx = context.WithValue(ctx, ctxValueRequestID, reqID)
req = req.WithContext(ctx) req = req.WithContext(ctx)
logger := req.Context().Value(CtxValueLogger).(*log.Entry)
logger := req.Context().Value(ctxValueLogger).(*log.Entry)
logger.Print("Incoming request") logger.Print("Incoming request")
res, httpErr := handler.OnIncomingRequest(req)
res := handler.OnIncomingRequest(req)
// Set common headers returned regardless of the outcome of the request // Set common headers returned regardless of the outcome of the request
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
SetCORSHeaders(w) SetCORSHeaders(w)
if httpErr != nil {
jsonErrorResponse(w, req, httpErr)
return
respond(w, req, res)
})
} }
// if they've returned bytes as the response, then just return them rather than marshalling as JSON.
// This gives handlers an escape hatch if they want to return cached bytes.
var resBytes []byte
resBytes, ok := res.([]byte)
if !ok {
r, err := json.Marshal(res)
if err != nil {
jsonErrorResponse(w, req, &HTTPError{nil, "Failed to serialise response as JSON", 500})
return
func respond(w http.ResponseWriter, req *http.Request, res JSONResponse) {
logger := req.Context().Value(ctxValueLogger).(*log.Entry)
// Set custom headers
if res.Headers != nil {
for h, val := range res.Headers {
w.Header().Set(h, val)
} }
resBytes = r
}
logger.Print(fmt.Sprintf("Responding (%d bytes)", len(resBytes)))
w.Write(resBytes)
})
} }
func jsonErrorResponse(w http.ResponseWriter, req *http.Request, httpErr *HTTPError) {
logger := req.Context().Value(CtxValueLogger).(*log.Entry)
if httpErr.Code == 302 {
logger.WithField("err", httpErr.Error()).Print("Redirecting")
http.Redirect(w, req, httpErr.Message, 302)
return
// Marshal JSON response into raw bytes to send as the HTTP body
resBytes, err := json.Marshal(res.JSON)
if err != nil {
logger.WithError(err).Error("Failed to marshal JSONResponse")
// this should never fail to be marshalled so drop err to the floor
res = MessageResponse(500, "Internal Server Error")
resBytes, _ = json.Marshal(res.JSON)
} }
logger.WithFields(log.Fields{
log.ErrorKey: httpErr,
}).Print("Responding with error")
w.WriteHeader(httpErr.Code) // Set response code
// Set status code and write the body
w.WriteHeader(res.Code)
logger.WithField("code", res.Code).Infof("Responding (%d bytes)", len(resBytes))
w.Write(resBytes)
}
r, err := json.Marshal(&JSONError{
Message: httpErr.Message,
})
if err != nil {
// We should never fail to marshal the JSON error response, but in this event just skip
// marshalling altogether
logger.Warn("Failed to marshal error response")
w.Write([]byte(`{}`))
// WithCORSOptions intercepts all OPTIONS requests and responds with CORS headers. The request handler
// is not invoked when this happens.
func WithCORSOptions(handler http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
if req.Method == "OPTIONS" {
SetCORSHeaders(w)
return return
} }
w.Write(r)
handler(w, req)
}
} }
// SetCORSHeaders sets unrestricted origin Access-Control headers on the response writer // SetCORSHeaders sets unrestricted origin Access-Control headers on the response writer
@ -140,3 +160,7 @@ func RandomString(n int) string {
} }
return string(b) return string(b)
} }
func init() {
rand.Seed(time.Now().UTC().UnixNano())
}

170
vendor/src/github.com/matrix-org/util/json_test.go

@ -2,6 +2,7 @@ package util
import ( import (
"context" "context"
"errors"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@ -10,10 +11,10 @@ import (
) )
type MockJSONRequestHandler struct { type MockJSONRequestHandler struct {
handler func(req *http.Request) (interface{}, *HTTPError)
handler func(req *http.Request) JSONResponse
} }
func (h *MockJSONRequestHandler) OnIncomingRequest(req *http.Request) (interface{}, *HTTPError) {
func (h *MockJSONRequestHandler) OnIncomingRequest(req *http.Request) JSONResponse {
return h.handler(req) return h.handler(req)
} }
@ -24,22 +25,27 @@ type MockResponse struct {
func TestMakeJSONAPI(t *testing.T) { func TestMakeJSONAPI(t *testing.T) {
log.SetLevel(log.PanicLevel) // suppress logs in test output log.SetLevel(log.PanicLevel) // suppress logs in test output
tests := []struct { tests := []struct {
Return interface{}
Err *HTTPError
Return JSONResponse
ExpectCode int ExpectCode int
ExpectJSON string ExpectJSON string
}{ }{
{nil, &HTTPError{nil, "Everything is broken", 500}, 500, `{"message":"Everything is broken"}`}, // Error return values
{nil, &HTTPError{nil, "Not here", 404}, 404, `{"message":"Not here"}`}, // With different status codes
{&MockResponse{"yep"}, nil, 200, `{"foo":"yep"}`}, // Success return values
{[]MockResponse{{"yep"}, {"narp"}}, nil, 200, `[{"foo":"yep"},{"foo":"narp"}]`}, // Top-level array success values
{[]byte(`actually bytes`), nil, 200, `actually bytes`}, // raw []byte escape hatch
{func(cannotBe, marshalled string) {}, nil, 500, `{"message":"Failed to serialise response as JSON"}`}, // impossible marshal
// MessageResponse return values
{MessageResponse(500, "Everything is broken"), 500, `{"message":"Everything is broken"}`},
// interface return values
{JSONResponse{500, MockResponse{"yep"}, nil}, 500, `{"foo":"yep"}`},
// Error JSON return values which fail to be marshalled should fallback to text
{JSONResponse{500, struct {
Foo interface{} `json:"foo"`
}{func(cannotBe, marshalled string) {}}, nil}, 500, `{"message":"Internal Server Error"}`},
// With different status codes
{JSONResponse{201, MockResponse{"narp"}, nil}, 201, `{"foo":"narp"}`},
// Top-level array success values
{JSONResponse{200, []MockResponse{{"yep"}, {"narp"}}, nil}, 200, `[{"foo":"yep"},{"foo":"narp"}]`},
} }
for _, tst := range tests { for _, tst := range tests {
mock := MockJSONRequestHandler{func(req *http.Request) (interface{}, *HTTPError) {
return tst.Return, tst.Err
mock := MockJSONRequestHandler{func(req *http.Request) JSONResponse {
return tst.Return
}} }}
mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil)
mockWriter := httptest.NewRecorder() mockWriter := httptest.NewRecorder()
@ -55,10 +61,38 @@ func TestMakeJSONAPI(t *testing.T) {
} }
} }
func TestMakeJSONAPICustomHeaders(t *testing.T) {
mock := MockJSONRequestHandler{func(req *http.Request) JSONResponse {
headers := make(map[string]string)
headers["Custom"] = "Thing"
headers["X-Custom"] = "Things"
return JSONResponse{
Code: 200,
JSON: MockResponse{"yep"},
Headers: headers,
}
}}
mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil)
mockWriter := httptest.NewRecorder()
handlerFunc := MakeJSONAPI(&mock)
handlerFunc(mockWriter, mockReq)
if mockWriter.Code != 200 {
t.Errorf("TestMakeJSONAPICustomHeaders wanted HTTP status 200, got %d", mockWriter.Code)
}
h := mockWriter.Header().Get("Custom")
if h != "Thing" {
t.Errorf("TestMakeJSONAPICustomHeaders wanted header 'Custom: Thing' , got 'Custom: %s'", h)
}
h = mockWriter.Header().Get("X-Custom")
if h != "Things" {
t.Errorf("TestMakeJSONAPICustomHeaders wanted header 'X-Custom: Things' , got 'X-Custom: %s'", h)
}
}
func TestMakeJSONAPIRedirect(t *testing.T) { func TestMakeJSONAPIRedirect(t *testing.T) {
log.SetLevel(log.PanicLevel) // suppress logs in test output log.SetLevel(log.PanicLevel) // suppress logs in test output
mock := MockJSONRequestHandler{func(req *http.Request) (interface{}, *HTTPError) {
return nil, &HTTPError{nil, "https://matrix.org", 302}
mock := MockJSONRequestHandler{func(req *http.Request) JSONResponse {
return RedirectResponse("https://matrix.org")
}} }}
mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil)
mockWriter := httptest.NewRecorder() mockWriter := httptest.NewRecorder()
@ -73,12 +107,74 @@ func TestMakeJSONAPIRedirect(t *testing.T) {
} }
} }
func TestMakeJSONAPIError(t *testing.T) {
log.SetLevel(log.PanicLevel) // suppress logs in test output
mock := MockJSONRequestHandler{func(req *http.Request) JSONResponse {
err := errors.New("oops")
return ErrorResponse(err)
}}
mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil)
mockWriter := httptest.NewRecorder()
handlerFunc := MakeJSONAPI(&mock)
handlerFunc(mockWriter, mockReq)
if mockWriter.Code != 500 {
t.Errorf("TestMakeJSONAPIError wanted HTTP status 500, got %d", mockWriter.Code)
}
actualBody := mockWriter.Body.String()
expect := `{"message":"oops"}`
if actualBody != expect {
t.Errorf("TestMakeJSONAPIError wanted body '%s', got '%s'", expect, actualBody)
}
}
func TestIs2xx(t *testing.T) {
tests := []struct {
Code int
Expect bool
}{
{200, true},
{201, true},
{299, true},
{300, false},
{199, false},
{0, false},
{500, false},
}
for _, test := range tests {
j := JSONResponse{
Code: test.Code,
}
actual := j.Is2xx()
if actual != test.Expect {
t.Errorf("TestIs2xx wanted %t, got %t", test.Expect, actual)
}
}
}
func TestGetLogger(t *testing.T) {
log.SetLevel(log.PanicLevel) // suppress logs in test output
entry := log.WithField("test", "yep")
mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil)
ctx := context.WithValue(mockReq.Context(), ctxValueLogger, entry)
mockReq = mockReq.WithContext(ctx)
ctxLogger := GetLogger(mockReq.Context())
if ctxLogger != entry {
t.Errorf("TestGetLogger wanted logger '%v', got '%v'", entry, ctxLogger)
}
noLoggerInReq, _ := http.NewRequest("GET", "http://example.com/foo", nil)
ctxLogger = GetLogger(noLoggerInReq.Context())
if ctxLogger != nil {
t.Errorf("TestGetLogger wanted nil logger, got '%v'", ctxLogger)
}
}
func TestProtect(t *testing.T) { func TestProtect(t *testing.T) {
log.SetLevel(log.PanicLevel) // suppress logs in test output log.SetLevel(log.PanicLevel) // suppress logs in test output
mockWriter := httptest.NewRecorder() mockWriter := httptest.NewRecorder()
mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil) mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil)
mockReq = mockReq.WithContext( mockReq = mockReq.WithContext(
context.WithValue(mockReq.Context(), CtxValueLogger, log.WithField("test", "yep")),
context.WithValue(mockReq.Context(), ctxValueLogger, log.WithField("test", "yep")),
) )
h := Protect(func(w http.ResponseWriter, req *http.Request) { h := Protect(func(w http.ResponseWriter, req *http.Request) {
panic("oh noes!") panic("oh noes!")
@ -97,3 +193,47 @@ func TestProtect(t *testing.T) {
t.Errorf("TestProtect wanted body %s, got %s", expectBody, actualBody) t.Errorf("TestProtect wanted body %s, got %s", expectBody, actualBody)
} }
} }
func TestWithCORSOptions(t *testing.T) {
log.SetLevel(log.PanicLevel) // suppress logs in test output
mockWriter := httptest.NewRecorder()
mockReq, _ := http.NewRequest("OPTIONS", "http://example.com/foo", nil)
h := WithCORSOptions(func(w http.ResponseWriter, req *http.Request) {
w.WriteHeader(200)
w.Write([]byte("yep"))
})
h(mockWriter, mockReq)
if mockWriter.Code != 200 {
t.Errorf("TestWithCORSOptions wanted HTTP status 200, got %d", mockWriter.Code)
}
origin := mockWriter.Header().Get("Access-Control-Allow-Origin")
if origin != "*" {
t.Errorf("TestWithCORSOptions wanted Access-Control-Allow-Origin header '*', got '%s'", origin)
}
// OPTIONS request shouldn't hit the handler func
expectBody := ""
actualBody := mockWriter.Body.String()
if actualBody != expectBody {
t.Errorf("TestWithCORSOptions wanted body %s, got %s", expectBody, actualBody)
}
}
func TestGetRequestID(t *testing.T) {
log.SetLevel(log.PanicLevel) // suppress logs in test output
reqID := "alphabetsoup"
mockReq, _ := http.NewRequest("GET", "http://example.com/foo", nil)
ctx := context.WithValue(mockReq.Context(), ctxValueRequestID, reqID)
mockReq = mockReq.WithContext(ctx)
ctxReqID := GetRequestID(mockReq.Context())
if reqID != ctxReqID {
t.Errorf("TestGetRequestID wanted request ID '%s', got '%s'", reqID, ctxReqID)
}
noReqIDInReq, _ := http.NewRequest("GET", "http://example.com/foo", nil)
ctxReqID = GetRequestID(noReqIDInReq.Context())
if ctxReqID != "" {
t.Errorf("TestGetRequestID wanted empty request ID, got '%s'", ctxReqID)
}
}
Loading…
Cancel
Save