diff --git a/src/github.com/matrix-org/go-neb/clients/clients.go b/src/github.com/matrix-org/go-neb/clients/clients.go index c6fa205..9192289 100644 --- a/src/github.com/matrix-org/go-neb/clients/clients.go +++ b/src/github.com/matrix-org/go-neb/clients/clients.go @@ -7,6 +7,7 @@ import ( "github.com/matrix-org/go-neb/matrix" "github.com/matrix-org/go-neb/plugin" "github.com/matrix-org/go-neb/types" + "net/http" "net/url" "strings" "sync" @@ -39,17 +40,19 @@ func (s nextBatchStore) Load(userID string) string { // A Clients is a collection of clients used for bot services. type Clients struct { - db *database.ServiceDB - dbMutex sync.Mutex - mapMutex sync.Mutex - clients map[string]clientEntry + db *database.ServiceDB + httpClient *http.Client + dbMutex sync.Mutex + mapMutex sync.Mutex + clients map[string]clientEntry } // New makes a new collection of matrix clients -func New(db *database.ServiceDB) *Clients { +func New(db *database.ServiceDB, cli *http.Client) *Clients { clients := &Clients{ - db: db, - clients: make(map[string]clientEntry), // user_id => clientEntry + db: db, + httpClient: cli, + clients: make(map[string]clientEntry), // user_id => clientEntry } return clients } @@ -238,7 +241,7 @@ func (c *Clients) newClient(config api.ClientConfig) (*matrix.Client, error) { return nil, err } - client := matrix.NewClient(homeserverURL, config.AccessToken, config.UserID) + client := matrix.NewClient(c.httpClient, homeserverURL, config.AccessToken, config.UserID) client.NextBatchStorer = nextBatchStore{c.db} // TODO: Check that the access token is valid for the userID by peforming diff --git a/src/github.com/matrix-org/go-neb/goneb.go b/src/github.com/matrix-org/go-neb/goneb.go index 329c8ee..ad70572 100644 --- a/src/github.com/matrix-org/go-neb/goneb.go +++ b/src/github.com/matrix-org/go-neb/goneb.go @@ -139,7 +139,7 @@ func loadDatabase(databaseType, databaseURL, configYAML string) (*database.Servi return db, err } -func setup(e envVars, mux *http.ServeMux) { +func setup(e envVars, mux *http.ServeMux, matrixClient *http.Client) { err := types.BaseURL(e.BaseURL) if err != nil { log.WithError(err).Panic("Failed to get base url") @@ -164,7 +164,7 @@ func setup(e envVars, mux *http.ServeMux) { log.Info("Inserted ", len(cfg.Sessions), " sessions") } - clients := clients.New(db) + clients := clients.New(db, matrixClient) if err := clients.Start(); err != nil { log.WithError(err).Panic("Failed to start up clients") } @@ -229,6 +229,6 @@ func main() { log.Infof("Go-NEB (%+v)", e) - setup(e, http.DefaultServeMux) + setup(e, http.DefaultServeMux, http.DefaultClient) log.Fatal(http.ListenAndServe(e.BindAddress, nil)) } diff --git a/src/github.com/matrix-org/go-neb/goneb_services_test.go b/src/github.com/matrix-org/go-neb/goneb_services_test.go index 381998a..6e500b7 100644 --- a/src/github.com/matrix-org/go-neb/goneb_services_test.go +++ b/src/github.com/matrix-org/go-neb/goneb_services_test.go @@ -1,31 +1,103 @@ package main import ( + "bytes" "net/http" "net/http/httptest" "os" + "strconv" "testing" ) var mux = http.NewServeMux() +type MockTripper struct { + handlers map[string]func(req *http.Request) (*http.Response, error) +} + +func (rt MockTripper) RoundTrip(req *http.Request) (*http.Response, error) { + key := req.Method + " " + req.URL.Path + h := rt.handlers[key] + if h == nil { + panic( + "Test RoundTrip: Unhandled request: " + key + "\n" + + "Handlers: " + strconv.Itoa(len(rt.handlers)), + ) + } + return h(req) +} + +func (rt MockTripper) Handle(method, path string, handler func(req *http.Request) (*http.Response, error)) { + key := method + " " + path + if _, exists := rt.handlers[key]; exists { + panic("Test handler with key " + key + " already exists") + } + rt.handlers[key] = handler +} + +var tripper = MockTripper{make(map[string]func(req *http.Request) (*http.Response, error))} + +type nopCloser struct { + *bytes.Buffer +} + +func (nopCloser) Close() error { return nil } + +func newResponse(statusCode int, body string) *http.Response { + return &http.Response{ + StatusCode: statusCode, + Body: nopCloser{bytes.NewBufferString(body)}, + } +} + func TestMain(m *testing.M) { setup(envVars{ BaseURL: "http://go.neb", DatabaseType: "sqlite3", DatabaseURL: ":memory:", - }, mux) + }, mux, &http.Client{ + Transport: tripper, + }) exitCode := m.Run() os.Exit(exitCode) } -func TestNotFound(t *testing.T) { +func TestConfigureClient(t *testing.T) { + for k := range tripper.handlers { + delete(tripper.handlers, k) + } mockWriter := httptest.NewRecorder() - mockReq, _ := http.NewRequest("GET", "http://go.neb/foo", nil) - mux.ServeHTTP(mockWriter, mockReq) + tripper.Handle("POST", "/_matrix/client/r0/user/@link:hyrule/filter", + func(req *http.Request) (*http.Response, error) { + return newResponse(200, `{ + "filter_id":"abcdef" + }`), nil + }, + ) + syncChan := make(chan string) + tripper.Handle("GET", "/_matrix/client/r0/sync", + func(req *http.Request) (*http.Response, error) { + syncChan <- "sync" + return newResponse(200, `{ + "next_batch":"11_22_33_44", + "rooms": {} + }`), nil + }, + ) - expectCode := 404 + mockReq, _ := http.NewRequest("POST", "http://go.neb/admin/configureClient", bytes.NewBufferString(` + { + "UserID":"@link:hyrule", + "HomeserverURL":"http://hyrule.loz", + "AccessToken":"dangeroustogoalone", + "Sync":true, + "AutoJoinRooms":true + }`)) + mux.ServeHTTP(mockWriter, mockReq) + expectCode := 200 if mockWriter.Code != expectCode { - t.Errorf("TestNotFound wanted HTTP status %d, got %d", expectCode, mockWriter.Code) + t.Errorf("TestConfigureClient wanted HTTP status %d, got %d", expectCode, mockWriter.Code) } + + <-syncChan } diff --git a/src/github.com/matrix-org/go-neb/matrix/matrix.go b/src/github.com/matrix-org/go-neb/matrix/matrix.go index 0214a18..bae4edb 100644 --- a/src/github.com/matrix-org/go-neb/matrix/matrix.go +++ b/src/github.com/matrix-org/go-neb/matrix/matrix.go @@ -404,7 +404,11 @@ func (cli *Client) doSync(timeout int, since string) ([]byte, error) { query["filter"] = cli.filterID } urlPath := cli.buildURLWithQuery([]string{"sync"}, query) - res, err := http.Get(urlPath) + req, err := http.NewRequest("GET", urlPath, nil) + if err != nil { + return nil, err + } + res, err := cli.httpClient.Do(req) if err != nil { return nil, err } @@ -417,7 +421,7 @@ func (cli *Client) doSync(timeout int, since string) ([]byte, error) { } // NewClient creates a new Matrix Client ready for syncing -func NewClient(homeserverURL *url.URL, accessToken string, userID string) *Client { +func NewClient(httpClient *http.Client, homeserverURL *url.URL, accessToken, userID string) *Client { cli := Client{ AccessToken: accessToken, HomeserverURL: homeserverURL, @@ -430,7 +434,7 @@ func NewClient(homeserverURL *url.URL, accessToken string, userID string) *Clien // remember the token across restarts. In practice, a database backend should be used. cli.NextBatchStorer = noopNextBatchStore{} cli.Rooms = make(map[string]*Room) - cli.httpClient = &http.Client{} + cli.httpClient = httpClient return &cli }