Browse Source

Split up the enormous handlers.go into more manageable files

Split them up based on the HTTP API they are implementing.
kegan/move-handlers
Kegan Dougal 8 years ago
parent
commit
4ab61e4a5f
  1. 198
      src/github.com/matrix-org/go-neb/api/handlers/auth.go
  2. 40
      src/github.com/matrix-org/go-neb/api/handlers/client.go
  3. 26
      src/github.com/matrix-org/go-neb/api/handlers/heartbeat.go
  4. 158
      src/github.com/matrix-org/go-neb/api/handlers/service.go
  5. 63
      src/github.com/matrix-org/go-neb/api/handlers/webhook.go
  6. 38
      src/github.com/matrix-org/go-neb/goneb.go
  7. 419
      src/github.com/matrix-org/go-neb/handlers.go

198
src/github.com/matrix-org/go-neb/api/handlers/auth.go

@ -0,0 +1,198 @@
package handlers
import (
"database/sql"
"encoding/base64"
"encoding/json"
"net/http"
"strings"
log "github.com/Sirupsen/logrus"
"github.com/matrix-org/go-neb/api"
"github.com/matrix-org/go-neb/database"
"github.com/matrix-org/go-neb/errors"
"github.com/matrix-org/go-neb/metrics"
"github.com/matrix-org/go-neb/types"
)
type RequestAuthSession struct {
Db *database.ServiceDB
}
func (h *RequestAuthSession) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) {
if req.Method != "POST" {
return nil, &errors.HTTPError{nil, "Unsupported Method", 405}
}
var body struct {
RealmID string
UserID string
Config json.RawMessage
}
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400}
}
log.WithFields(log.Fields{
"realm_id": body.RealmID,
"user_id": body.UserID,
}).Print("Incoming auth session request")
if body.UserID == "" || body.RealmID == "" || body.Config == nil {
return nil, &errors.HTTPError{nil, `Must supply a "UserID", a "RealmID" and a "Config"`, 400}
}
realm, err := h.Db.LoadAuthRealm(body.RealmID)
if err != nil {
return nil, &errors.HTTPError{err, "Unknown RealmID", 400}
}
response := realm.RequestAuthSession(body.UserID, body.Config)
if response == nil {
return nil, &errors.HTTPError{nil, "Failed to request auth session", 500}
}
metrics.IncrementAuthSession(realm.Type())
return response, nil
}
type RemoveAuthSession struct {
Db *database.ServiceDB
}
func (h *RemoveAuthSession) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) {
if req.Method != "POST" {
return nil, &errors.HTTPError{nil, "Unsupported Method", 405}
}
var body struct {
RealmID string
UserID string
}
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400}
}
log.WithFields(log.Fields{
"realm_id": body.RealmID,
"user_id": body.UserID,
}).Print("Incoming remove auth session request")
if body.UserID == "" || body.RealmID == "" {
return nil, &errors.HTTPError{nil, `Must supply a "UserID", a "RealmID"`, 400}
}
_, err := h.Db.LoadAuthRealm(body.RealmID)
if err != nil {
return nil, &errors.HTTPError{err, "Unknown RealmID", 400}
}
if err := h.Db.RemoveAuthSession(body.RealmID, body.UserID); err != nil {
return nil, &errors.HTTPError{err, "Failed to remove auth session", 500}
}
return []byte(`{}`), nil
}
type RealmRedirect struct {
Db *database.ServiceDB
}
func (rh *RealmRedirect) Handle(w http.ResponseWriter, req *http.Request) {
segments := strings.Split(req.URL.Path, "/")
// last path segment is the base64d realm ID which we will pass the incoming request to
base64realmID := segments[len(segments)-1]
bytesRealmID, err := base64.RawURLEncoding.DecodeString(base64realmID)
realmID := string(bytesRealmID)
if err != nil {
log.WithError(err).WithField("base64_realm_id", base64realmID).Print(
"Not a b64 encoded string",
)
w.WriteHeader(400)
return
}
realm, err := rh.Db.LoadAuthRealm(realmID)
if err != nil {
log.WithError(err).WithField("realm_id", realmID).Print("Failed to load realm")
w.WriteHeader(404)
return
}
log.WithFields(log.Fields{
"realm_id": realmID,
}).Print("Incoming realm redirect request")
realm.OnReceiveRedirect(w, req)
}
type ConfigureAuthRealm struct {
Db *database.ServiceDB
}
func (h *ConfigureAuthRealm) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) {
if req.Method != "POST" {
return nil, &errors.HTTPError{nil, "Unsupported Method", 405}
}
var body api.ConfigureAuthRealmRequest
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400}
}
if err := body.Check(); err != nil {
return nil, &errors.HTTPError{err, err.Error(), 400}
}
realm, err := types.CreateAuthRealm(body.ID, body.Type, body.Config)
if err != nil {
return nil, &errors.HTTPError{err, "Error parsing config JSON", 400}
}
if err = realm.Register(); err != nil {
return nil, &errors.HTTPError{err, "Error registering auth realm", 400}
}
oldRealm, err := h.Db.StoreAuthRealm(realm)
if err != nil {
return nil, &errors.HTTPError{err, "Error storing realm", 500}
}
return &struct {
ID string
Type string
OldConfig types.AuthRealm
NewConfig types.AuthRealm
}{body.ID, body.Type, oldRealm, realm}, nil
}
type GetSession struct {
Db *database.ServiceDB
}
func (h *GetSession) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) {
if req.Method != "POST" {
return nil, &errors.HTTPError{nil, "Unsupported Method", 405}
}
var body struct {
RealmID string
UserID string
}
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400}
}
if body.RealmID == "" || body.UserID == "" {
return nil, &errors.HTTPError{nil, `Must supply a "RealmID" and "UserID"`, 400}
}
session, err := h.Db.LoadAuthSessionByUser(body.RealmID, body.UserID)
if err != nil && err != sql.ErrNoRows {
return nil, &errors.HTTPError{err, `Failed to load session`, 500}
}
if err == sql.ErrNoRows {
return &struct {
Authenticated bool
}{false}, nil
}
return &struct {
ID string
Authenticated bool
Info interface{}
}{session.ID(), session.Authenticated(), session.Info()}, nil
}

