From dc2fadd4509e8d2f06a50c2d470796bd5b585ce2 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Tue, 25 Oct 2016 11:27:42 +0100 Subject: [PATCH] Factor out startup process to `setup()` and add top-level test - Pass in an `envVars` struct rather than grabbing directly from `os`. - Accept a `ServeMux` rather than always using the default mux. - Add a `_test.go` file to instrument Go-NEB using `TestMain` to call `setup()` --- src/github.com/matrix-org/go-neb/goneb.go | 89 +++++++++++-------- .../matrix-org/go-neb/goneb_services_test.go | 31 +++++++ 2 files changed, 81 insertions(+), 39 deletions(-) create mode 100644 src/github.com/matrix-org/go-neb/goneb_services_test.go diff --git a/src/github.com/matrix-org/go-neb/goneb.go b/src/github.com/matrix-org/go-neb/goneb.go index 7b73ca2..329c8ee 100644 --- a/src/github.com/matrix-org/go-neb/goneb.go +++ b/src/github.com/matrix-org/go-neb/goneb.go @@ -139,42 +139,22 @@ func loadDatabase(databaseType, databaseURL, configYAML string) (*database.Servi return db, err } -func main() { - bindAddress := os.Getenv("BIND_ADDRESS") - databaseType := os.Getenv("DATABASE_TYPE") - databaseURL := os.Getenv("DATABASE_URL") - baseURL := os.Getenv("BASE_URL") - logDir := os.Getenv("LOG_DIR") - configYAML := os.Getenv("CONFIG_FILE") - - if logDir != "" { - log.AddHook(dugong.NewFSHook( - filepath.Join(logDir, "info.log"), - filepath.Join(logDir, "warn.log"), - filepath.Join(logDir, "error.log"), - )) - } - - log.Infof( - "Go-NEB (BIND_ADDRESS=%s DATABASE_TYPE=%s DATABASE_URL=%s BASE_URL=%s LOG_DIR=%s CONFIG_FILE=%s)", - bindAddress, databaseType, databaseURL, baseURL, logDir, configYAML, - ) - - err := types.BaseURL(baseURL) +func setup(e envVars, mux *http.ServeMux) { + err := types.BaseURL(e.BaseURL) if err != nil { log.WithError(err).Panic("Failed to get base url") } - db, err := loadDatabase(databaseType, databaseURL, configYAML) + db, err := loadDatabase(e.DatabaseType, e.DatabaseURL, e.ConfigFile) if err != nil { log.WithError(err).Panic("Failed to open database") } // Populate the database from the config file if one was supplied. var cfg *api.ConfigFile - if configYAML != "" { - if cfg, err = loadFromConfig(db, configYAML); err != nil { - log.WithError(err).WithField("config_file", configYAML).Panic("Failed to load config file") + if e.ConfigFile != "" { + if cfg, err = loadFromConfig(db, e.ConfigFile); err != nil { + log.WithError(err).WithField("config_file", e.ConfigFile).Panic("Failed to load config file") } if err := db.InsertFromConfig(cfg); err != nil { log.WithError(err).Panic("Failed to persist config data into in-memory DB") @@ -190,34 +170,65 @@ func main() { } // Handle non-admin paths for normal NEB functioning - http.Handle("/metrics", prometheus.Handler()) - http.Handle("/test", prometheus.InstrumentHandler("test", server.MakeJSONAPI(&heartbeatHandler{}))) + mux.Handle("/metrics", prometheus.Handler()) + mux.Handle("/test", prometheus.InstrumentHandler("test", server.MakeJSONAPI(&heartbeatHandler{}))) wh := &webhookHandler{db: db, clients: clients} - http.HandleFunc("/services/hooks/", prometheus.InstrumentHandlerFunc("webhookHandler", server.Protect(wh.handle))) + mux.HandleFunc("/services/hooks/", prometheus.InstrumentHandlerFunc("webhookHandler", server.Protect(wh.handle))) rh := &realmRedirectHandler{db: db} - http.HandleFunc("/realms/redirects/", prometheus.InstrumentHandlerFunc("realmRedirectHandler", server.Protect(rh.handle))) + mux.HandleFunc("/realms/redirects/", prometheus.InstrumentHandlerFunc("realmRedirectHandler", server.Protect(rh.handle))) // Read exclusively from the config file if one was supplied. // Otherwise, add HTTP listeners for new Services/Sessions/Clients/etc. - if configYAML != "" { + if e.ConfigFile != "" { if err := insertServicesFromConfig(clients, cfg.Services); err != nil { log.WithError(err).Panic("Failed to insert services") } log.Info("Inserted ", len(cfg.Services), " services") } else { - http.Handle("/admin/getService", prometheus.InstrumentHandler("getService", server.MakeJSONAPI(&getServiceHandler{db: db}))) - http.Handle("/admin/getSession", prometheus.InstrumentHandler("getSession", server.MakeJSONAPI(&getSessionHandler{db: db}))) - http.Handle("/admin/configureClient", prometheus.InstrumentHandler("configureClient", server.MakeJSONAPI(&configureClientHandler{db: db, clients: clients}))) - http.Handle("/admin/configureService", prometheus.InstrumentHandler("configureService", server.MakeJSONAPI(newConfigureServiceHandler(db, clients)))) - http.Handle("/admin/configureAuthRealm", prometheus.InstrumentHandler("configureAuthRealm", server.MakeJSONAPI(&configureAuthRealmHandler{db: db}))) - http.Handle("/admin/requestAuthSession", prometheus.InstrumentHandler("requestAuthSession", server.MakeJSONAPI(&requestAuthSessionHandler{db: db}))) - http.Handle("/admin/removeAuthSession", prometheus.InstrumentHandler("removeAuthSession", server.MakeJSONAPI(&removeAuthSessionHandler{db: db}))) + mux.Handle("/admin/getService", prometheus.InstrumentHandler("getService", server.MakeJSONAPI(&getServiceHandler{db: db}))) + mux.Handle("/admin/getSession", prometheus.InstrumentHandler("getSession", server.MakeJSONAPI(&getSessionHandler{db: db}))) + mux.Handle("/admin/configureClient", prometheus.InstrumentHandler("configureClient", server.MakeJSONAPI(&configureClientHandler{db: db, clients: clients}))) + mux.Handle("/admin/configureService", prometheus.InstrumentHandler("configureService", server.MakeJSONAPI(newConfigureServiceHandler(db, clients)))) + mux.Handle("/admin/configureAuthRealm", prometheus.InstrumentHandler("configureAuthRealm", server.MakeJSONAPI(&configureAuthRealmHandler{db: db}))) + mux.Handle("/admin/requestAuthSession", prometheus.InstrumentHandler("requestAuthSession", server.MakeJSONAPI(&requestAuthSessionHandler{db: db}))) + mux.Handle("/admin/removeAuthSession", prometheus.InstrumentHandler("removeAuthSession", server.MakeJSONAPI(&removeAuthSessionHandler{db: db}))) } polling.SetClients(clients) if err := polling.Start(); err != nil { log.WithError(err).Panic("Failed to start polling") } +} + +type envVars struct { + BindAddress string + DatabaseType string + DatabaseURL string + BaseURL string + LogDir string + ConfigFile string +} + +func main() { + e := envVars{ + BindAddress: os.Getenv("BIND_ADDRESS"), + DatabaseType: os.Getenv("DATABASE_TYPE"), + DatabaseURL: os.Getenv("DATABASE_URL"), + BaseURL: os.Getenv("BASE_URL"), + LogDir: os.Getenv("LOG_DIR"), + ConfigFile: os.Getenv("CONFIG_FILE"), + } + + if e.LogDir != "" { + log.AddHook(dugong.NewFSHook( + filepath.Join(e.LogDir, "info.log"), + filepath.Join(e.LogDir, "warn.log"), + filepath.Join(e.LogDir, "error.log"), + )) + } + + log.Infof("Go-NEB (%+v)", e) - log.Fatal(http.ListenAndServe(bindAddress, nil)) + setup(e, http.DefaultServeMux) + log.Fatal(http.ListenAndServe(e.BindAddress, nil)) } diff --git a/src/github.com/matrix-org/go-neb/goneb_services_test.go b/src/github.com/matrix-org/go-neb/goneb_services_test.go new file mode 100644 index 0000000..381998a --- /dev/null +++ b/src/github.com/matrix-org/go-neb/goneb_services_test.go @@ -0,0 +1,31 @@ +package main + +import ( + "net/http" + "net/http/httptest" + "os" + "testing" +) + +var mux = http.NewServeMux() + +func TestMain(m *testing.M) { + setup(envVars{ + BaseURL: "http://go.neb", + DatabaseType: "sqlite3", + DatabaseURL: ":memory:", + }, mux) + exitCode := m.Run() + os.Exit(exitCode) +} + +func TestNotFound(t *testing.T) { + mockWriter := httptest.NewRecorder() + mockReq, _ := http.NewRequest("GET", "http://go.neb/foo", nil) + mux.ServeHTTP(mockWriter, mockReq) + + expectCode := 404 + if mockWriter.Code != expectCode { + t.Errorf("TestNotFound wanted HTTP status %d, got %d", expectCode, mockWriter.Code) + } +}