Browse Source

Factor out startup process to `setup()` and add top-level test

- Pass in an `envVars` struct rather than grabbing directly from `os`.
- Accept a `ServeMux` rather than always using the default mux.
- Add a `_test.go` file to instrument Go-NEB using `TestMain` to call `setup()`
kegan/tests-prep
Kegan Dougal 8 years ago
parent
commit
dc2fadd450
  1. 89
      src/github.com/matrix-org/go-neb/goneb.go
  2. 31
      src/github.com/matrix-org/go-neb/goneb_services_test.go

89
src/github.com/matrix-org/go-neb/goneb.go

@ -139,42 +139,22 @@ func loadDatabase(databaseType, databaseURL, configYAML string) (*database.Servi
return db, err
}
func main() {
bindAddress := os.Getenv("BIND_ADDRESS")
databaseType := os.Getenv("DATABASE_TYPE")
databaseURL := os.Getenv("DATABASE_URL")
baseURL := os.Getenv("BASE_URL")
logDir := os.Getenv("LOG_DIR")
configYAML := os.Getenv("CONFIG_FILE")
if logDir != "" {
log.AddHook(dugong.NewFSHook(
filepath.Join(logDir, "info.log"),
filepath.Join(logDir, "warn.log"),
filepath.Join(logDir, "error.log"),
))
}
log.Infof(
"Go-NEB (BIND_ADDRESS=%s DATABASE_TYPE=%s DATABASE_URL=%s BASE_URL=%s LOG_DIR=%s CONFIG_FILE=%s)",
bindAddress, databaseType, databaseURL, baseURL, logDir, configYAML,
)
err := types.BaseURL(baseURL)
func setup(e envVars, mux *http.ServeMux) {
err := types.BaseURL(e.BaseURL)
if err != nil {
log.WithError(err).Panic("Failed to get base url")
}
db, err := loadDatabase(databaseType, databaseURL, configYAML)
db, err := loadDatabase(e.DatabaseType, e.DatabaseURL, e.ConfigFile)
if err != nil {
log.WithError(err).Panic("Failed to open database")
}
// Populate the database from the config file if one was supplied.
var cfg *api.ConfigFile
if configYAML != "" {
if cfg, err = loadFromConfig(db, configYAML); err != nil {
log.WithError(err).WithField("config_file", configYAML).Panic("Failed to load config file")
if e.ConfigFile != "" {
if cfg, err = loadFromConfig(db, e.ConfigFile); err != nil {
log.WithError(err).WithField("config_file", e.ConfigFile).Panic("Failed to load config file")
}
if err := db.InsertFromConfig(cfg); err != nil {
log.WithError(err).Panic("Failed to persist config data into in-memory DB")
@ -190,34 +170,65 @@ func main() {
}
// Handle non-admin paths for normal NEB functioning
http.Handle("/metrics", prometheus.Handler())
http.Handle("/test", prometheus.InstrumentHandler("test", server.MakeJSONAPI(&heartbeatHandler{})))
mux.Handle("/metrics", prometheus.Handler())
mux.Handle("/test", prometheus.InstrumentHandler("test", server.MakeJSONAPI(&heartbeatHandler{})))
wh := &webhookHandler{db: db, clients: clients}
http.HandleFunc("/services/hooks/", prometheus.InstrumentHandlerFunc("webhookHandler", server.Protect(wh.handle)))
mux.HandleFunc("/services/hooks/", prometheus.InstrumentHandlerFunc("webhookHandler", server.Protect(wh.handle)))
rh := &realmRedirectHandler{db: db}
http.HandleFunc("/realms/redirects/", prometheus.InstrumentHandlerFunc("realmRedirectHandler", server.Protect(rh.handle)))
mux.HandleFunc("/realms/redirects/", prometheus.InstrumentHandlerFunc("realmRedirectHandler", server.Protect(rh.handle)))
// Read exclusively from the config file if one was supplied.
// Otherwise, add HTTP listeners for new Services/Sessions/Clients/etc.
if configYAML != "" {
if e.ConfigFile != "" {
if err := insertServicesFromConfig(clients, cfg.Services); err != nil {
log.WithError(err).Panic("Failed to insert services")
}
log.Info("Inserted ", len(cfg.Services), " services")
} else {
http.Handle("/admin/getService", prometheus.InstrumentHandler("getService", server.MakeJSONAPI(&getServiceHandler{db: db})))
http.Handle("/admin/getSession", prometheus.InstrumentHandler("getSession", server.MakeJSONAPI(&getSessionHandler{db: db})))
http.Handle("/admin/configureClient", prometheus.InstrumentHandler("configureClient", server.MakeJSONAPI(&configureClientHandler{db: db, clients: clients})))
http.Handle("/admin/configureService", prometheus.InstrumentHandler("configureService", server.MakeJSONAPI(newConfigureServiceHandler(db, clients))))
http.Handle("/admin/configureAuthRealm", prometheus.InstrumentHandler("configureAuthRealm", server.MakeJSONAPI(&configureAuthRealmHandler{db: db})))
http.Handle("/admin/requestAuthSession", prometheus.InstrumentHandler("requestAuthSession", server.MakeJSONAPI(&requestAuthSessionHandler{db: db})))
http.Handle("/admin/removeAuthSession", prometheus.InstrumentHandler("removeAuthSession", server.MakeJSONAPI(&removeAuthSessionHandler{db: db})))
mux.Handle("/admin/getService", prometheus.InstrumentHandler("getService", server.MakeJSONAPI(&getServiceHandler{db: db})))
mux.Handle("/admin/getSession", prometheus.InstrumentHandler("getSession", server.MakeJSONAPI(&getSessionHandler{db: db})))
mux.Handle("/admin/configureClient", prometheus.InstrumentHandler("configureClient", server.MakeJSONAPI(&configureClientHandler{db: db, clients: clients})))
mux.Handle("/admin/configureService", prometheus.InstrumentHandler("configureService", server.MakeJSONAPI(newConfigureServiceHandler(db, clients))))
mux.Handle("/admin/configureAuthRealm", prometheus.InstrumentHandler("configureAuthRealm", server.MakeJSONAPI(&configureAuthRealmHandler{db: db})))
mux.Handle("/admin/requestAuthSession", prometheus.InstrumentHandler("requestAuthSession", server.MakeJSONAPI(&requestAuthSessionHandler{db: db})))
mux.Handle("/admin/removeAuthSession", prometheus.InstrumentHandler("removeAuthSession", server.MakeJSONAPI(&removeAuthSessionHandler{db: db})))
}
polling.SetClients(clients)
if err := polling.Start(); err != nil {
log.WithError(err).Panic("Failed to start polling")
}
}
type envVars struct {
BindAddress string
DatabaseType string
DatabaseURL string
BaseURL string
LogDir string
ConfigFile string
}
func main() {
e := envVars{
BindAddress: os.Getenv("BIND_ADDRESS"),
DatabaseType: os.Getenv("DATABASE_TYPE"),
DatabaseURL: os.Getenv("DATABASE_URL"),
BaseURL: os.Getenv("BASE_URL"),
LogDir: os.Getenv("LOG_DIR"),
ConfigFile: os.Getenv("CONFIG_FILE"),
}
if e.LogDir != "" {
log.AddHook(dugong.NewFSHook(
filepath.Join(e.LogDir, "info.log"),
filepath.Join(e.LogDir, "warn.log"),
filepath.Join(e.LogDir, "error.log"),
))
}
log.Infof("Go-NEB (%+v)", e)
log.Fatal(http.ListenAndServe(bindAddress, nil))
setup(e, http.DefaultServeMux)
log.Fatal(http.ListenAndServe(e.BindAddress, nil))
}

31
src/github.com/matrix-org/go-neb/goneb_services_test.go

@ -0,0 +1,31 @@
package main
import (
"net/http"
"net/http/httptest"
"os"
"testing"
)
var mux = http.NewServeMux()
func TestMain(m *testing.M) {
setup(envVars{
BaseURL: "http://go.neb",
DatabaseType: "sqlite3",
DatabaseURL: ":memory:",
}, mux)
exitCode := m.Run()
os.Exit(exitCode)
}
func TestNotFound(t *testing.T) {
mockWriter := httptest.NewRecorder()
mockReq, _ := http.NewRequest("GET", "http://go.neb/foo", nil)
mux.ServeHTTP(mockWriter, mockReq)
expectCode := 404
if mockWriter.Code != expectCode {
t.Errorf("TestNotFound wanted HTTP status %d, got %d", expectCode, mockWriter.Code)
}
}
Loading…
Cancel
Save