40
src/github.com/matrix-org/go-neb/api/handlers/client.go

@ -0,0 +1,40 @@
package handlers
import (
"encoding/json"
"net/http"
"github.com/matrix-org/go-neb/api"
"github.com/matrix-org/go-neb/clients"
"github.com/matrix-org/go-neb/errors"
)
// ConfigureClient represents an HTTP handler capable of processing /configureClient requests
type ConfigureClient struct {
Clients *clients.Clients
}
func (s *ConfigureClient) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) {
if req.Method != "POST" {
return nil, &errors.HTTPError{nil, "Unsupported Method", 405}
}
var body api.ClientConfig
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400}
}
if err := body.Check(); err != nil {
return nil, &errors.HTTPError{err, "Error parsing client config", 400}
}
oldClient, err := s.Clients.Update(body)
if err != nil {
return nil, &errors.HTTPError{err, "Error storing token", 500}
}
return &struct {
OldClient api.ClientConfig
NewClient api.ClientConfig
}{oldClient, body}, nil
}

26
src/github.com/matrix-org/go-neb/api/handlers/heartbeat.go

@ -0,0 +1,26 @@
package handlers
import (
"github.com/matrix-org/go-neb/errors"
"net/http"
)
// Heartbeat implements the heartbeat API
type Heartbeat struct{}
// OnIncomingRequest returns an empty JSON object which can be used to detect liveness of Go-NEB.
//
// Request:
// ```
// GET /test
// ```
//
// Response:
// ```
// HTTP/1.1 200 OK
// Content-Type: applicatoin/json
// {}
// ```
func (*Heartbeat) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) {
return &struct{}{}, nil
}

