From 4ab61e4a5ffc28bb6526c816382a3e756b1d60a3 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Tue, 1 Nov 2016 14:57:19 +0000 Subject: [PATCH] Split up the enormous handlers.go into more manageable files Split them up based on the HTTP API they are implementing. --- .../matrix-org/go-neb/api/handlers/auth.go | 198 +++++++++ .../matrix-org/go-neb/api/handlers/client.go | 40 ++ .../go-neb/api/handlers/heartbeat.go | 26 ++ .../matrix-org/go-neb/api/handlers/service.go | 158 +++++++ .../matrix-org/go-neb/api/handlers/webhook.go | 63 +++ src/github.com/matrix-org/go-neb/goneb.go | 38 +- src/github.com/matrix-org/go-neb/handlers.go | 419 ------------------ 7 files changed, 505 insertions(+), 437 deletions(-) create mode 100644 src/github.com/matrix-org/go-neb/api/handlers/auth.go create mode 100644 src/github.com/matrix-org/go-neb/api/handlers/client.go create mode 100644 src/github.com/matrix-org/go-neb/api/handlers/heartbeat.go create mode 100644 src/github.com/matrix-org/go-neb/api/handlers/service.go create mode 100644 src/github.com/matrix-org/go-neb/api/handlers/webhook.go delete mode 100644 src/github.com/matrix-org/go-neb/handlers.go diff --git a/src/github.com/matrix-org/go-neb/api/handlers/auth.go b/src/github.com/matrix-org/go-neb/api/handlers/auth.go new file mode 100644 index 0000000..7c8a7cf --- /dev/null +++ b/src/github.com/matrix-org/go-neb/api/handlers/auth.go @@ -0,0 +1,198 @@ +package handlers + +import ( + "database/sql" + "encoding/base64" + "encoding/json" + "net/http" + "strings" + + log "github.com/Sirupsen/logrus" + "github.com/matrix-org/go-neb/api" + "github.com/matrix-org/go-neb/database" + "github.com/matrix-org/go-neb/errors" + "github.com/matrix-org/go-neb/metrics" + "github.com/matrix-org/go-neb/types" +) + +type RequestAuthSession struct { + Db *database.ServiceDB +} + +func (h *RequestAuthSession) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { + if req.Method != "POST" { + return nil, &errors.HTTPError{nil, "Unsupported Method", 405} + } + var body struct { + RealmID string + UserID string + Config json.RawMessage + } + if err := json.NewDecoder(req.Body).Decode(&body); err != nil { + return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} + } + log.WithFields(log.Fields{ + "realm_id": body.RealmID, + "user_id": body.UserID, + }).Print("Incoming auth session request") + + if body.UserID == "" || body.RealmID == "" || body.Config == nil { + return nil, &errors.HTTPError{nil, `Must supply a "UserID", a "RealmID" and a "Config"`, 400} + } + + realm, err := h.Db.LoadAuthRealm(body.RealmID) + if err != nil { + return nil, &errors.HTTPError{err, "Unknown RealmID", 400} + } + + response := realm.RequestAuthSession(body.UserID, body.Config) + if response == nil { + return nil, &errors.HTTPError{nil, "Failed to request auth session", 500} + } + + metrics.IncrementAuthSession(realm.Type()) + + return response, nil +} + +type RemoveAuthSession struct { + Db *database.ServiceDB +} + +func (h *RemoveAuthSession) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { + if req.Method != "POST" { + return nil, &errors.HTTPError{nil, "Unsupported Method", 405} + } + var body struct { + RealmID string + UserID string + } + if err := json.NewDecoder(req.Body).Decode(&body); err != nil { + return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} + } + log.WithFields(log.Fields{ + "realm_id": body.RealmID, + "user_id": body.UserID, + }).Print("Incoming remove auth session request") + + if body.UserID == "" || body.RealmID == "" { + return nil, &errors.HTTPError{nil, `Must supply a "UserID", a "RealmID"`, 400} + } + + _, err := h.Db.LoadAuthRealm(body.RealmID) + if err != nil { + return nil, &errors.HTTPError{err, "Unknown RealmID", 400} + } + + if err := h.Db.RemoveAuthSession(body.RealmID, body.UserID); err != nil { + return nil, &errors.HTTPError{err, "Failed to remove auth session", 500} + } + + return []byte(`{}`), nil +} + +type RealmRedirect struct { + Db *database.ServiceDB +} + +func (rh *RealmRedirect) Handle(w http.ResponseWriter, req *http.Request) { + segments := strings.Split(req.URL.Path, "/") + // last path segment is the base64d realm ID which we will pass the incoming request to + base64realmID := segments[len(segments)-1] + bytesRealmID, err := base64.RawURLEncoding.DecodeString(base64realmID) + realmID := string(bytesRealmID) + if err != nil { + log.WithError(err).WithField("base64_realm_id", base64realmID).Print( + "Not a b64 encoded string", + ) + w.WriteHeader(400) + return + } + + realm, err := rh.Db.LoadAuthRealm(realmID) + if err != nil { + log.WithError(err).WithField("realm_id", realmID).Print("Failed to load realm") + w.WriteHeader(404) + return + } + log.WithFields(log.Fields{ + "realm_id": realmID, + }).Print("Incoming realm redirect request") + realm.OnReceiveRedirect(w, req) +} + +type ConfigureAuthRealm struct { + Db *database.ServiceDB +} + +func (h *ConfigureAuthRealm) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { + if req.Method != "POST" { + return nil, &errors.HTTPError{nil, "Unsupported Method", 405} + } + var body api.ConfigureAuthRealmRequest + if err := json.NewDecoder(req.Body).Decode(&body); err != nil { + return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} + } + + if err := body.Check(); err != nil { + return nil, &errors.HTTPError{err, err.Error(), 400} + } + + realm, err := types.CreateAuthRealm(body.ID, body.Type, body.Config) + if err != nil { + return nil, &errors.HTTPError{err, "Error parsing config JSON", 400} + } + + if err = realm.Register(); err != nil { + return nil, &errors.HTTPError{err, "Error registering auth realm", 400} + } + + oldRealm, err := h.Db.StoreAuthRealm(realm) + if err != nil { + return nil, &errors.HTTPError{err, "Error storing realm", 500} + } + + return &struct { + ID string + Type string + OldConfig types.AuthRealm + NewConfig types.AuthRealm + }{body.ID, body.Type, oldRealm, realm}, nil +} + +type GetSession struct { + Db *database.ServiceDB +} + +func (h *GetSession) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { + if req.Method != "POST" { + return nil, &errors.HTTPError{nil, "Unsupported Method", 405} + } + var body struct { + RealmID string + UserID string + } + if err := json.NewDecoder(req.Body).Decode(&body); err != nil { + return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} + } + + if body.RealmID == "" || body.UserID == "" { + return nil, &errors.HTTPError{nil, `Must supply a "RealmID" and "UserID"`, 400} + } + + session, err := h.Db.LoadAuthSessionByUser(body.RealmID, body.UserID) + if err != nil && err != sql.ErrNoRows { + return nil, &errors.HTTPError{err, `Failed to load session`, 500} + } + if err == sql.ErrNoRows { + return &struct { + Authenticated bool + }{false}, nil + } + + return &struct { + ID string + Authenticated bool + Info interface{} + }{session.ID(), session.Authenticated(), session.Info()}, nil +} diff --git a/src/github.com/matrix-org/go-neb/api/handlers/client.go b/src/github.com/matrix-org/go-neb/api/handlers/client.go new file mode 100644 index 0000000..036fe2f --- /dev/null +++ b/src/github.com/matrix-org/go-neb/api/handlers/client.go @@ -0,0 +1,40 @@ +package handlers + +import ( + "encoding/json" + "net/http" + + "github.com/matrix-org/go-neb/api" + "github.com/matrix-org/go-neb/clients" + "github.com/matrix-org/go-neb/errors" +) + +// ConfigureClient represents an HTTP handler capable of processing /configureClient requests +type ConfigureClient struct { + Clients *clients.Clients +} + +func (s *ConfigureClient) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { + if req.Method != "POST" { + return nil, &errors.HTTPError{nil, "Unsupported Method", 405} + } + + var body api.ClientConfig + if err := json.NewDecoder(req.Body).Decode(&body); err != nil { + return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} + } + + if err := body.Check(); err != nil { + return nil, &errors.HTTPError{err, "Error parsing client config", 400} + } + + oldClient, err := s.Clients.Update(body) + if err != nil { + return nil, &errors.HTTPError{err, "Error storing token", 500} + } + + return &struct { + OldClient api.ClientConfig + NewClient api.ClientConfig + }{oldClient, body}, nil +} diff --git a/src/github.com/matrix-org/go-neb/api/handlers/heartbeat.go b/src/github.com/matrix-org/go-neb/api/handlers/heartbeat.go new file mode 100644 index 0000000..817311a --- /dev/null +++ b/src/github.com/matrix-org/go-neb/api/handlers/heartbeat.go @@ -0,0 +1,26 @@ +package handlers + +import ( + "github.com/matrix-org/go-neb/errors" + "net/http" +) + +// Heartbeat implements the heartbeat API +type Heartbeat struct{} + +// OnIncomingRequest returns an empty JSON object which can be used to detect liveness of Go-NEB. +// +// Request: +// ``` +// GET /test +// ``` +// +// Response: +// ``` +// HTTP/1.1 200 OK +// Content-Type: applicatoin/json +// {} +// ``` +func (*Heartbeat) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { + return &struct{}{}, nil +} diff --git a/src/github.com/matrix-org/go-neb/api/handlers/service.go b/src/github.com/matrix-org/go-neb/api/handlers/service.go new file mode 100644 index 0000000..d6478bd --- /dev/null +++ b/src/github.com/matrix-org/go-neb/api/handlers/service.go @@ -0,0 +1,158 @@ +package handlers + +import ( + "database/sql" + "encoding/json" + "net/http" + "sync" + + log "github.com/Sirupsen/logrus" + "github.com/matrix-org/go-neb/api" + "github.com/matrix-org/go-neb/clients" + "github.com/matrix-org/go-neb/database" + "github.com/matrix-org/go-neb/errors" + "github.com/matrix-org/go-neb/metrics" + "github.com/matrix-org/go-neb/polling" + "github.com/matrix-org/go-neb/types" +) + +type ConfigureService struct { + db *database.ServiceDB + clients *clients.Clients + mapMutex sync.Mutex + mutexByServiceID map[string]*sync.Mutex +} + +func NewConfigureService(db *database.ServiceDB, clients *clients.Clients) *ConfigureService { + return &ConfigureService{ + db: db, + clients: clients, + mutexByServiceID: make(map[string]*sync.Mutex), + } +} + +func (s *ConfigureService) getMutexForServiceID(serviceID string) *sync.Mutex { + s.mapMutex.Lock() + defer s.mapMutex.Unlock() + m := s.mutexByServiceID[serviceID] + if m == nil { + // XXX TODO: There's a memory leak here. The amount of mutexes created is unbounded, as there will be 1 per service which are never deleted. + // A better solution would be to have a striped hash map with a bounded pool of mutexes. We can't live with a single global mutex because the Register() + // function this is protecting does many many HTTP requests which can take a long time on bad networks and will head of line block other services. + m = &sync.Mutex{} + s.mutexByServiceID[serviceID] = m + } + return m +} + +func (s *ConfigureService) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { + if req.Method != "POST" { + return nil, &errors.HTTPError{nil, "Unsupported Method", 405} + } + + service, httpErr := s.createService(req) + if httpErr != nil { + return nil, httpErr + } + log.WithFields(log.Fields{ + "service_id": service.ServiceID(), + "service_type": service.ServiceType(), + "service_user_id": service.ServiceUserID(), + }).Print("Incoming configure service request") + + // Have mutexes around each service to queue up multiple requests for the same service ID + mut := s.getMutexForServiceID(service.ServiceID()) + mut.Lock() + defer mut.Unlock() + + old, err := s.db.LoadService(service.ServiceID()) + if err != nil && err != sql.ErrNoRows { + return nil, &errors.HTTPError{err, "Error loading old service", 500} + } + + client, err := s.clients.Client(service.ServiceUserID()) + if err != nil { + return nil, &errors.HTTPError{err, "Unknown matrix client", 400} + } + + if err = service.Register(old, client); err != nil { + return nil, &errors.HTTPError{err, "Failed to register service: " + err.Error(), 500} + } + + oldService, err := s.db.StoreService(service) + if err != nil { + return nil, &errors.HTTPError{err, "Error storing service", 500} + } + + // Start any polling NOW because they may decide to stop it in PostRegister, and we want to make + // sure we'll actually stop. + if _, ok := service.(types.Poller); ok { + if err := polling.StartPolling(service); err != nil { + log.WithFields(log.Fields{ + "service_id": service.ServiceID(), + log.ErrorKey: err, + }).Error("Failed to start poll loop.") + } + } + + service.PostRegister(old) + metrics.IncrementConfigureService(service.ServiceType()) + + return &struct { + ID string + Type string + OldConfig types.Service + NewConfig types.Service + }{service.ServiceID(), service.ServiceType(), oldService, service}, nil +} + +func (s *ConfigureService) createService(req *http.Request) (types.Service, *errors.HTTPError) { + var body api.ConfigureServiceRequest + if err := json.NewDecoder(req.Body).Decode(&body); err != nil { + return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} + } + + if err := body.Check(); err != nil { + return nil, &errors.HTTPError{err, err.Error(), 400} + } + + service, err := types.CreateService(body.ID, body.Type, body.UserID, body.Config) + if err != nil { + return nil, &errors.HTTPError{err, "Error parsing config JSON", 400} + } + return service, nil +} + +type GetService struct { + Db *database.ServiceDB +} + +func (h *GetService) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { + if req.Method != "POST" { + return nil, &errors.HTTPError{nil, "Unsupported Method", 405} + } + var body struct { + ID string + } + if err := json.NewDecoder(req.Body).Decode(&body); err != nil { + return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} + } + + if body.ID == "" { + return nil, &errors.HTTPError{nil, `Must supply a "ID"`, 400} + } + + srv, err := h.Db.LoadService(body.ID) + if err != nil { + if err == sql.ErrNoRows { + return nil, &errors.HTTPError{err, `Service not found`, 404} + } + return nil, &errors.HTTPError{err, `Failed to load service`, 500} + } + + return &struct { + ID string + Type string + Config types.Service + }{srv.ServiceID(), srv.ServiceType(), srv}, nil +} diff --git a/src/github.com/matrix-org/go-neb/api/handlers/webhook.go b/src/github.com/matrix-org/go-neb/api/handlers/webhook.go new file mode 100644 index 0000000..392a6ff --- /dev/null +++ b/src/github.com/matrix-org/go-neb/api/handlers/webhook.go @@ -0,0 +1,63 @@ +package handlers + +import ( + "encoding/base64" + log "github.com/Sirupsen/logrus" + "github.com/matrix-org/go-neb/clients" + "github.com/matrix-org/go-neb/database" + "github.com/matrix-org/go-neb/metrics" + "net/http" + "strings" +) + +// Webhook represents an HTTP handler capable of accepting webhook requests on behalf of services. +type Webhook struct { + db *database.ServiceDB + clients *clients.Clients +} + +// NewWebhook returns a new webhook HTTP handler +func NewWebhook(db *database.ServiceDB, cli *clients.Clients) *Webhook { + return &Webhook{db, cli} +} + +// Handle an incoming webhook HTTP request. +// +// The webhook MUST have a known base64 encoded service ID as the last path segment +// in order for this request to be passed to the correct service. +func (wh *Webhook) Handle(w http.ResponseWriter, req *http.Request) { + log.WithField("path", req.URL.Path).Print("Incoming webhook request") + segments := strings.Split(req.URL.Path, "/") + // last path segment is the service ID which we will pass the incoming request to, + // but we've base64d it. + base64srvID := segments[len(segments)-1] + bytesSrvID, err := base64.RawURLEncoding.DecodeString(base64srvID) + if err != nil { + log.WithError(err).WithField("base64_service_id", base64srvID).Print( + "Not a b64 encoded string", + ) + w.WriteHeader(400) + return + } + srvID := string(bytesSrvID) + + service, err := wh.db.LoadService(srvID) + if err != nil { + log.WithError(err).WithField("service_id", srvID).Print("Failed to load service") + w.WriteHeader(404) + return + } + cli, err := wh.clients.Client(service.ServiceUserID()) + if err != nil { + log.WithError(err).WithField("user_id", service.ServiceUserID()).Print( + "Failed to retrieve matrix client instance") + w.WriteHeader(500) + return + } + log.WithFields(log.Fields{ + "service_id": service.ServiceID(), + "service_type": service.ServiceType(), + }).Print("Incoming webhook for service") + metrics.IncrementWebhook(service.ServiceType()) + service.OnReceiveWebhook(w, req, cli) +} diff --git a/src/github.com/matrix-org/go-neb/goneb.go b/src/github.com/matrix-org/go-neb/goneb.go index eef917f..d881aca 100644 --- a/src/github.com/matrix-org/go-neb/goneb.go +++ b/src/github.com/matrix-org/go-neb/goneb.go @@ -3,9 +3,16 @@ package main import ( "encoding/json" "fmt" + "io/ioutil" + "net/http" + _ "net/http/pprof" + "os" + "path/filepath" + log "github.com/Sirupsen/logrus" "github.com/matrix-org/dugong" "github.com/matrix-org/go-neb/api" + "github.com/matrix-org/go-neb/api/handlers" "github.com/matrix-org/go-neb/clients" "github.com/matrix-org/go-neb/database" _ "github.com/matrix-org/go-neb/metrics" @@ -22,12 +29,7 @@ import ( "github.com/matrix-org/go-neb/types" _ "github.com/mattn/go-sqlite3" "github.com/prometheus/client_golang/prometheus" - "gopkg.in/yaml.v2" - "io/ioutil" - "net/http" - _ "net/http/pprof" - "os" - "path/filepath" + yaml "gopkg.in/yaml.v2" ) // loadFromConfig loads a config file and returns a ConfigFile @@ -171,11 +173,11 @@ func setup(e envVars, mux *http.ServeMux, matrixClient *http.Client) { // Handle non-admin paths for normal NEB functioning mux.Handle("/metrics", prometheus.Handler()) - mux.Handle("/test", prometheus.InstrumentHandler("test", server.MakeJSONAPI(&heartbeatHandler{}))) - wh := &webhookHandler{db: db, clients: clients} - mux.HandleFunc("/services/hooks/", prometheus.InstrumentHandlerFunc("webhookHandler", server.Protect(wh.handle))) - rh := &realmRedirectHandler{db: db} - mux.HandleFunc("/realms/redirects/", prometheus.InstrumentHandlerFunc("realmRedirectHandler", server.Protect(rh.handle))) + mux.Handle("/test", prometheus.InstrumentHandler("test", server.MakeJSONAPI(&handlers.Heartbeat{}))) + wh := handlers.NewWebhook(db, clients) + mux.HandleFunc("/services/hooks/", prometheus.InstrumentHandlerFunc("webhookHandler", server.Protect(wh.Handle))) + rh := &handlers.RealmRedirect{db} + mux.HandleFunc("/realms/redirects/", prometheus.InstrumentHandlerFunc("realmRedirectHandler", server.Protect(rh.Handle))) // Read exclusively from the config file if one was supplied. // Otherwise, add HTTP listeners for new Services/Sessions/Clients/etc. @@ -186,13 +188,13 @@ func setup(e envVars, mux *http.ServeMux, matrixClient *http.Client) { log.Info("Inserted ", len(cfg.Services), " services") } else { - mux.Handle("/admin/getService", prometheus.InstrumentHandler("getService", server.MakeJSONAPI(&getServiceHandler{db: db}))) - mux.Handle("/admin/getSession", prometheus.InstrumentHandler("getSession", server.MakeJSONAPI(&getSessionHandler{db: db}))) - mux.Handle("/admin/configureClient", prometheus.InstrumentHandler("configureClient", server.MakeJSONAPI(&configureClientHandler{db: db, clients: clients}))) - mux.Handle("/admin/configureService", prometheus.InstrumentHandler("configureService", server.MakeJSONAPI(newConfigureServiceHandler(db, clients)))) - mux.Handle("/admin/configureAuthRealm", prometheus.InstrumentHandler("configureAuthRealm", server.MakeJSONAPI(&configureAuthRealmHandler{db: db}))) - mux.Handle("/admin/requestAuthSession", prometheus.InstrumentHandler("requestAuthSession", server.MakeJSONAPI(&requestAuthSessionHandler{db: db}))) - mux.Handle("/admin/removeAuthSession", prometheus.InstrumentHandler("removeAuthSession", server.MakeJSONAPI(&removeAuthSessionHandler{db: db}))) + mux.Handle("/admin/getService", prometheus.InstrumentHandler("getService", server.MakeJSONAPI(&handlers.GetService{db}))) + mux.Handle("/admin/getSession", prometheus.InstrumentHandler("getSession", server.MakeJSONAPI(&handlers.GetSession{db}))) + mux.Handle("/admin/configureClient", prometheus.InstrumentHandler("configureClient", server.MakeJSONAPI(&handlers.ConfigureClient{clients}))) + mux.Handle("/admin/configureService", prometheus.InstrumentHandler("configureService", server.MakeJSONAPI(handlers.NewConfigureService(db, clients)))) + mux.Handle("/admin/configureAuthRealm", prometheus.InstrumentHandler("configureAuthRealm", server.MakeJSONAPI(&handlers.ConfigureAuthRealm{db}))) + mux.Handle("/admin/requestAuthSession", prometheus.InstrumentHandler("requestAuthSession", server.MakeJSONAPI(&handlers.RequestAuthSession{db}))) + mux.Handle("/admin/removeAuthSession", prometheus.InstrumentHandler("removeAuthSession", server.MakeJSONAPI(&handlers.RemoveAuthSession{db}))) } polling.SetClients(clients) if err := polling.Start(); err != nil { diff --git a/src/github.com/matrix-org/go-neb/handlers.go b/src/github.com/matrix-org/go-neb/handlers.go deleted file mode 100644 index 2cec684..0000000 --- a/src/github.com/matrix-org/go-neb/handlers.go +++ /dev/null @@ -1,419 +0,0 @@ -package main - -import ( - "database/sql" - "encoding/base64" - "encoding/json" - log "github.com/Sirupsen/logrus" - "github.com/matrix-org/go-neb/api" - "github.com/matrix-org/go-neb/clients" - "github.com/matrix-org/go-neb/database" - "github.com/matrix-org/go-neb/errors" - "github.com/matrix-org/go-neb/metrics" - "github.com/matrix-org/go-neb/polling" - "github.com/matrix-org/go-neb/types" - "net/http" - "strings" - "sync" -) - -type heartbeatHandler struct{} - -func (*heartbeatHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { - return &struct{}{}, nil -} - -type requestAuthSessionHandler struct { - db *database.ServiceDB -} - -func (h *requestAuthSessionHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { - if req.Method != "POST" { - return nil, &errors.HTTPError{nil, "Unsupported Method", 405} - } - var body struct { - RealmID string - UserID string - Config json.RawMessage - } - if err := json.NewDecoder(req.Body).Decode(&body); err != nil { - return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} - } - log.WithFields(log.Fields{ - "realm_id": body.RealmID, - "user_id": body.UserID, - }).Print("Incoming auth session request") - - if body.UserID == "" || body.RealmID == "" || body.Config == nil { - return nil, &errors.HTTPError{nil, `Must supply a "UserID", a "RealmID" and a "Config"`, 400} - } - - realm, err := h.db.LoadAuthRealm(body.RealmID) - if err != nil { - return nil, &errors.HTTPError{err, "Unknown RealmID", 400} - } - - response := realm.RequestAuthSession(body.UserID, body.Config) - if response == nil { - return nil, &errors.HTTPError{nil, "Failed to request auth session", 500} - } - - metrics.IncrementAuthSession(realm.Type()) - - return response, nil -} - -type removeAuthSessionHandler struct { - db *database.ServiceDB -} - -func (h *removeAuthSessionHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { - if req.Method != "POST" { - return nil, &errors.HTTPError{nil, "Unsupported Method", 405} - } - var body struct { - RealmID string - UserID string - } - if err := json.NewDecoder(req.Body).Decode(&body); err != nil { - return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} - } - log.WithFields(log.Fields{ - "realm_id": body.RealmID, - "user_id": body.UserID, - }).Print("Incoming remove auth session request") - - if body.UserID == "" || body.RealmID == "" { - return nil, &errors.HTTPError{nil, `Must supply a "UserID", a "RealmID"`, 400} - } - - _, err := h.db.LoadAuthRealm(body.RealmID) - if err != nil { - return nil, &errors.HTTPError{err, "Unknown RealmID", 400} - } - - if err := h.db.RemoveAuthSession(body.RealmID, body.UserID); err != nil { - return nil, &errors.HTTPError{err, "Failed to remove auth session", 500} - } - - return []byte(`{}`), nil -} - -type realmRedirectHandler struct { - db *database.ServiceDB -} - -func (rh *realmRedirectHandler) handle(w http.ResponseWriter, req *http.Request) { - segments := strings.Split(req.URL.Path, "/") - // last path segment is the base64d realm ID which we will pass the incoming request to - base64realmID := segments[len(segments)-1] - bytesRealmID, err := base64.RawURLEncoding.DecodeString(base64realmID) - realmID := string(bytesRealmID) - if err != nil { - log.WithError(err).WithField("base64_realm_id", base64realmID).Print( - "Not a b64 encoded string", - ) - w.WriteHeader(400) - return - } - - realm, err := rh.db.LoadAuthRealm(realmID) - if err != nil { - log.WithError(err).WithField("realm_id", realmID).Print("Failed to load realm") - w.WriteHeader(404) - return - } - log.WithFields(log.Fields{ - "realm_id": realmID, - }).Print("Incoming realm redirect request") - realm.OnReceiveRedirect(w, req) -} - -type configureAuthRealmHandler struct { - db *database.ServiceDB -} - -func (h *configureAuthRealmHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { - if req.Method != "POST" { - return nil, &errors.HTTPError{nil, "Unsupported Method", 405} - } - var body api.ConfigureAuthRealmRequest - if err := json.NewDecoder(req.Body).Decode(&body); err != nil { - return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} - } - - if err := body.Check(); err != nil { - return nil, &errors.HTTPError{err, err.Error(), 400} - } - - realm, err := types.CreateAuthRealm(body.ID, body.Type, body.Config) - if err != nil { - return nil, &errors.HTTPError{err, "Error parsing config JSON", 400} - } - - if err = realm.Register(); err != nil { - return nil, &errors.HTTPError{err, "Error registering auth realm", 400} - } - - oldRealm, err := h.db.StoreAuthRealm(realm) - if err != nil { - return nil, &errors.HTTPError{err, "Error storing realm", 500} - } - - return &struct { - ID string - Type string - OldConfig types.AuthRealm - NewConfig types.AuthRealm - }{body.ID, body.Type, oldRealm, realm}, nil -} - -type webhookHandler struct { - db *database.ServiceDB - clients *clients.Clients -} - -func (wh *webhookHandler) handle(w http.ResponseWriter, req *http.Request) { - log.WithField("path", req.URL.Path).Print("Incoming webhook request") - segments := strings.Split(req.URL.Path, "/") - // last path segment is the service ID which we will pass the incoming request to, - // but we've base64d it. - base64srvID := segments[len(segments)-1] - bytesSrvID, err := base64.RawURLEncoding.DecodeString(base64srvID) - if err != nil { - log.WithError(err).WithField("base64_service_id", base64srvID).Print( - "Not a b64 encoded string", - ) - w.WriteHeader(400) - return - } - srvID := string(bytesSrvID) - - service, err := wh.db.LoadService(srvID) - if err != nil { - log.WithError(err).WithField("service_id", srvID).Print("Failed to load service") - w.WriteHeader(404) - return - } - cli, err := wh.clients.Client(service.ServiceUserID()) - if err != nil { - log.WithError(err).WithField("user_id", service.ServiceUserID()).Print( - "Failed to retrieve matrix client instance") - w.WriteHeader(500) - return - } - log.WithFields(log.Fields{ - "service_id": service.ServiceID(), - "service_type": service.ServiceType(), - }).Print("Incoming webhook for service") - metrics.IncrementWebhook(service.ServiceType()) - service.OnReceiveWebhook(w, req, cli) -} - -type configureClientHandler struct { - db *database.ServiceDB - clients *clients.Clients -} - -func (s *configureClientHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { - if req.Method != "POST" { - return nil, &errors.HTTPError{nil, "Unsupported Method", 405} - } - - var body api.ClientConfig - if err := json.NewDecoder(req.Body).Decode(&body); err != nil { - return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} - } - - if err := body.Check(); err != nil { - return nil, &errors.HTTPError{err, "Error parsing client config", 400} - } - - oldClient, err := s.clients.Update(body) - if err != nil { - return nil, &errors.HTTPError{err, "Error storing token", 500} - } - - return &struct { - OldClient api.ClientConfig - NewClient api.ClientConfig - }{oldClient, body}, nil -} - -type configureServiceHandler struct { - db *database.ServiceDB - clients *clients.Clients - mapMutex sync.Mutex - mutexByServiceID map[string]*sync.Mutex -} - -func newConfigureServiceHandler(db *database.ServiceDB, clients *clients.Clients) *configureServiceHandler { - return &configureServiceHandler{ - db: db, - clients: clients, - mutexByServiceID: make(map[string]*sync.Mutex), - } -} - -func (s *configureServiceHandler) getMutexForServiceID(serviceID string) *sync.Mutex { - s.mapMutex.Lock() - defer s.mapMutex.Unlock() - m := s.mutexByServiceID[serviceID] - if m == nil { - // XXX TODO: There's a memory leak here. The amount of mutexes created is unbounded, as there will be 1 per service which are never deleted. - // A better solution would be to have a striped hash map with a bounded pool of mutexes. We can't live with a single global mutex because the Register() - // function this is protecting does many many HTTP requests which can take a long time on bad networks and will head of line block other services. - m = &sync.Mutex{} - s.mutexByServiceID[serviceID] = m - } - return m -} - -func (s *configureServiceHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { - if req.Method != "POST" { - return nil, &errors.HTTPError{nil, "Unsupported Method", 405} - } - - service, httpErr := s.createService(req) - if httpErr != nil { - return nil, httpErr - } - log.WithFields(log.Fields{ - "service_id": service.ServiceID(), - "service_type": service.ServiceType(), - "service_user_id": service.ServiceUserID(), - }).Print("Incoming configure service request") - - // Have mutexes around each service to queue up multiple requests for the same service ID - mut := s.getMutexForServiceID(service.ServiceID()) - mut.Lock() - defer mut.Unlock() - - old, err := s.db.LoadService(service.ServiceID()) - if err != nil && err != sql.ErrNoRows { - return nil, &errors.HTTPError{err, "Error loading old service", 500} - } - - client, err := s.clients.Client(service.ServiceUserID()) - if err != nil { - return nil, &errors.HTTPError{err, "Unknown matrix client", 400} - } - - if err = service.Register(old, client); err != nil { - return nil, &errors.HTTPError{err, "Failed to register service: " + err.Error(), 500} - } - - oldService, err := s.db.StoreService(service) - if err != nil { - return nil, &errors.HTTPError{err, "Error storing service", 500} - } - - // Start any polling NOW because they may decide to stop it in PostRegister, and we want to make - // sure we'll actually stop. - if _, ok := service.(types.Poller); ok { - if err := polling.StartPolling(service); err != nil { - log.WithFields(log.Fields{ - "service_id": service.ServiceID(), - log.ErrorKey: err, - }).Error("Failed to start poll loop.") - } - } - - service.PostRegister(old) - metrics.IncrementConfigureService(service.ServiceType()) - - return &struct { - ID string - Type string - OldConfig types.Service - NewConfig types.Service - }{service.ServiceID(), service.ServiceType(), oldService, service}, nil -} - -func (s *configureServiceHandler) createService(req *http.Request) (types.Service, *errors.HTTPError) { - var body api.ConfigureServiceRequest - if err := json.NewDecoder(req.Body).Decode(&body); err != nil { - return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} - } - - if err := body.Check(); err != nil { - return nil, &errors.HTTPError{err, err.Error(), 400} - } - - service, err := types.CreateService(body.ID, body.Type, body.UserID, body.Config) - if err != nil { - return nil, &errors.HTTPError{err, "Error parsing config JSON", 400} - } - return service, nil -} - -type getServiceHandler struct { - db *database.ServiceDB -} - -func (h *getServiceHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { - if req.Method != "POST" { - return nil, &errors.HTTPError{nil, "Unsupported Method", 405} - } - var body struct { - ID string - } - if err := json.NewDecoder(req.Body).Decode(&body); err != nil { - return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} - } - - if body.ID == "" { - return nil, &errors.HTTPError{nil, `Must supply a "ID"`, 400} - } - - srv, err := h.db.LoadService(body.ID) - if err != nil { - if err == sql.ErrNoRows { - return nil, &errors.HTTPError{err, `Service not found`, 404} - } - return nil, &errors.HTTPError{err, `Failed to load service`, 500} - } - - return &struct { - ID string - Type string - Config types.Service - }{srv.ServiceID(), srv.ServiceType(), srv}, nil -} - -type getSessionHandler struct { - db *database.ServiceDB -} - -func (h *getSessionHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { - if req.Method != "POST" { - return nil, &errors.HTTPError{nil, "Unsupported Method", 405} - } - var body struct { - RealmID string - UserID string - } - if err := json.NewDecoder(req.Body).Decode(&body); err != nil { - return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} - } - - if body.RealmID == "" || body.UserID == "" { - return nil, &errors.HTTPError{nil, `Must supply a "RealmID" and "UserID"`, 400} - } - - session, err := h.db.LoadAuthSessionByUser(body.RealmID, body.UserID) - if err != nil && err != sql.ErrNoRows { - return nil, &errors.HTTPError{err, `Failed to load session`, 500} - } - if err == sql.ErrNoRows { - return &struct { - Authenticated bool - }{false}, nil - } - - return &struct { - ID string - Authenticated bool - Info interface{} - }{session.ID(), session.Authenticated(), session.Info()}, nil -}