diff --git a/src/github.com/matrix-org/go-neb/services/guggy/guggy_test.go b/src/github.com/matrix-org/go-neb/services/guggy/guggy_test.go index b7d5a05..480b8b8 100644 --- a/src/github.com/matrix-org/go-neb/services/guggy/guggy_test.go +++ b/src/github.com/matrix-org/go-neb/services/guggy/guggy_test.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/matrix-org/go-neb/database" "github.com/matrix-org/go-neb/matrix" + "github.com/matrix-org/go-neb/testutils" "github.com/matrix-org/go-neb/types" "io/ioutil" "net/http" @@ -14,14 +15,6 @@ import ( "testing" ) -type MockTransport struct { - roundTrip func(*http.Request) (*http.Response, error) -} - -func (t MockTransport) RoundTrip(req *http.Request) (*http.Response, error) { - return t.roundTrip(req) -} - // TODO: It would be nice to tabularise this test so we can try failing different combinations of responses to make // sure all cases are handled, rather than just the general case as is here. func TestCommand(t *testing.T) { @@ -30,8 +23,7 @@ func TestCommand(t *testing.T) { guggyImageURL := "https://guggy.com/gifs/23ryf872fg" // Mock the response from Guggy - guggyTrans := struct{ MockTransport }{} - guggyTrans.roundTrip = func(req *http.Request) (*http.Response, error) { + guggyTrans := testutils.NewRoundTripper(func(req *http.Request) (*http.Response, error) { guggyURL := "https://text2gif.guggy.com/guggify" if req.URL.String() != guggyURL { t.Fatalf("Bad URL: got %s want %s", req.URL.String(), guggyURL) @@ -65,7 +57,7 @@ func TestCommand(t *testing.T) { StatusCode: 200, Body: ioutil.NopCloser(bytes.NewBuffer(b)), }, nil - } + }) // clobber the guggy service http client instance httpClient = &http.Client{Transport: guggyTrans} @@ -79,8 +71,8 @@ func TestCommand(t *testing.T) { guggy := srv.(*Service) // Mock the response from Matrix - matrixTrans := struct{ MockTransport }{} - matrixTrans.roundTrip = func(req *http.Request) (*http.Response, error) { + matrixTrans := struct{ testutils.MockTransport }{} + matrixTrans.RT = func(req *http.Request) (*http.Response, error) { if req.URL.String() == guggyImageURL { // getting the guggy image return &http.Response{ StatusCode: 200, diff --git a/src/github.com/matrix-org/go-neb/services/rssbot/rssbot_test.go b/src/github.com/matrix-org/go-neb/services/rssbot/rssbot_test.go index cc6b260..4912d8c 100644 --- a/src/github.com/matrix-org/go-neb/services/rssbot/rssbot_test.go +++ b/src/github.com/matrix-org/go-neb/services/rssbot/rssbot_test.go @@ -14,6 +14,7 @@ import ( "github.com/matrix-org/go-neb/database" "github.com/matrix-org/go-neb/matrix" + "github.com/matrix-org/go-neb/testutils" "github.com/matrix-org/go-neb/types" ) @@ -36,20 +37,11 @@ const rssFeedXML = ` ` -type MockTransport struct { - roundTrip func(*http.Request) (*http.Response, error) -} - -func (t MockTransport) RoundTrip(req *http.Request) (*http.Response, error) { - return t.roundTrip(req) -} - func TestHTMLEntities(t *testing.T) { database.SetServiceDB(&database.NopStorage{}) feedURL := "https://thehappymaskshop.hyrule" // Replace the cachingClient with a mock so we can intercept RSS requests - rssTrans := struct{ MockTransport }{} - rssTrans.roundTrip = func(req *http.Request) (*http.Response, error) { + rssTrans := testutils.NewRoundTripper(func(req *http.Request) (*http.Response, error) { if req.URL.String() != feedURL { return nil, errors.New("Unknown test URL") } @@ -57,7 +49,7 @@ func TestHTMLEntities(t *testing.T) { StatusCode: 200, Body: ioutil.NopCloser(bytes.NewBufferString(rssFeedXML)), }, nil - } + }) cachingClient = &http.Client{Transport: rssTrans} // Create the RSS service @@ -79,8 +71,8 @@ func TestHTMLEntities(t *testing.T) { // Create the Matrix client which will send the notification wg := sync.WaitGroup{} wg.Add(1) - matrixTrans := struct{ MockTransport }{} - matrixTrans.roundTrip = func(req *http.Request) (*http.Response, error) { + matrixTrans := struct{ testutils.MockTransport }{} + matrixTrans.RT = func(req *http.Request) (*http.Response, error) { if strings.HasPrefix(req.URL.Path, "/_matrix/client/r0/rooms/!linksroom:hyrule/send/m.room.message") { // Check content body to make sure it is decoded var msg matrix.HTMLMessage diff --git a/src/github.com/matrix-org/go-neb/services/travisci/travisci_test.go b/src/github.com/matrix-org/go-neb/services/travisci/travisci_test.go index 5a8e2e5..3c1d695 100644 --- a/src/github.com/matrix-org/go-neb/services/travisci/travisci_test.go +++ b/src/github.com/matrix-org/go-neb/services/travisci/travisci_test.go @@ -13,6 +13,7 @@ import ( "github.com/matrix-org/go-neb/database" "github.com/matrix-org/go-neb/matrix" + "github.com/matrix-org/go-neb/testutils" "github.com/matrix-org/go-neb/types" ) @@ -91,14 +92,6 @@ var travisTests = []struct { }, } -type MockTransport struct { - roundTrip func(*http.Request) (*http.Response, error) -} - -func (t MockTransport) RoundTrip(req *http.Request) (*http.Response, error) { - return t.roundTrip(req) -} - func TestTravisCI(t *testing.T) { database.SetServiceDB(&database.NopStorage{}) @@ -106,8 +99,7 @@ func TestTravisCI(t *testing.T) { urlToKey := make(map[string]string) urlToKey["https://api.travis-ci.org/config"] = travisOrgPEMPublicKey urlToKey["https://api.travis-ci.com/config"] = travisComPEMPublicKey - travisTransport := struct{ MockTransport }{} - travisTransport.roundTrip = func(req *http.Request) (*http.Response, error) { + travisTransport := testutils.NewRoundTripper(func(req *http.Request) (*http.Response, error) { if key := urlToKey[req.URL.String()]; key != "" { escKey, _ := json.Marshal(key) return &http.Response{ @@ -118,14 +110,14 @@ func TestTravisCI(t *testing.T) { }, nil } return nil, fmt.Errorf("Unhandled URL %s", req.URL.String()) - } + }) // clobber the http client that the service uses to talk to Travis httpClient = &http.Client{Transport: travisTransport} // Intercept message sending to Matrix and mock responses msgs := []matrix.TextMessage{} - matrixTrans := struct{ MockTransport }{} - matrixTrans.roundTrip = func(req *http.Request) (*http.Response, error) { + matrixTrans := struct{ testutils.MockTransport }{} + matrixTrans.RT = func(req *http.Request) (*http.Response, error) { if !strings.Contains(req.URL.String(), "/send/m.room.message") { return nil, fmt.Errorf("Unhandled URL: %s", req.URL.String()) } diff --git a/src/github.com/matrix-org/go-neb/testutils/testutils.go b/src/github.com/matrix-org/go-neb/testutils/testutils.go new file mode 100644 index 0000000..ee678fa --- /dev/null +++ b/src/github.com/matrix-org/go-neb/testutils/testutils.go @@ -0,0 +1,28 @@ +package testutils + +import ( + "net/http" +) + +// MockTransport implements RoundTripper +type MockTransport struct { + // RT is the RoundTrip function. Replace this function with your test function. + // For example: + // t := MockTransport{} + // t.RT = func(req *http.Request) (*http.Response, error) { + // // assert req args, return res or error + // } + RT func(*http.Request) (*http.Response, error) +} + +// RoundTrip is a RoundTripper +func (t MockTransport) RoundTrip(req *http.Request) (*http.Response, error) { + return t.RT(req) +} + +// NewRoundTripper returns a new RoundTripper which will call the provided function. +func NewRoundTripper(roundTrip func(*http.Request) (*http.Response, error)) http.RoundTripper { + rt := MockTransport{} + rt.RT = roundTrip + return rt +}