158
src/github.com/matrix-org/go-neb/api/handlers/service.go

@ -0,0 +1,158 @@
package handlers
import (
"database/sql"
"encoding/json"
"net/http"
"sync"
log "github.com/Sirupsen/logrus"
"github.com/matrix-org/go-neb/api"
"github.com/matrix-org/go-neb/clients"
"github.com/matrix-org/go-neb/database"
"github.com/matrix-org/go-neb/errors"
"github.com/matrix-org/go-neb/metrics"
"github.com/matrix-org/go-neb/polling"
"github.com/matrix-org/go-neb/types"
)
type ConfigureService struct {
db *database.ServiceDB
clients *clients.Clients
mapMutex sync.Mutex
mutexByServiceID map[string]*sync.Mutex
}
func NewConfigureService(db *database.ServiceDB, clients *clients.Clients) *ConfigureService {
return &ConfigureService{
db: db,
clients: clients,
mutexByServiceID: make(map[string]*sync.Mutex),
}
}
func (s *ConfigureService) getMutexForServiceID(serviceID string) *sync.Mutex {
s.mapMutex.Lock()
defer s.mapMutex.Unlock()
m := s.mutexByServiceID[serviceID]
if m == nil {
// XXX TODO: There's a memory leak here. The amount of mutexes created is unbounded, as there will be 1 per service which are never deleted.
// A better solution would be to have a striped hash map with a bounded pool of mutexes. We can't live with a single global mutex because the Register()
// function this is protecting does many many HTTP requests which can take a long time on bad networks and will head of line block other services.
m = &sync.Mutex{}
s.mutexByServiceID[serviceID] = m
}
return m
}
func (s *ConfigureService) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) {
if req.Method != "POST" {
return nil, &errors.HTTPError{nil, "Unsupported Method", 405}
}
service, httpErr := s.createService(req)
if httpErr != nil {
return nil, httpErr
}
log.WithFields(log.Fields{
"service_id": service.ServiceID(),
"service_type": service.ServiceType(),
"service_user_id": service.ServiceUserID(),
}).Print("Incoming configure service request")
// Have mutexes around each service to queue up multiple requests for the same service ID
mut := s.getMutexForServiceID(service.ServiceID())
mut.Lock()
defer mut.Unlock()
old, err := s.db.LoadService(service.ServiceID())
if err != nil && err != sql.ErrNoRows {
return nil, &errors.HTTPError{err, "Error loading old service", 500}
}
client, err := s.clients.Client(service.ServiceUserID())
if err != nil {
return nil, &errors.HTTPError{err, "Unknown matrix client", 400}
}
if err = service.Register(old, client); err != nil {
return nil, &errors.HTTPError{err, "Failed to register service: " + err.Error(), 500}
}
oldService, err := s.db.StoreService(service)
if err != nil {
return nil, &errors.HTTPError{err, "Error storing service", 500}
}
// Start any polling NOW because they may decide to stop it in PostRegister, and we want to make
// sure we'll actually stop.
if _, ok := service.(types.Poller); ok {
if err := polling.StartPolling(service); err != nil {
log.WithFields(log.Fields{
"service_id": service.ServiceID(),
log.ErrorKey: err,
}).Error("Failed to start poll loop.")
}
}
service.PostRegister(old)
metrics.IncrementConfigureService(service.ServiceType())
return &struct {
ID string
Type string
OldConfig types.Service
NewConfig types.Service
}{service.ServiceID(), service.ServiceType(), oldService, service}, nil
}
func (s *ConfigureService) createService(req *http.Request) (types.Service, *errors.HTTPError) {
var body api.ConfigureServiceRequest
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400}
}
if err := body.Check(); err != nil {
return nil, &errors.HTTPError{err, err.Error(), 400}
}
service, err := types.CreateService(body.ID, body.Type, body.UserID, body.Config)
if err != nil {
return nil, &errors.HTTPError{err, "Error parsing config JSON", 400}
}
return service, nil
}
type GetService struct {
Db *database.ServiceDB
}
func (h *GetService) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) {
if req.Method != "POST" {
return nil, &errors.HTTPError{nil, "Unsupported Method", 405}
}
var body struct {
ID string
}
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400}
}
if body.ID == "" {
return nil, &errors.HTTPError{nil, `Must supply a "ID"`, 400}
}
srv, err := h.Db.LoadService(body.ID)
if err != nil {
if err == sql.ErrNoRows {
return nil, &errors.HTTPError{err, `Service not found`, 404}
}
return nil, &errors.HTTPError{err, `Failed to load service`, 500}
}
return &struct {
ID string
Type string
Config types.Service
}{srv.ServiceID(), srv.ServiceType(), srv}, nil
}

