mirror of https://github.com/matrix-org/go-neb.git
Browse Source
Split up the enormous handlers.go into more manageable files
Split up the enormous handlers.go into more manageable files
Split them up based on the HTTP API they are implementing.kegan/move-handlers
Kegan Dougal
8 years ago
7 changed files with 505 additions and 437 deletions
-
198src/github.com/matrix-org/go-neb/api/handlers/auth.go
-
40src/github.com/matrix-org/go-neb/api/handlers/client.go
-
26src/github.com/matrix-org/go-neb/api/handlers/heartbeat.go
-
158src/github.com/matrix-org/go-neb/api/handlers/service.go
-
63src/github.com/matrix-org/go-neb/api/handlers/webhook.go
-
38src/github.com/matrix-org/go-neb/goneb.go
-
419src/github.com/matrix-org/go-neb/handlers.go
@ -0,0 +1,198 @@ |
|||
package handlers |
|||
|
|||
import ( |
|||
"database/sql" |
|||
"encoding/base64" |
|||
"encoding/json" |
|||
"net/http" |
|||
"strings" |
|||
|
|||
log "github.com/Sirupsen/logrus" |
|||
"github.com/matrix-org/go-neb/api" |
|||
"github.com/matrix-org/go-neb/database" |
|||
"github.com/matrix-org/go-neb/errors" |
|||
"github.com/matrix-org/go-neb/metrics" |
|||
"github.com/matrix-org/go-neb/types" |
|||
) |
|||
|
|||
type RequestAuthSession struct { |
|||
Db *database.ServiceDB |
|||
} |
|||
|
|||
func (h *RequestAuthSession) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|||
if req.Method != "POST" { |
|||
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
|||
} |
|||
var body struct { |
|||
RealmID string |
|||
UserID string |
|||
Config json.RawMessage |
|||
} |
|||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
|||
} |
|||
log.WithFields(log.Fields{ |
|||
"realm_id": body.RealmID, |
|||
"user_id": body.UserID, |
|||
}).Print("Incoming auth session request") |
|||
|
|||
if body.UserID == "" || body.RealmID == "" || body.Config == nil { |
|||
return nil, &errors.HTTPError{nil, `Must supply a "UserID", a "RealmID" and a "Config"`, 400} |
|||
} |
|||
|
|||
realm, err := h.Db.LoadAuthRealm(body.RealmID) |
|||
if err != nil { |
|||
return nil, &errors.HTTPError{err, "Unknown RealmID", 400} |
|||
} |
|||
|
|||
response := realm.RequestAuthSession(body.UserID, body.Config) |
|||
if response == nil { |
|||
return nil, &errors.HTTPError{nil, "Failed to request auth session", 500} |
|||
} |
|||
|
|||
metrics.IncrementAuthSession(realm.Type()) |
|||
|
|||
return response, nil |
|||
} |
|||
|
|||
type RemoveAuthSession struct { |
|||
Db *database.ServiceDB |
|||
} |
|||
|
|||
func (h *RemoveAuthSession) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|||
if req.Method != "POST" { |
|||
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
|||
} |
|||
var body struct { |
|||
RealmID string |
|||
UserID string |
|||
} |
|||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
|||
} |
|||
log.WithFields(log.Fields{ |
|||
"realm_id": body.RealmID, |
|||
"user_id": body.UserID, |
|||
}).Print("Incoming remove auth session request") |
|||
|
|||
if body.UserID == "" || body.RealmID == "" { |
|||
return nil, &errors.HTTPError{nil, `Must supply a "UserID", a "RealmID"`, 400} |
|||
} |
|||
|
|||
_, err := h.Db.LoadAuthRealm(body.RealmID) |
|||
if err != nil { |
|||
return nil, &errors.HTTPError{err, "Unknown RealmID", 400} |
|||
} |
|||
|
|||
if err := h.Db.RemoveAuthSession(body.RealmID, body.UserID); err != nil { |
|||
return nil, &errors.HTTPError{err, "Failed to remove auth session", 500} |
|||
} |
|||
|
|||
return []byte(`{}`), nil |
|||
} |
|||
|
|||
type RealmRedirect struct { |
|||
Db *database.ServiceDB |
|||
} |
|||
|
|||
func (rh *RealmRedirect) Handle(w http.ResponseWriter, req *http.Request) { |
|||
segments := strings.Split(req.URL.Path, "/") |
|||
// last path segment is the base64d realm ID which we will pass the incoming request to
|
|||
base64realmID := segments[len(segments)-1] |
|||
bytesRealmID, err := base64.RawURLEncoding.DecodeString(base64realmID) |
|||
realmID := string(bytesRealmID) |
|||
if err != nil { |
|||
log.WithError(err).WithField("base64_realm_id", base64realmID).Print( |
|||
"Not a b64 encoded string", |
|||
) |
|||
w.WriteHeader(400) |
|||
return |
|||
} |
|||
|
|||
realm, err := rh.Db.LoadAuthRealm(realmID) |
|||
if err != nil { |
|||
log.WithError(err).WithField("realm_id", realmID).Print("Failed to load realm") |
|||
w.WriteHeader(404) |
|||
return |
|||
} |
|||
log.WithFields(log.Fields{ |
|||
"realm_id": realmID, |
|||
}).Print("Incoming realm redirect request") |
|||
realm.OnReceiveRedirect(w, req) |
|||
} |
|||
|
|||
type ConfigureAuthRealm struct { |
|||
Db *database.ServiceDB |
|||
} |
|||
|
|||
func (h *ConfigureAuthRealm) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|||
if req.Method != "POST" { |
|||
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
|||
} |
|||
var body api.ConfigureAuthRealmRequest |
|||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
|||
} |
|||
|
|||
if err := body.Check(); err != nil { |
|||
return nil, &errors.HTTPError{err, err.Error(), 400} |
|||
} |
|||
|
|||
realm, err := types.CreateAuthRealm(body.ID, body.Type, body.Config) |
|||
if err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing config JSON", 400} |
|||
} |
|||
|
|||
if err = realm.Register(); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error registering auth realm", 400} |
|||
} |
|||
|
|||
oldRealm, err := h.Db.StoreAuthRealm(realm) |
|||
if err != nil { |
|||
return nil, &errors.HTTPError{err, "Error storing realm", 500} |
|||
} |
|||
|
|||
return &struct { |
|||
ID string |
|||
Type string |
|||
OldConfig types.AuthRealm |
|||
NewConfig types.AuthRealm |
|||
}{body.ID, body.Type, oldRealm, realm}, nil |
|||
} |
|||
|
|||
type GetSession struct { |
|||
Db *database.ServiceDB |
|||
} |
|||
|
|||
func (h *GetSession) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|||
if req.Method != "POST" { |
|||
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
|||
} |
|||
var body struct { |
|||
RealmID string |
|||
UserID string |
|||
} |
|||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
|||
} |
|||
|
|||
if body.RealmID == "" || body.UserID == "" { |
|||
return nil, &errors.HTTPError{nil, `Must supply a "RealmID" and "UserID"`, 400} |
|||
} |
|||
|
|||
session, err := h.Db.LoadAuthSessionByUser(body.RealmID, body.UserID) |
|||
if err != nil && err != sql.ErrNoRows { |
|||
return nil, &errors.HTTPError{err, `Failed to load session`, 500} |
|||
} |
|||
if err == sql.ErrNoRows { |
|||
return &struct { |
|||
Authenticated bool |
|||
}{false}, nil |
|||
} |
|||
|
|||
return &struct { |
|||
ID string |
|||
Authenticated bool |
|||
Info interface{} |
|||
}{session.ID(), session.Authenticated(), session.Info()}, nil |
|||
} |
@ -0,0 +1,40 @@ |
|||
package handlers |
|||
|
|||
import ( |
|||
"encoding/json" |
|||
"net/http" |
|||
|
|||
"github.com/matrix-org/go-neb/api" |
|||
"github.com/matrix-org/go-neb/clients" |
|||
"github.com/matrix-org/go-neb/errors" |
|||
) |
|||
|
|||
// ConfigureClient represents an HTTP handler capable of processing /configureClient requests
|
|||
type ConfigureClient struct { |
|||
Clients *clients.Clients |
|||
} |
|||
|
|||
func (s *ConfigureClient) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|||
if req.Method != "POST" { |
|||
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
|||
} |
|||
|
|||
var body api.ClientConfig |
|||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
|||
} |
|||
|
|||
if err := body.Check(); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing client config", 400} |
|||
} |
|||
|
|||
oldClient, err := s.Clients.Update(body) |
|||
if err != nil { |
|||
return nil, &errors.HTTPError{err, "Error storing token", 500} |
|||
} |
|||
|
|||
return &struct { |
|||
OldClient api.ClientConfig |
|||
NewClient api.ClientConfig |
|||
}{oldClient, body}, nil |
|||
} |
@ -0,0 +1,26 @@ |
|||
package handlers |
|||
|
|||
import ( |
|||
"github.com/matrix-org/go-neb/errors" |
|||
"net/http" |
|||
) |
|||
|
|||
// Heartbeat implements the heartbeat API
|
|||
type Heartbeat struct{} |
|||
|
|||
// OnIncomingRequest returns an empty JSON object which can be used to detect liveness of Go-NEB.
|
|||
//
|
|||
// Request:
|
|||
// ```
|
|||
// GET /test
|
|||
// ```
|
|||
//
|
|||
// Response:
|
|||
// ```
|
|||
// HTTP/1.1 200 OK
|
|||
// Content-Type: applicatoin/json
|
|||
// {}
|
|||
// ```
|
|||
func (*Heartbeat) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|||
return &struct{}{}, nil |
|||
} |
@ -0,0 +1,158 @@ |
|||
package handlers |
|||
|
|||
import ( |
|||
"database/sql" |
|||
"encoding/json" |
|||
"net/http" |
|||
"sync" |
|||
|
|||
log "github.com/Sirupsen/logrus" |
|||
"github.com/matrix-org/go-neb/api" |
|||
"github.com/matrix-org/go-neb/clients" |
|||
"github.com/matrix-org/go-neb/database" |
|||
"github.com/matrix-org/go-neb/errors" |
|||
"github.com/matrix-org/go-neb/metrics" |
|||
"github.com/matrix-org/go-neb/polling" |
|||
"github.com/matrix-org/go-neb/types" |
|||
) |
|||
|
|||
type ConfigureService struct { |
|||
db *database.ServiceDB |
|||
clients *clients.Clients |
|||
mapMutex sync.Mutex |
|||
mutexByServiceID map[string]*sync.Mutex |
|||
} |
|||
|
|||
func NewConfigureService(db *database.ServiceDB, clients *clients.Clients) *ConfigureService { |
|||
return &ConfigureService{ |
|||
db: db, |
|||
clients: clients, |
|||
mutexByServiceID: make(map[string]*sync.Mutex), |
|||
} |
|||
} |
|||
|
|||
func (s *ConfigureService) getMutexForServiceID(serviceID string) *sync.Mutex { |
|||
s.mapMutex.Lock() |
|||
defer s.mapMutex.Unlock() |
|||
m := s.mutexByServiceID[serviceID] |
|||
if m == nil { |
|||
// XXX TODO: There's a memory leak here. The amount of mutexes created is unbounded, as there will be 1 per service which are never deleted.
|
|||
// A better solution would be to have a striped hash map with a bounded pool of mutexes. We can't live with a single global mutex because the Register()
|
|||
// function this is protecting does many many HTTP requests which can take a long time on bad networks and will head of line block other services.
|
|||
m = &sync.Mutex{} |
|||
s.mutexByServiceID[serviceID] = m |
|||
} |
|||
return m |
|||
} |
|||
|
|||
func (s *ConfigureService) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|||
if req.Method != "POST" { |
|||
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
|||
} |
|||
|
|||
service, httpErr := s.createService(req) |
|||
if httpErr != nil { |
|||
return nil, httpErr |
|||
} |
|||
log.WithFields(log.Fields{ |
|||
"service_id": service.ServiceID(), |
|||
"service_type": service.ServiceType(), |
|||
"service_user_id": service.ServiceUserID(), |
|||
}).Print("Incoming configure service request") |
|||
|
|||
// Have mutexes around each service to queue up multiple requests for the same service ID
|
|||
mut := s.getMutexForServiceID(service.ServiceID()) |
|||
mut.Lock() |
|||
defer mut.Unlock() |
|||
|
|||
old, err := s.db.LoadService(service.ServiceID()) |
|||
if err != nil && err != sql.ErrNoRows { |
|||
return nil, &errors.HTTPError{err, "Error loading old service", 500} |
|||
} |
|||
|
|||
client, err := s.clients.Client(service.ServiceUserID()) |
|||
if err != nil { |
|||
return nil, &errors.HTTPError{err, "Unknown matrix client", 400} |
|||
} |
|||
|
|||
if err = service.Register(old, client); err != nil { |
|||
return nil, &errors.HTTPError{err, "Failed to register service: " + err.Error(), 500} |
|||
} |
|||
|
|||
oldService, err := s.db.StoreService(service) |
|||
if err != nil { |
|||
return nil, &errors.HTTPError{err, "Error storing service", 500} |
|||
} |
|||
|
|||
// Start any polling NOW because they may decide to stop it in PostRegister, and we want to make
|
|||
// sure we'll actually stop.
|
|||
if _, ok := service.(types.Poller); ok { |
|||
if err := polling.StartPolling(service); err != nil { |
|||
log.WithFields(log.Fields{ |
|||
"service_id": service.ServiceID(), |
|||
log.ErrorKey: err, |
|||
}).Error("Failed to start poll loop.") |
|||
} |
|||
} |
|||
|
|||
service.PostRegister(old) |
|||
metrics.IncrementConfigureService(service.ServiceType()) |
|||
|
|||
return &struct { |
|||
ID string |
|||
Type string |
|||
OldConfig types.Service |
|||
NewConfig types.Service |
|||
}{service.ServiceID(), service.ServiceType(), oldService, service}, nil |
|||
} |
|||
|
|||
func (s *ConfigureService) createService(req *http.Request) (types.Service, *errors.HTTPError) { |
|||
var body api.ConfigureServiceRequest |
|||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
|||
} |
|||
|
|||
if err := body.Check(); err != nil { |
|||
return nil, &errors.HTTPError{err, err.Error(), 400} |
|||
} |
|||
|
|||
service, err := types.CreateService(body.ID, body.Type, body.UserID, body.Config) |
|||
if err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing config JSON", 400} |
|||
} |
|||
return service, nil |
|||
} |
|||
|
|||
type GetService struct { |
|||
Db *database.ServiceDB |
|||
} |
|||
|
|||
func (h *GetService) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|||
if req.Method != "POST" { |
|||
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
|||
} |
|||
var body struct { |
|||
ID string |
|||
} |
|||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
|||
} |
|||
|
|||
if body.ID == "" { |
|||
return nil, &errors.HTTPError{nil, `Must supply a "ID"`, 400} |
|||
} |
|||
|
|||
srv, err := h.Db.LoadService(body.ID) |
|||
if err != nil { |
|||
if err == sql.ErrNoRows { |
|||
return nil, &errors.HTTPError{err, `Service not found`, 404} |
|||
} |
|||
return nil, &errors.HTTPError{err, `Failed to load service`, 500} |
|||
} |
|||
|
|||
return &struct { |
|||
ID string |
|||
Type string |
|||
Config types.Service |
|||
}{srv.ServiceID(), srv.ServiceType(), srv}, nil |
|||
} |
@ -0,0 +1,63 @@ |
|||
package handlers |
|||
|
|||
import ( |
|||
"encoding/base64" |
|||
log "github.com/Sirupsen/logrus" |
|||
"github.com/matrix-org/go-neb/clients" |
|||
"github.com/matrix-org/go-neb/database" |
|||
"github.com/matrix-org/go-neb/metrics" |
|||
"net/http" |
|||
"strings" |
|||
) |
|||
|
|||
// Webhook represents an HTTP handler capable of accepting webhook requests on behalf of services.
|
|||
type Webhook struct { |
|||
db *database.ServiceDB |
|||
clients *clients.Clients |
|||
} |
|||
|
|||
// NewWebhook returns a new webhook HTTP handler
|
|||
func NewWebhook(db *database.ServiceDB, cli *clients.Clients) *Webhook { |
|||
return &Webhook{db, cli} |
|||
} |
|||
|
|||
// Handle an incoming webhook HTTP request.
|
|||
//
|
|||
// The webhook MUST have a known base64 encoded service ID as the last path segment
|
|||
// in order for this request to be passed to the correct service.
|
|||
func (wh *Webhook) Handle(w http.ResponseWriter, req *http.Request) { |
|||
log.WithField("path", req.URL.Path).Print("Incoming webhook request") |
|||
segments := strings.Split(req.URL.Path, "/") |
|||
// last path segment is the service ID which we will pass the incoming request to,
|
|||
// but we've base64d it.
|
|||
base64srvID := segments[len(segments)-1] |
|||
bytesSrvID, err := base64.RawURLEncoding.DecodeString(base64srvID) |
|||
if err != nil { |
|||
log.WithError(err).WithField("base64_service_id", base64srvID).Print( |
|||
"Not a b64 encoded string", |
|||
) |
|||
w.WriteHeader(400) |
|||
return |
|||
} |
|||
srvID := string(bytesSrvID) |
|||
|
|||
service, err := wh.db.LoadService(srvID) |
|||
if err != nil { |
|||
log.WithError(err).WithField("service_id", srvID).Print("Failed to load service") |
|||
w.WriteHeader(404) |
|||
return |
|||
} |
|||
cli, err := wh.clients.Client(service.ServiceUserID()) |
|||
if err != nil { |
|||
log.WithError(err).WithField("user_id", service.ServiceUserID()).Print( |
|||
"Failed to retrieve matrix client instance") |
|||
w.WriteHeader(500) |
|||
return |
|||
} |
|||
log.WithFields(log.Fields{ |
|||
"service_id": service.ServiceID(), |
|||
"service_type": service.ServiceType(), |
|||
}).Print("Incoming webhook for service") |
|||
metrics.IncrementWebhook(service.ServiceType()) |
|||
service.OnReceiveWebhook(w, req, cli) |
|||
} |
@ -1,419 +0,0 @@ |
|||
package main |
|||
|
|||
import ( |
|||
"database/sql" |
|||
"encoding/base64" |
|||
"encoding/json" |
|||
log "github.com/Sirupsen/logrus" |
|||
"github.com/matrix-org/go-neb/api" |
|||
"github.com/matrix-org/go-neb/clients" |
|||
"github.com/matrix-org/go-neb/database" |
|||
"github.com/matrix-org/go-neb/errors" |
|||
"github.com/matrix-org/go-neb/metrics" |
|||
"github.com/matrix-org/go-neb/polling" |
|||
"github.com/matrix-org/go-neb/types" |
|||
"net/http" |
|||
"strings" |
|||
"sync" |
|||
) |
|||
|
|||
type heartbeatHandler struct{} |
|||
|
|||
func (*heartbeatHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|||
return &struct{}{}, nil |
|||
} |
|||
|
|||
type requestAuthSessionHandler struct { |
|||
db *database.ServiceDB |
|||
} |
|||
|
|||
func (h *requestAuthSessionHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|||
if req.Method != "POST" { |
|||
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
|||
} |
|||
var body struct { |
|||
RealmID string |
|||
UserID string |
|||
Config json.RawMessage |
|||
} |
|||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
|||
} |
|||
log.WithFields(log.Fields{ |
|||
"realm_id": body.RealmID, |
|||
"user_id": body.UserID, |
|||
}).Print("Incoming auth session request") |
|||
|
|||
if body.UserID == "" || body.RealmID == "" || body.Config == nil { |
|||
return nil, &errors.HTTPError{nil, `Must supply a "UserID", a "RealmID" and a "Config"`, 400} |
|||
} |
|||
|
|||
realm, err := h.db.LoadAuthRealm(body.RealmID) |
|||
if err != nil { |
|||
return nil, &errors.HTTPError{err, "Unknown RealmID", 400} |
|||
} |
|||
|
|||
response := realm.RequestAuthSession(body.UserID, body.Config) |
|||
if response == nil { |
|||
return nil, &errors.HTTPError{nil, "Failed to request auth session", 500} |
|||
} |
|||
|
|||
metrics.IncrementAuthSession(realm.Type()) |
|||
|
|||
return response, nil |
|||
} |
|||
|
|||
type removeAuthSessionHandler struct { |
|||
db *database.ServiceDB |
|||
} |
|||
|
|||
func (h *removeAuthSessionHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|||
if req.Method != "POST" { |
|||
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
|||
} |
|||
var body struct { |
|||
RealmID string |
|||
UserID string |
|||
} |
|||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
|||
} |
|||
log.WithFields(log.Fields{ |
|||
"realm_id": body.RealmID, |
|||
"user_id": body.UserID, |
|||
}).Print("Incoming remove auth session request") |
|||
|
|||
if body.UserID == "" || body.RealmID == "" { |
|||
return nil, &errors.HTTPError{nil, `Must supply a "UserID", a "RealmID"`, 400} |
|||
} |
|||
|
|||
_, err := h.db.LoadAuthRealm(body.RealmID) |
|||
if err != nil { |
|||
return nil, &errors.HTTPError{err, "Unknown RealmID", 400} |
|||
} |
|||
|
|||
if err := h.db.RemoveAuthSession(body.RealmID, body.UserID); err != nil { |
|||
return nil, &errors.HTTPError{err, "Failed to remove auth session", 500} |
|||
} |
|||
|
|||
return []byte(`{}`), nil |
|||
} |
|||
|
|||
type realmRedirectHandler struct { |
|||
db *database.ServiceDB |
|||
} |
|||
|
|||
func (rh *realmRedirectHandler) handle(w http.ResponseWriter, req *http.Request) { |
|||
segments := strings.Split(req.URL.Path, "/") |
|||
// last path segment is the base64d realm ID which we will pass the incoming request to
|
|||
base64realmID := segments[len(segments)-1] |
|||
bytesRealmID, err := base64.RawURLEncoding.DecodeString(base64realmID) |
|||
realmID := string(bytesRealmID) |
|||
if err != nil { |
|||
log.WithError(err).WithField("base64_realm_id", base64realmID).Print( |
|||
"Not a b64 encoded string", |
|||
) |
|||
w.WriteHeader(400) |
|||
return |
|||
} |
|||
|
|||
realm, err := rh.db.LoadAuthRealm(realmID) |
|||
if err != nil { |
|||
log.WithError(err).WithField("realm_id", realmID).Print("Failed to load realm") |
|||
w.WriteHeader(404) |
|||
return |
|||
} |
|||
log.WithFields(log.Fields{ |
|||
"realm_id": realmID, |
|||
}).Print("Incoming realm redirect request") |
|||
realm.OnReceiveRedirect(w, req) |
|||
} |
|||
|
|||
type configureAuthRealmHandler struct { |
|||
db *database.ServiceDB |
|||
} |
|||
|
|||
func (h *configureAuthRealmHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|||
if req.Method != "POST" { |
|||
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
|||
} |
|||
var body api.ConfigureAuthRealmRequest |
|||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
|||
} |
|||
|
|||
if err := body.Check(); err != nil { |
|||
return nil, &errors.HTTPError{err, err.Error(), 400} |
|||
} |
|||
|
|||
realm, err := types.CreateAuthRealm(body.ID, body.Type, body.Config) |
|||
if err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing config JSON", 400} |
|||
} |
|||
|
|||
if err = realm.Register(); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error registering auth realm", 400} |
|||
} |
|||
|
|||
oldRealm, err := h.db.StoreAuthRealm(realm) |
|||
if err != nil { |
|||
return nil, &errors.HTTPError{err, "Error storing realm", 500} |
|||
} |
|||
|
|||
return &struct { |
|||
ID string |
|||
Type string |
|||
OldConfig types.AuthRealm |
|||
NewConfig types.AuthRealm |
|||
}{body.ID, body.Type, oldRealm, realm}, nil |
|||
} |
|||
|
|||
type webhookHandler struct { |
|||
db *database.ServiceDB |
|||
clients *clients.Clients |
|||
} |
|||
|
|||
func (wh *webhookHandler) handle(w http.ResponseWriter, req *http.Request) { |
|||
log.WithField("path", req.URL.Path).Print("Incoming webhook request") |
|||
segments := strings.Split(req.URL.Path, "/") |
|||
// last path segment is the service ID which we will pass the incoming request to,
|
|||
// but we've base64d it.
|
|||
base64srvID := segments[len(segments)-1] |
|||
bytesSrvID, err := base64.RawURLEncoding.DecodeString(base64srvID) |
|||
if err != nil { |
|||
log.WithError(err).WithField("base64_service_id", base64srvID).Print( |
|||
"Not a b64 encoded string", |
|||
) |
|||
w.WriteHeader(400) |
|||
return |
|||
} |
|||
srvID := string(bytesSrvID) |
|||
|
|||
service, err := wh.db.LoadService(srvID) |
|||
if err != nil { |
|||
log.WithError(err).WithField("service_id", srvID).Print("Failed to load service") |
|||
w.WriteHeader(404) |
|||
return |
|||
} |
|||
cli, err := wh.clients.Client(service.ServiceUserID()) |
|||
if err != nil { |
|||
log.WithError(err).WithField("user_id", service.ServiceUserID()).Print( |
|||
"Failed to retrieve matrix client instance") |
|||
w.WriteHeader(500) |
|||
return |
|||
} |
|||
log.WithFields(log.Fields{ |
|||
"service_id": service.ServiceID(), |
|||
"service_type": service.ServiceType(), |
|||
}).Print("Incoming webhook for service") |
|||
metrics.IncrementWebhook(service.ServiceType()) |
|||
service.OnReceiveWebhook(w, req, cli) |
|||
} |
|||
|
|||
type configureClientHandler struct { |
|||
db *database.ServiceDB |
|||
clients *clients.Clients |
|||
} |
|||
|
|||
func (s *configureClientHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|||
if req.Method != "POST" { |
|||
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
|||
} |
|||
|
|||
var body api.ClientConfig |
|||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
|||
} |
|||
|
|||
if err := body.Check(); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing client config", 400} |
|||
} |
|||
|
|||
oldClient, err := s.clients.Update(body) |
|||
if err != nil { |
|||
return nil, &errors.HTTPError{err, "Error storing token", 500} |
|||
} |
|||
|
|||
return &struct { |
|||
OldClient api.ClientConfig |
|||
NewClient api.ClientConfig |
|||
}{oldClient, body}, nil |
|||
} |
|||
|
|||
type configureServiceHandler struct { |
|||
db *database.ServiceDB |
|||
clients *clients.Clients |
|||
mapMutex sync.Mutex |
|||
mutexByServiceID map[string]*sync.Mutex |
|||
} |
|||
|
|||
func newConfigureServiceHandler(db *database.ServiceDB, clients *clients.Clients) *configureServiceHandler { |
|||
return &configureServiceHandler{ |
|||
db: db, |
|||
clients: clients, |
|||
mutexByServiceID: make(map[string]*sync.Mutex), |
|||
} |
|||
} |
|||
|
|||
func (s *configureServiceHandler) getMutexForServiceID(serviceID string) *sync.Mutex { |
|||
s.mapMutex.Lock() |
|||
defer s.mapMutex.Unlock() |
|||
m := s.mutexByServiceID[serviceID] |
|||
if m == nil { |
|||
// XXX TODO: There's a memory leak here. The amount of mutexes created is unbounded, as there will be 1 per service which are never deleted.
|
|||
// A better solution would be to have a striped hash map with a bounded pool of mutexes. We can't live with a single global mutex because the Register()
|
|||
// function this is protecting does many many HTTP requests which can take a long time on bad networks and will head of line block other services.
|
|||
m = &sync.Mutex{} |
|||
s.mutexByServiceID[serviceID] = m |
|||
} |
|||
return m |
|||
} |
|||
|
|||
func (s *configureServiceHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|||
if req.Method != "POST" { |
|||
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
|||
} |
|||
|
|||
service, httpErr := s.createService(req) |
|||
if httpErr != nil { |
|||
return nil, httpErr |
|||
} |
|||
log.WithFields(log.Fields{ |
|||
"service_id": service.ServiceID(), |
|||
"service_type": service.ServiceType(), |
|||
"service_user_id": service.ServiceUserID(), |
|||
}).Print("Incoming configure service request") |
|||
|
|||
// Have mutexes around each service to queue up multiple requests for the same service ID
|
|||
mut := s.getMutexForServiceID(service.ServiceID()) |
|||
mut.Lock() |
|||
defer mut.Unlock() |
|||
|
|||
old, err := s.db.LoadService(service.ServiceID()) |
|||
if err != nil && err != sql.ErrNoRows { |
|||
return nil, &errors.HTTPError{err, "Error loading old service", 500} |
|||
} |
|||
|
|||
client, err := s.clients.Client(service.ServiceUserID()) |
|||
if err != nil { |
|||
return nil, &errors.HTTPError{err, "Unknown matrix client", 400} |
|||
} |
|||
|
|||
if err = service.Register(old, client); err != nil { |
|||
return nil, &errors.HTTPError{err, "Failed to register service: " + err.Error(), 500} |
|||
} |
|||
|
|||
oldService, err := s.db.StoreService(service) |
|||
if err != nil { |
|||
return nil, &errors.HTTPError{err, "Error storing service", 500} |
|||
} |
|||
|
|||
// Start any polling NOW because they may decide to stop it in PostRegister, and we want to make
|
|||
// sure we'll actually stop.
|
|||
if _, ok := service.(types.Poller); ok { |
|||
if err := polling.StartPolling(service); err != nil { |
|||
log.WithFields(log.Fields{ |
|||
"service_id": service.ServiceID(), |
|||
log.ErrorKey: err, |
|||
}).Error("Failed to start poll loop.") |
|||
} |
|||
} |
|||
|
|||
service.PostRegister(old) |
|||
metrics.IncrementConfigureService(service.ServiceType()) |
|||
|
|||
return &struct { |
|||
ID string |
|||
Type string |
|||
OldConfig types.Service |
|||
NewConfig types.Service |
|||
}{service.ServiceID(), service.ServiceType(), oldService, service}, nil |
|||
} |
|||
|
|||
func (s *configureServiceHandler) createService(req *http.Request) (types.Service, *errors.HTTPError) { |
|||
var body api.ConfigureServiceRequest |
|||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
|||
} |
|||
|
|||
if err := body.Check(); err != nil { |
|||
return nil, &errors.HTTPError{err, err.Error(), 400} |
|||
} |
|||
|
|||
service, err := types.CreateService(body.ID, body.Type, body.UserID, body.Config) |
|||
if err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing config JSON", 400} |
|||
} |
|||
return service, nil |
|||
} |
|||
|
|||
type getServiceHandler struct { |
|||
db *database.ServiceDB |
|||
} |
|||
|
|||
func (h *getServiceHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|||
if req.Method != "POST" { |
|||
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
|||
} |
|||
var body struct { |
|||
ID string |
|||
} |
|||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
|||
} |
|||
|
|||
if body.ID == "" { |
|||
return nil, &errors.HTTPError{nil, `Must supply a "ID"`, 400} |
|||
} |
|||
|
|||
srv, err := h.db.LoadService(body.ID) |
|||
if err != nil { |
|||
if err == sql.ErrNoRows { |
|||
return nil, &errors.HTTPError{err, `Service not found`, 404} |
|||
} |
|||
return nil, &errors.HTTPError{err, `Failed to load service`, 500} |
|||
} |
|||
|
|||
return &struct { |
|||
ID string |
|||
Type string |
|||
Config types.Service |
|||
}{srv.ServiceID(), srv.ServiceType(), srv}, nil |
|||
} |
|||
|
|||
type getSessionHandler struct { |
|||
db *database.ServiceDB |
|||
} |
|||
|
|||
func (h *getSessionHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|||
if req.Method != "POST" { |
|||
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
|||
} |
|||
var body struct { |
|||
RealmID string |
|||
UserID string |
|||
} |
|||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
|||
} |
|||
|
|||
if body.RealmID == "" || body.UserID == "" { |
|||
return nil, &errors.HTTPError{nil, `Must supply a "RealmID" and "UserID"`, 400} |
|||
} |
|||
|
|||
session, err := h.db.LoadAuthSessionByUser(body.RealmID, body.UserID) |
|||
if err != nil && err != sql.ErrNoRows { |
|||
return nil, &errors.HTTPError{err, `Failed to load session`, 500} |
|||
} |
|||
if err == sql.ErrNoRows { |
|||
return &struct { |
|||
Authenticated bool |
|||
}{false}, nil |
|||
} |
|||
|
|||
return &struct { |
|||
ID string |
|||
Authenticated bool |
|||
Info interface{} |
|||
}{session.ID(), session.Authenticated(), session.Info()}, nil |
|||
} |
Write
Preview
Loading…
Cancel
Save
Reference in new issue