From dc2fadd4509e8d2f06a50c2d470796bd5b585ce2 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Tue, 25 Oct 2016 11:27:42 +0100 Subject: [PATCH 1/3] Factor out startup process to `setup()` and add top-level test - Pass in an `envVars` struct rather than grabbing directly from `os`. - Accept a `ServeMux` rather than always using the default mux. - Add a `_test.go` file to instrument Go-NEB using `TestMain` to call `setup()` --- src/github.com/matrix-org/go-neb/goneb.go | 89 +++++++++++-------- .../matrix-org/go-neb/goneb_services_test.go | 31 +++++++ 2 files changed, 81 insertions(+), 39 deletions(-) create mode 100644 src/github.com/matrix-org/go-neb/goneb_services_test.go diff --git a/src/github.com/matrix-org/go-neb/goneb.go b/src/github.com/matrix-org/go-neb/goneb.go index 7b73ca2..329c8ee 100644 --- a/src/github.com/matrix-org/go-neb/goneb.go +++ b/src/github.com/matrix-org/go-neb/goneb.go @@ -139,42 +139,22 @@ func loadDatabase(databaseType, databaseURL, configYAML string) (*database.Servi return db, err } -func main() { - bindAddress := os.Getenv("BIND_ADDRESS") - databaseType := os.Getenv("DATABASE_TYPE") - databaseURL := os.Getenv("DATABASE_URL") - baseURL := os.Getenv("BASE_URL") - logDir := os.Getenv("LOG_DIR") - configYAML := os.Getenv("CONFIG_FILE") - - if logDir != "" { - log.AddHook(dugong.NewFSHook( - filepath.Join(logDir, "info.log"), - filepath.Join(logDir, "warn.log"), - filepath.Join(logDir, "error.log"), - )) - } - - log.Infof( - "Go-NEB (BIND_ADDRESS=%s DATABASE_TYPE=%s DATABASE_URL=%s BASE_URL=%s LOG_DIR=%s CONFIG_FILE=%s)", - bindAddress, databaseType, databaseURL, baseURL, logDir, configYAML, - ) - - err := types.BaseURL(baseURL) +func setup(e envVars, mux *http.ServeMux) { + err := types.BaseURL(e.BaseURL) if err != nil { log.WithError(err).Panic("Failed to get base url") } - db, err := loadDatabase(databaseType, databaseURL, configYAML) + db, err := loadDatabase(e.DatabaseType, e.DatabaseURL, e.ConfigFile) if err != nil { log.WithError(err).Panic("Failed to open database") } // Populate the database from the config file if one was supplied. var cfg *api.ConfigFile - if configYAML != "" { - if cfg, err = loadFromConfig(db, configYAML); err != nil { - log.WithError(err).WithField("config_file", configYAML).Panic("Failed to load config file") + if e.ConfigFile != "" { + if cfg, err = loadFromConfig(db, e.ConfigFile); err != nil { + log.WithError(err).WithField("config_file", e.ConfigFile).Panic("Failed to load config file") } if err := db.InsertFromConfig(cfg); err != nil { log.WithError(err).Panic("Failed to persist config data into in-memory DB") @@ -190,34 +170,65 @@ func main() { } // Handle non-admin paths for normal NEB functioning - http.Handle("/metrics", prometheus.Handler()) - http.Handle("/test", prometheus.InstrumentHandler("test", server.MakeJSONAPI(&heartbeatHandler{}))) + mux.Handle("/metrics", prometheus.Handler()) + mux.Handle("/test", prometheus.InstrumentHandler("test", server.MakeJSONAPI(&heartbeatHandler{}))) wh := &webhookHandler{db: db, clients: clients} - http.HandleFunc("/services/hooks/", prometheus.InstrumentHandlerFunc("webhookHandler", server.Protect(wh.handle))) + mux.HandleFunc("/services/hooks/", prometheus.InstrumentHandlerFunc("webhookHandler", server.Protect(wh.handle))) rh := &realmRedirectHandler{db: db} - http.HandleFunc("/realms/redirects/", prometheus.InstrumentHandlerFunc("realmRedirectHandler", server.Protect(rh.handle))) + mux.HandleFunc("/realms/redirects/", prometheus.InstrumentHandlerFunc("realmRedirectHandler", server.Protect(rh.handle))) // Read exclusively from the config file if one was supplied. // Otherwise, add HTTP listeners for new Services/Sessions/Clients/etc. - if configYAML != "" { + if e.ConfigFile != "" { if err := insertServicesFromConfig(clients, cfg.Services); err != nil { log.WithError(err).Panic("Failed to insert services") } log.Info("Inserted ", len(cfg.Services), " services") } else { - http.Handle("/admin/getService", prometheus.InstrumentHandler("getService", server.MakeJSONAPI(&getServiceHandler{db: db}))) - http.Handle("/admin/getSession", prometheus.InstrumentHandler("getSession", server.MakeJSONAPI(&getSessionHandler{db: db}))) - http.Handle("/admin/configureClient", prometheus.InstrumentHandler("configureClient", server.MakeJSONAPI(&configureClientHandler{db: db, clients: clients}))) - http.Handle("/admin/configureService", prometheus.InstrumentHandler("configureService", server.MakeJSONAPI(newConfigureServiceHandler(db, clients)))) - http.Handle("/admin/configureAuthRealm", prometheus.InstrumentHandler("configureAuthRealm", server.MakeJSONAPI(&configureAuthRealmHandler{db: db}))) - http.Handle("/admin/requestAuthSession", prometheus.InstrumentHandler("requestAuthSession", server.MakeJSONAPI(&requestAuthSessionHandler{db: db}))) - http.Handle("/admin/removeAuthSession", prometheus.InstrumentHandler("removeAuthSession", server.MakeJSONAPI(&removeAuthSessionHandler{db: db}))) + mux.Handle("/admin/getService", prometheus.InstrumentHandler("getService", server.MakeJSONAPI(&getServiceHandler{db: db}))) + mux.Handle("/admin/getSession", prometheus.InstrumentHandler("getSession", server.MakeJSONAPI(&getSessionHandler{db: db}))) + mux.Handle("/admin/configureClient", prometheus.InstrumentHandler("configureClient", server.MakeJSONAPI(&configureClientHandler{db: db, clients: clients}))) + mux.Handle("/admin/configureService", prometheus.InstrumentHandler("configureService", server.MakeJSONAPI(newConfigureServiceHandler(db, clients)))) + mux.Handle("/admin/configureAuthRealm", prometheus.InstrumentHandler("configureAuthRealm", server.MakeJSONAPI(&configureAuthRealmHandler{db: db}))) + mux.Handle("/admin/requestAuthSession", prometheus.InstrumentHandler("requestAuthSession", server.MakeJSONAPI(&requestAuthSessionHandler{db: db}))) + mux.Handle("/admin/removeAuthSession", prometheus.InstrumentHandler("removeAuthSession", server.MakeJSONAPI(&removeAuthSessionHandler{db: db}))) } polling.SetClients(clients) if err := polling.Start(); err != nil { log.WithError(err).Panic("Failed to start polling") } +} + +type envVars struct { + BindAddress string + DatabaseType string + DatabaseURL string + BaseURL string + LogDir string + ConfigFile string +} + +func main() { + e := envVars{ + BindAddress: os.Getenv("BIND_ADDRESS"), + DatabaseType: os.Getenv("DATABASE_TYPE"), + DatabaseURL: os.Getenv("DATABASE_URL"), + BaseURL: os.Getenv("BASE_URL"), + LogDir: os.Getenv("LOG_DIR"), + ConfigFile: os.Getenv("CONFIG_FILE"), + } + + if e.LogDir != "" { + log.AddHook(dugong.NewFSHook( + filepath.Join(e.LogDir, "info.log"), + filepath.Join(e.LogDir, "warn.log"), + filepath.Join(e.LogDir, "error.log"), + )) + } + + log.Infof("Go-NEB (%+v)", e) - log.Fatal(http.ListenAndServe(bindAddress, nil)) + setup(e, http.DefaultServeMux) + log.Fatal(http.ListenAndServe(e.BindAddress, nil)) } diff --git a/src/github.com/matrix-org/go-neb/goneb_services_test.go b/src/github.com/matrix-org/go-neb/goneb_services_test.go new file mode 100644 index 0000000..381998a --- /dev/null +++ b/src/github.com/matrix-org/go-neb/goneb_services_test.go @@ -0,0 +1,31 @@ +package main + +import ( + "net/http" + "net/http/httptest" + "os" + "testing" +) + +var mux = http.NewServeMux() + +func TestMain(m *testing.M) { + setup(envVars{ + BaseURL: "http://go.neb", + DatabaseType: "sqlite3", + DatabaseURL: ":memory:", + }, mux) + exitCode := m.Run() + os.Exit(exitCode) +} + +func TestNotFound(t *testing.T) { + mockWriter := httptest.NewRecorder() + mockReq, _ := http.NewRequest("GET", "http://go.neb/foo", nil) + mux.ServeHTTP(mockWriter, mockReq) + + expectCode := 404 + if mockWriter.Code != expectCode { + t.Errorf("TestNotFound wanted HTTP status %d, got %d", expectCode, mockWriter.Code) + } +} From 5c65d4cf958b9d43582d48240293a8ebd489a6bc Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Tue, 25 Oct 2016 15:48:04 +0100 Subject: [PATCH 2/3] Add a mock HTTP client for Matrix clients to use. Add example test. This will need rejigging at some point to make the test easier to set up. --- .../matrix-org/go-neb/clients/clients.go | 19 +++-- src/github.com/matrix-org/go-neb/goneb.go | 6 +- .../matrix-org/go-neb/goneb_services_test.go | 84 +++++++++++++++++-- .../matrix-org/go-neb/matrix/matrix.go | 10 ++- 4 files changed, 99 insertions(+), 20 deletions(-) diff --git a/src/github.com/matrix-org/go-neb/clients/clients.go b/src/github.com/matrix-org/go-neb/clients/clients.go index c6fa205..9192289 100644 --- a/src/github.com/matrix-org/go-neb/clients/clients.go +++ b/src/github.com/matrix-org/go-neb/clients/clients.go @@ -7,6 +7,7 @@ import ( "github.com/matrix-org/go-neb/matrix" "github.com/matrix-org/go-neb/plugin" "github.com/matrix-org/go-neb/types" + "net/http" "net/url" "strings" "sync" @@ -39,17 +40,19 @@ func (s nextBatchStore) Load(userID string) string { // A Clients is a collection of clients used for bot services. type Clients struct { - db *database.ServiceDB - dbMutex sync.Mutex - mapMutex sync.Mutex - clients map[string]clientEntry + db *database.ServiceDB + httpClient *http.Client + dbMutex sync.Mutex + mapMutex sync.Mutex + clients map[string]clientEntry } // New makes a new collection of matrix clients -func New(db *database.ServiceDB) *Clients { +func New(db *database.ServiceDB, cli *http.Client) *Clients { clients := &Clients{ - db: db, - clients: make(map[string]clientEntry), // user_id => clientEntry + db: db, + httpClient: cli, + clients: make(map[string]clientEntry), // user_id => clientEntry } return clients } @@ -238,7 +241,7 @@ func (c *Clients) newClient(config api.ClientConfig) (*matrix.Client, error) { return nil, err } - client := matrix.NewClient(homeserverURL, config.AccessToken, config.UserID) + client := matrix.NewClient(c.httpClient, homeserverURL, config.AccessToken, config.UserID) client.NextBatchStorer = nextBatchStore{c.db} // TODO: Check that the access token is valid for the userID by peforming diff --git a/src/github.com/matrix-org/go-neb/goneb.go b/src/github.com/matrix-org/go-neb/goneb.go index 329c8ee..ad70572 100644 --- a/src/github.com/matrix-org/go-neb/goneb.go +++ b/src/github.com/matrix-org/go-neb/goneb.go @@ -139,7 +139,7 @@ func loadDatabase(databaseType, databaseURL, configYAML string) (*database.Servi return db, err } -func setup(e envVars, mux *http.ServeMux) { +func setup(e envVars, mux *http.ServeMux, matrixClient *http.Client) { err := types.BaseURL(e.BaseURL) if err != nil { log.WithError(err).Panic("Failed to get base url") @@ -164,7 +164,7 @@ func setup(e envVars, mux *http.ServeMux) { log.Info("Inserted ", len(cfg.Sessions), " sessions") } - clients := clients.New(db) + clients := clients.New(db, matrixClient) if err := clients.Start(); err != nil { log.WithError(err).Panic("Failed to start up clients") } @@ -229,6 +229,6 @@ func main() { log.Infof("Go-NEB (%+v)", e) - setup(e, http.DefaultServeMux) + setup(e, http.DefaultServeMux, http.DefaultClient) log.Fatal(http.ListenAndServe(e.BindAddress, nil)) } diff --git a/src/github.com/matrix-org/go-neb/goneb_services_test.go b/src/github.com/matrix-org/go-neb/goneb_services_test.go index 381998a..6e500b7 100644 --- a/src/github.com/matrix-org/go-neb/goneb_services_test.go +++ b/src/github.com/matrix-org/go-neb/goneb_services_test.go @@ -1,31 +1,103 @@ package main import ( + "bytes" "net/http" "net/http/httptest" "os" + "strconv" "testing" ) var mux = http.NewServeMux() +type MockTripper struct { + handlers map[string]func(req *http.Request) (*http.Response, error) +} + +func (rt MockTripper) RoundTrip(req *http.Request) (*http.Response, error) { + key := req.Method + " " + req.URL.Path + h := rt.handlers[key] + if h == nil { + panic( + "Test RoundTrip: Unhandled request: " + key + "\n" + + "Handlers: " + strconv.Itoa(len(rt.handlers)), + ) + } + return h(req) +} + +func (rt MockTripper) Handle(method, path string, handler func(req *http.Request) (*http.Response, error)) { + key := method + " " + path + if _, exists := rt.handlers[key]; exists { + panic("Test handler with key " + key + " already exists") + } + rt.handlers[key] = handler +} + +var tripper = MockTripper{make(map[string]func(req *http.Request) (*http.Response, error))} + +type nopCloser struct { + *bytes.Buffer +} + +func (nopCloser) Close() error { return nil } + +func newResponse(statusCode int, body string) *http.Response { + return &http.Response{ + StatusCode: statusCode, + Body: nopCloser{bytes.NewBufferString(body)}, + } +} + func TestMain(m *testing.M) { setup(envVars{ BaseURL: "http://go.neb", DatabaseType: "sqlite3", DatabaseURL: ":memory:", - }, mux) + }, mux, &http.Client{ + Transport: tripper, + }) exitCode := m.Run() os.Exit(exitCode) } -func TestNotFound(t *testing.T) { +func TestConfigureClient(t *testing.T) { + for k := range tripper.handlers { + delete(tripper.handlers, k) + } mockWriter := httptest.NewRecorder() - mockReq, _ := http.NewRequest("GET", "http://go.neb/foo", nil) - mux.ServeHTTP(mockWriter, mockReq) + tripper.Handle("POST", "/_matrix/client/r0/user/@link:hyrule/filter", + func(req *http.Request) (*http.Response, error) { + return newResponse(200, `{ + "filter_id":"abcdef" + }`), nil + }, + ) + syncChan := make(chan string) + tripper.Handle("GET", "/_matrix/client/r0/sync", + func(req *http.Request) (*http.Response, error) { + syncChan <- "sync" + return newResponse(200, `{ + "next_batch":"11_22_33_44", + "rooms": {} + }`), nil + }, + ) - expectCode := 404 + mockReq, _ := http.NewRequest("POST", "http://go.neb/admin/configureClient", bytes.NewBufferString(` + { + "UserID":"@link:hyrule", + "HomeserverURL":"http://hyrule.loz", + "AccessToken":"dangeroustogoalone", + "Sync":true, + "AutoJoinRooms":true + }`)) + mux.ServeHTTP(mockWriter, mockReq) + expectCode := 200 if mockWriter.Code != expectCode { - t.Errorf("TestNotFound wanted HTTP status %d, got %d", expectCode, mockWriter.Code) + t.Errorf("TestConfigureClient wanted HTTP status %d, got %d", expectCode, mockWriter.Code) } + + <-syncChan } diff --git a/src/github.com/matrix-org/go-neb/matrix/matrix.go b/src/github.com/matrix-org/go-neb/matrix/matrix.go index 0214a18..bae4edb 100644 --- a/src/github.com/matrix-org/go-neb/matrix/matrix.go +++ b/src/github.com/matrix-org/go-neb/matrix/matrix.go @@ -404,7 +404,11 @@ func (cli *Client) doSync(timeout int, since string) ([]byte, error) { query["filter"] = cli.filterID } urlPath := cli.buildURLWithQuery([]string{"sync"}, query) - res, err := http.Get(urlPath) + req, err := http.NewRequest("GET", urlPath, nil) + if err != nil { + return nil, err + } + res, err := cli.httpClient.Do(req) if err != nil { return nil, err } @@ -417,7 +421,7 @@ func (cli *Client) doSync(timeout int, since string) ([]byte, error) { } // NewClient creates a new Matrix Client ready for syncing -func NewClient(homeserverURL *url.URL, accessToken string, userID string) *Client { +func NewClient(httpClient *http.Client, homeserverURL *url.URL, accessToken, userID string) *Client { cli := Client{ AccessToken: accessToken, HomeserverURL: homeserverURL, @@ -430,7 +434,7 @@ func NewClient(homeserverURL *url.URL, accessToken string, userID string) *Clien // remember the token across restarts. In practice, a database backend should be used. cli.NextBatchStorer = noopNextBatchStore{} cli.Rooms = make(map[string]*Room) - cli.httpClient = &http.Client{} + cli.httpClient = httpClient return &cli } From 352d7415582568f5131ee15579b04280cbbbc48f Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Tue, 25 Oct 2016 16:14:16 +0100 Subject: [PATCH 3/3] Review comments --- hooks/pre-commit | 2 +- src/github.com/matrix-org/go-neb/goneb_services_test.go | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/hooks/pre-commit b/hooks/pre-commit index d9ffbfb..09ff264 100755 --- a/hooks/pre-commit +++ b/hooks/pre-commit @@ -6,4 +6,4 @@ golint src/... go fmt ./src/... go tool vet --shadow ./src gocyclo -over 12 src/ -gb test +gb test -timeout 5s diff --git a/src/github.com/matrix-org/go-neb/goneb_services_test.go b/src/github.com/matrix-org/go-neb/goneb_services_test.go index 6e500b7..baa925a 100644 --- a/src/github.com/matrix-org/go-neb/goneb_services_test.go +++ b/src/github.com/matrix-org/go-neb/goneb_services_test.go @@ -2,10 +2,10 @@ package main import ( "bytes" + "fmt" "net/http" "net/http/httptest" "os" - "strconv" "testing" ) @@ -20,8 +20,8 @@ func (rt MockTripper) RoundTrip(req *http.Request) (*http.Response, error) { h := rt.handlers[key] if h == nil { panic( - "Test RoundTrip: Unhandled request: " + key + "\n" + - "Handlers: " + strconv.Itoa(len(rt.handlers)), + fmt.Sprintf("Test RoundTrip: Unhandled request: %s\nHandlers: %d", + key, len(rt.handlers)), ) } return h(req)