63
src/github.com/matrix-org/go-neb/api/handlers/webhook.go

@ -0,0 +1,63 @@
package handlers
import (
"encoding/base64"
log "github.com/Sirupsen/logrus"
"github.com/matrix-org/go-neb/clients"
"github.com/matrix-org/go-neb/database"
"github.com/matrix-org/go-neb/metrics"
"net/http"
"strings"
)
// Webhook represents an HTTP handler capable of accepting webhook requests on behalf of services.
type Webhook struct {
db *database.ServiceDB
clients *clients.Clients
}
// NewWebhook returns a new webhook HTTP handler
func NewWebhook(db *database.ServiceDB, cli *clients.Clients) *Webhook {
return &Webhook{db, cli}
}
// Handle an incoming webhook HTTP request.
//
// The webhook MUST have a known base64 encoded service ID as the last path segment
// in order for this request to be passed to the correct service.
func (wh *Webhook) Handle(w http.ResponseWriter, req *http.Request) {
log.WithField("path", req.URL.Path).Print("Incoming webhook request")
segments := strings.Split(req.URL.Path, "/")
// last path segment is the service ID which we will pass the incoming request to,
// but we've base64d it.
base64srvID := segments[len(segments)-1]
bytesSrvID, err := base64.RawURLEncoding.DecodeString(base64srvID)
if err != nil {
log.WithError(err).WithField("base64_service_id", base64srvID).Print(
"Not a b64 encoded string",
)
w.WriteHeader(400)
return
}
srvID := string(bytesSrvID)
service, err := wh.db.LoadService(srvID)
if err != nil {
log.WithError(err).WithField("service_id", srvID).Print("Failed to load service")
w.WriteHeader(404)
return
}
cli, err := wh.clients.Client(service.ServiceUserID())
if err != nil {
log.WithError(err).WithField("user_id", service.ServiceUserID()).Print(
"Failed to retrieve matrix client instance")
w.WriteHeader(500)
return
}
log.WithFields(log.Fields{
"service_id": service.ServiceID(),
"service_type": service.ServiceType(),
}).Print("Incoming webhook for service")
metrics.IncrementWebhook(service.ServiceType())
service.OnReceiveWebhook(w, req, cli)
}

38
src/github.com/matrix-org/go-neb/goneb.go

@ -3,9 +3,16 @@ package main
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
_ "net/http/pprof"
"os"
"path/filepath"
log "github.com/Sirupsen/logrus"
"github.com/matrix-org/dugong"
"github.com/matrix-org/go-neb/api"
"github.com/matrix-org/go-neb/api/handlers"
"github.com/matrix-org/go-neb/clients"
"github.com/matrix-org/go-neb/database"
_ "github.com/matrix-org/go-neb/metrics"
@ -22,12 +29,7 @@ import (
"github.com/matrix-org/go-neb/types"
_ "github.com/mattn/go-sqlite3"
"github.com/prometheus/client_golang/prometheus"
"gopkg.in/yaml.v2"
"io/ioutil"
"net/http"
_ "net/http/pprof"
"os"
"path/filepath"
yaml "gopkg.in/yaml.v2"
)
// loadFromConfig loads a config file and returns a ConfigFile
@ -171,11 +173,11 @@ func setup(e envVars, mux *http.ServeMux, matrixClient *http.Client) {
// Handle non-admin paths for normal NEB functioning
mux.Handle("/metrics", prometheus.Handler())
mux.Handle("/test", prometheus.InstrumentHandler("test", server.MakeJSONAPI(&heartbeatHandler{})))
wh := &webhookHandler{db: db, clients: clients}
mux.HandleFunc("/services/hooks/", prometheus.InstrumentHandlerFunc("webhookHandler", server.Protect(wh.handle)))
rh := &realmRedirectHandler{db: db}
mux.HandleFunc("/realms/redirects/", prometheus.InstrumentHandlerFunc("realmRedirectHandler", server.Protect(rh.handle)))
mux.Handle("/test", prometheus.InstrumentHandler("test", server.MakeJSONAPI(&handlers.Heartbeat{})))
wh := handlers.NewWebhook(db, clients)
mux.HandleFunc("/services/hooks/", prometheus.InstrumentHandlerFunc("webhookHandler", server.Protect(wh.Handle)))
rh := &handlers.RealmRedirect{db}
mux.HandleFunc("/realms/redirects/", prometheus.InstrumentHandlerFunc("realmRedirectHandler", server.Protect(rh.Handle)))
// Read exclusively from the config file if one was supplied.
// Otherwise, add HTTP listeners for new Services/Sessions/Clients/etc.
@ -186,13 +188,13 @@ func setup(e envVars, mux *http.ServeMux, matrixClient *http.Client) {
log.Info("Inserted ", len(cfg.Services), " services")
} else {
mux.Handle("/admin/getService", prometheus.InstrumentHandler("getService", server.MakeJSONAPI(&getServiceHandler{db: db})))
mux.Handle("/admin/getSession", prometheus.InstrumentHandler("getSession", server.MakeJSONAPI(&getSessionHandler{db: db})))
mux.Handle("/admin/configureClient", prometheus.InstrumentHandler("configureClient", server.MakeJSONAPI(&configureClientHandler{db: db, clients: clients})))
mux.Handle("/admin/configureService", prometheus.InstrumentHandler("configureService", server.MakeJSONAPI(newConfigureServiceHandler(db, clients))))
mux.Handle("/admin/configureAuthRealm", prometheus.InstrumentHandler("configureAuthRealm", server.MakeJSONAPI(&configureAuthRealmHandler{db: db})))
mux.Handle("/admin/requestAuthSession", prometheus.InstrumentHandler("requestAuthSession", server.MakeJSONAPI(&requestAuthSessionHandler{db: db})))
mux.Handle("/admin/removeAuthSession", prometheus.InstrumentHandler("removeAuthSession", server.MakeJSONAPI(&removeAuthSessionHandler{db: db})))
mux.Handle("/admin/getService", prometheus.InstrumentHandler("getService", server.MakeJSONAPI(&handlers.GetService{db})))
mux.Handle("/admin/getSession", prometheus.InstrumentHandler("getSession", server.MakeJSONAPI(&handlers.GetSession{db})))
mux.Handle("/admin/configureClient", prometheus.InstrumentHandler("configureClient", server.MakeJSONAPI(&handlers.ConfigureClient{clients})))
mux.Handle("/admin/configureService", prometheus.InstrumentHandler("configureService", server.MakeJSONAPI(handlers.NewConfigureService(db, clients))))
mux.Handle("/admin/configureAuthRealm", prometheus.InstrumentHandler("configureAuthRealm", server.MakeJSONAPI(&handlers.ConfigureAuthRealm{db})))
mux.Handle("/admin/requestAuthSession", prometheus.InstrumentHandler("requestAuthSession", server.MakeJSONAPI(&handlers.RequestAuthSession{db})))
mux.Handle("/admin/removeAuthSession", prometheus.InstrumentHandler("removeAuthSession", server.MakeJSONAPI(&handlers.RemoveAuthSession{db})))
}
polling.SetClients(clients)
if err := polling.Start(); err != nil {

419
src/github.com/matrix-org/go-neb/handlers.go

@ -1,419 +0,0 @@
package main
import (
"database/sql"
"encoding/base64"
"encoding/json"
log "github.com/Sirupsen/logrus"
"github.com/matrix-org/go-neb/api"
"github.com/matrix-org/go-neb/clients"
"github.com/matrix-org/go-neb/database"
"github.com/matrix-org/go-neb/errors"
"github.com/matrix-org/go-neb/metrics"
"github.com/matrix-org/go-neb/polling"
"github.com/matrix-org/go-neb/types"
"net/http"
"strings"
"sync"
)
type heartbeatHandler struct{}
func (*heartbeatHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) {
return &struct{}{}, nil
}
type requestAuthSessionHandler struct {
db *database.ServiceDB
}
func (h *requestAuthSessionHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) {
if req.Method != "POST" {
return nil, &errors.HTTPError{nil, "Unsupported Method", 405}
}
var body struct {
RealmID string
UserID string
Config json.RawMessage
}
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400}
}
log.WithFields(log.Fields{
"realm_id": body.RealmID,
"user_id": body.UserID,
}).Print("Incoming auth session request")
if body.UserID == "" || body.RealmID == "" || body.Config == nil {
return nil, &errors.HTTPError{nil, `Must supply a "UserID", a "RealmID" and a "Config"`, 400}
}
realm, err := h.db.LoadAuthRealm(body.RealmID)
if err != nil {
return nil, &errors.HTTPError{err, "Unknown RealmID", 400}
}
response := realm.RequestAuthSession(body.UserID, body.Config)
if response == nil {
return nil, &errors.HTTPError{nil, "Failed to request auth session", 500}
}
metrics.IncrementAuthSession(realm.Type())
return response, nil
}
type removeAuthSessionHandler struct {
db *database.ServiceDB
}
func (h *removeAuthSessionHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) {
if req.Method != "POST" {
return nil, &errors.HTTPError{nil, "Unsupported Method", 405}
}
var body struct {
RealmID string
UserID string
}
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400}
}
log.WithFields(log.Fields{
"realm_id": body.RealmID,
"user_id": body.UserID,
}).Print("Incoming remove auth session request")
if body.UserID == "" || body.RealmID == "" {
return nil, &errors.HTTPError{nil, `Must supply a "UserID", a "RealmID"`, 400}
}
_, err := h.db.LoadAuthRealm(body.RealmID)
if err != nil {
return nil, &errors.HTTPError{err, "Unknown RealmID", 400}
}
if err := h.db.RemoveAuthSession(body.RealmID, body.UserID); err != nil {
return nil, &errors.HTTPError{err, "Failed to remove auth session", 500}
}
return []byte(`{}`), nil
}
type realmRedirectHandler struct {
db *database.ServiceDB
}
func (rh *realmRedirectHandler) handle(w http.ResponseWriter, req *http.Request) {
segments := strings.Split(req.URL.Path, "/")
// last path segment is the base64d realm ID which we will pass the incoming request to
base64realmID := segments[len(segments)-1]
bytesRealmID, err := base64.RawURLEncoding.DecodeString(base64realmID)
realmID := string(bytesRealmID)
if err != nil {
log.WithError(err).WithField("base64_realm_id", base64realmID).Print(
"Not a b64 encoded string",
)
w.WriteHeader(400)
return
}
realm, err := rh.db.LoadAuthRealm(realmID)
if err != nil {
log.WithError(err).WithField("realm_id", realmID).Print("Failed to load realm")
w.WriteHeader(404)
return
}
log.WithFields(log.Fields{
"realm_id": realmID,
}).Print("Incoming realm redirect request")
realm.OnReceiveRedirect(w, req)
}
type configureAuthRealmHandler struct {
db *database.ServiceDB
}
func (h *configureAuthRealmHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) {
if req.Method != "POST" {
return nil, &errors.HTTPError{nil, "Unsupported Method", 405}
}
var body api.ConfigureAuthRealmRequest
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400}
}
if err := body.Check(); err != nil {
return nil, &errors.HTTPError{err, err.Error(), 400}
}
realm, err := types.CreateAuthRealm(body.ID, body.Type, body.Config)
if err != nil {
return nil, &errors.HTTPError{err, "Error parsing config JSON", 400}
}
if err = realm.Register(); err != nil {
return nil, &errors.HTTPError{err, "Error registering auth realm", 400}
}
oldRealm, err := h.db.StoreAuthRealm(realm)
if err != nil {
return nil, &errors.HTTPError{err, "Error storing realm", 500}
}
return &struct {
ID string
Type string
OldConfig types.AuthRealm
NewConfig types.AuthRealm
}{body.ID, body.Type, oldRealm, realm}, nil
}
type webhookHandler struct {
db *database.ServiceDB
clients *clients.Clients
}
func (wh *webhookHandler) handle(w http.ResponseWriter, req *http.Request) {
log.WithField("path", req.URL.Path).Print("Incoming webhook request")
segments := strings.Split(req.URL.Path, "/")
// last path segment is the service ID which we will pass the incoming request to,
// but we've base64d it.
base64srvID := segments[len(segments)-1]
bytesSrvID, err := base64.RawURLEncoding.DecodeString(base64srvID)
if err != nil {
log.WithError(err).WithField("base64_service_id", base64srvID).Print(
"Not a b64 encoded string",
)
w.WriteHeader(400)
return
}
srvID := string(bytesSrvID)
service, err := wh.db.LoadService(srvID)
if err != nil {
log.WithError(err).WithField("service_id", srvID).Print("Failed to load service")
w.WriteHeader(404)
return
}
cli, err := wh.clients.Client(service.ServiceUserID())
if err != nil {
log.WithError(err).WithField("user_id", service.ServiceUserID()).Print(
"Failed to retrieve matrix client instance")
w.WriteHeader(500)
return
}
log.WithFields(log.Fields{
"service_id": service.ServiceID(),
"service_type": service.ServiceType(),
}).Print("Incoming webhook for service")
metrics.IncrementWebhook(service.ServiceType())
service.OnReceiveWebhook(w, req, cli)
}
type configureClientHandler struct {
db *database.ServiceDB
clients *clients.Clients
}
func (s *configureClientHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) {
if req.Method != "POST" {
return nil, &errors.HTTPError{nil, "Unsupported Method", 405}
}
var body api.ClientConfig
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400}
}
if err := body.Check(); err != nil {
return nil, &errors.HTTPError{err, "Error parsing client config", 400}
}
oldClient, err := s.clients.Update(body)
if err != nil {
return nil, &errors.HTTPError{err, "Error storing token", 500}
}
return &struct {
OldClient api.ClientConfig
NewClient api.ClientConfig
}{oldClient, body}, nil
}
type configureServiceHandler struct {
db *database.ServiceDB
clients *clients.Clients
mapMutex sync.Mutex
mutexByServiceID map[string]*sync.Mutex
}
func newConfigureServiceHandler(db *database.ServiceDB, clients *clients.Clients) *configureServiceHandler {
return &configureServiceHandler{
db: db,
clients: clients,
mutexByServiceID: make(map[string]*sync.Mutex),
}
}
func (s *configureServiceHandler) getMutexForServiceID(serviceID string) *sync.Mutex {
s.mapMutex.Lock()
defer s.mapMutex.Unlock()
m := s.mutexByServiceID[serviceID]
if m == nil {
// XXX TODO: There's a memory leak here. The amount of mutexes created is unbounded, as there will be 1 per service which are never deleted.
// A better solution would be to have a striped hash map with a bounded pool of mutexes. We can't live with a single global mutex because the Register()
// function this is protecting does many many HTTP requests which can take a long time on bad networks and will head of line block other services.
m = &sync.Mutex{}
s.mutexByServiceID[serviceID] = m
}
return m
}
func (s *configureServiceHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) {
if req.Method != "POST" {
return nil, &errors.HTTPError{nil, "Unsupported Method", 405}
}
service, httpErr := s.createService(req)
if httpErr != nil {
return nil, httpErr
}
log.WithFields(log.Fields{
"service_id": service.ServiceID(),
"service_type": service.ServiceType(),
"service_user_id": service.ServiceUserID(),
}).Print("Incoming configure service request")
// Have mutexes around each service to queue up multiple requests for the same service ID
mut := s.getMutexForServiceID(service.ServiceID())
mut.Lock()
defer mut.Unlock()
old, err := s.db.LoadService(service.ServiceID())
if err != nil && err != sql.ErrNoRows {
return nil, &errors.HTTPError{err, "Error loading old service", 500}
}
client, err := s.clients.Client(service.ServiceUserID())
if err != nil {
return nil, &errors.HTTPError{err, "Unknown matrix client", 400}
}
if err = service.Register(old, client); err != nil {
return nil, &errors.HTTPError{err, "Failed to register service: " + err.Error(), 500}
}
oldService, err := s.db.StoreService(service)
if err != nil {
return nil, &errors.HTTPError{err, "Error storing service", 500}
}
// Start any polling NOW because they may decide to stop it in PostRegister, and we want to make
// sure we'll actually stop.
if _, ok := service.(types.Poller); ok {
if err := polling.StartPolling(service); err != nil {
log.WithFields(log.Fields{
"service_id": service.ServiceID(),
log.ErrorKey: err,
}).Error("Failed to start poll loop.")
}
}
service.PostRegister(old)
metrics.IncrementConfigureService(service.ServiceType())
return &struct {
ID string
Type string
OldConfig types.Service
NewConfig types.Service
}{service.ServiceID(), service.ServiceType(), oldService, service}, nil
}
func (s *configureServiceHandler) createService(req *http.Request) (types.Service, *errors.HTTPError) {
var body api.ConfigureServiceRequest
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400}
}
if err := body.Check(); err != nil {
return nil, &errors.HTTPError{err, err.Error(), 400}
}
service, err := types.CreateService(body.ID, body.Type, body.UserID, body.Config)
if err != nil {
return nil, &errors.HTTPError{err, "Error parsing config JSON", 400}
}
return service, nil
}
type getServiceHandler struct {
db *database.ServiceDB
}
func (h *getServiceHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) {
if req.Method != "POST" {
return nil, &errors.HTTPError{nil, "Unsupported Method", 405}
}
var body struct {
ID string
}
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400}
}
if body.ID == "" {
return nil, &errors.HTTPError{nil, `Must supply a "ID"`, 400}
}
srv, err := h.db.LoadService(body.ID)
if err != nil {
if err == sql.ErrNoRows {
return nil, &errors.HTTPError{err, `Service not found`, 404}
}
return nil, &errors.HTTPError{err, `Failed to load service`, 500}
}
return &struct {
ID string
Type string
Config types.Service
}{srv.ServiceID(), srv.ServiceType(), srv}, nil
}
type getSessionHandler struct {
db *database.ServiceDB
}
func (h *getSessionHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) {
if req.Method != "POST" {
return nil, &errors.HTTPError{nil, "Unsupported Method", 405}
}
var body struct {
RealmID string
UserID string
}
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400}
}
if body.RealmID == "" || body.UserID == "" {
return nil, &errors.HTTPError{nil, `Must supply a "RealmID" and "UserID"`, 400}
}
session, err := h.db.LoadAuthSessionByUser(body.RealmID, body.UserID)
if err != nil && err != sql.ErrNoRows {
return nil, &errors.HTTPError{err, `Failed to load session`, 500}
}
if err == sql.ErrNoRows {
return &struct {
Authenticated bool
}{false}, nil
}
return &struct {
ID string
Authenticated bool
Info interface{}
}{session.ID(), session.Authenticated(), session.Info()}, nil
}
Loading…
Cancel
Save