mirror of https://github.com/matrix-org/go-neb.git
Browse Source
Merge pull request #106 from matrix-org/kegan/move-handlers
Merge pull request #106 from matrix-org/kegan/move-handlers
Split up the enormous handlers.go into more manageable fileskegan/docs-services
Kegsay
8 years ago
committed by
GitHub
7 changed files with 505 additions and 437 deletions
-
198src/github.com/matrix-org/go-neb/api/handlers/auth.go
-
40src/github.com/matrix-org/go-neb/api/handlers/client.go
-
26src/github.com/matrix-org/go-neb/api/handlers/heartbeat.go
-
158src/github.com/matrix-org/go-neb/api/handlers/service.go
-
63src/github.com/matrix-org/go-neb/api/handlers/webhook.go
-
38src/github.com/matrix-org/go-neb/goneb.go
-
419src/github.com/matrix-org/go-neb/handlers.go
@ -0,0 +1,198 @@ |
|||
package handlers |
|||
|
|||
import ( |
|||
"database/sql" |
|||
"encoding/base64" |
|||
"encoding/json" |
|||
"net/http" |
|||
"strings" |
|||
|
|||
log "github.com/Sirupsen/logrus" |
|||
"github.com/matrix-org/go-neb/api" |
|||
"github.com/matrix-org/go-neb/database" |
|||
"github.com/matrix-org/go-neb/errors" |
|||
"github.com/matrix-org/go-neb/metrics" |
|||
"github.com/matrix-org/go-neb/types" |
|||
) |
|||
|
|||
type RequestAuthSession struct { |
|||
Db *database.ServiceDB |
|||
} |
|||
|
|||
func (h *RequestAuthSession) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|||
if req.Method != "POST" { |
|||
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
|||
} |
|||
var body struct { |
|||
RealmID string |
|||
UserID string |
|||
Config json.RawMessage |
|||
} |
|||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
|||
} |
|||
log.WithFields(log.Fields{ |
|||
"realm_id": body.RealmID, |
|||
"user_id": body.UserID, |
|||
}).Print("Incoming auth session request") |
|||
|
|||
if body.UserID == "" || body.RealmID == "" || body.Config == nil { |
|||
return nil, &errors.HTTPError{nil, `Must supply a "UserID", a "RealmID" and a "Config"`, 400} |
|||
} |
|||
|
|||
realm, err := h.Db.LoadAuthRealm(body.RealmID) |
|||
if err != nil { |
|||
return nil, &errors.HTTPError{err, "Unknown RealmID", 400} |
|||
} |
|||
|
|||
response := realm.RequestAuthSession(body.UserID, body.Config) |
|||
if response == nil { |
|||
return nil, &errors.HTTPError{nil, "Failed to request auth session", 500} |
|||
} |
|||
|
|||
metrics.IncrementAuthSession(realm.Type()) |
|||
|
|||
return response, nil |
|||
} |
|||
|
|||
type RemoveAuthSession struct { |
|||
Db *database.ServiceDB |
|||
} |
|||
|
|||
func (h *RemoveAuthSession) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|||
if req.Method != "POST" { |
|||
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
|||
} |
|||
var body struct { |
|||
RealmID string |
|||
UserID string |
|||
} |
|||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
|||
} |
|||
log.WithFields(log.Fields{ |
|||
"realm_id": body.RealmID, |
|||
"user_id": body.UserID, |
|||
}).Print("Incoming remove auth session request") |
|||
|
|||
if body.UserID == "" || body.RealmID == "" { |
|||
return nil, &errors.HTTPError{nil, `Must supply a "UserID", a "RealmID"`, 400} |
|||
} |
|||
|
|||
_, err := h.Db.LoadAuthRealm(body.RealmID) |
|||
if err != nil { |
|||
return nil, &errors.HTTPError{err, "Unknown RealmID", 400} |
|||
} |
|||
|
|||
if err := h.Db.RemoveAuthSession(body.RealmID, body.UserID); err != nil { |
|||
return nil, &errors.HTTPError{err, "Failed to remove auth session", 500} |
|||
} |
|||
|
|||
return []byte(`{}`), nil |
|||
} |
|||
|
|||
type RealmRedirect struct { |
|||
Db *database.ServiceDB |
|||
} |
|||
|
|||
func (rh *RealmRedirect) Handle(w http.ResponseWriter, req *http.Request) { |
|||
segments := strings.Split(req.URL.Path, "/") |
|||
// last path segment is the base64d realm ID which we will pass the incoming request to
|
|||
base64realmID := segments[len(segments)-1] |
|||
bytesRealmID, err := base64.RawURLEncoding.DecodeString(base64realmID) |
|||
realmID := string(bytesRealmID) |
|||
if err != nil { |
|||
log.WithError(err).WithField("base64_realm_id", base64realmID).Print( |
|||
"Not a b64 encoded string", |
|||
) |
|||
w.WriteHeader(400) |
|||
return |
|||
} |
|||
|
|||
realm, err := rh.Db.LoadAuthRealm(realmID) |
|||
if err != nil { |
|||
log.WithError(err).WithField("realm_id", realmID).Print("Failed to load realm") |
|||
w.WriteHeader(404) |
|||
return |
|||
} |
|||
log.WithFields(log.Fields{ |
|||
"realm_id": realmID, |
|||
}).Print("Incoming realm redirect request") |
|||
realm.OnReceiveRedirect(w, req) |
|||
} |
|||
|
|||
type ConfigureAuthRealm struct { |
|||
Db *database.ServiceDB |
|||
} |
|||
|
|||
func (h *ConfigureAuthRealm) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|||
if req.Method != "POST" { |
|||
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
|||
} |
|||
var body api.ConfigureAuthRealmRequest |
|||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
|||
} |
|||
|
|||
if err := body.Check(); err != nil { |
|||
return nil, &errors.HTTPError{err, err.Error(), 400} |
|||
} |
|||
|
|||
realm, err := types.CreateAuthRealm(body.ID, body.Type, body.Config) |
|||
if err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing config JSON", 400} |
|||
} |
|||
|
|||
if err = realm.Register(); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error registering auth realm", 400} |
|||
} |
|||
|
|||
oldRealm, err := h.Db.StoreAuthRealm(realm) |
|||
if err != nil { |
|||
return nil, &errors.HTTPError{err, "Error storing realm", 500} |
|||
} |
|||
|
|||
return &struct { |
|||
ID string |
|||
Type string |
|||
OldConfig types.AuthRealm |
|||
NewConfig types.AuthRealm |
|||
}{body.ID, body.Type, oldRealm, realm}, nil |
|||
} |
|||
|
|||
type GetSession struct { |
|||
Db *database.ServiceDB |
|||
} |
|||
|
|||
func (h *GetSession) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|||
if req.Method != "POST" { |
|||
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
|||
} |
|||
var body struct { |
|||
RealmID string |
|||
UserID string |
|||
} |
|||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
|||
} |
|||
|
|||
if body.RealmID == "" || body.UserID == "" { |
|||
return nil, &errors.HTTPError{nil, `Must supply a "RealmID" and "UserID"`, 400} |
|||
} |
|||
|
|||
session, err := h.Db.LoadAuthSessionByUser(body.RealmID, body.UserID) |
|||
if err != nil && err != sql.ErrNoRows { |
|||
return nil, &errors.HTTPError{err, `Failed to load session`, 500} |
|||
} |
|||
if err == sql.ErrNoRows { |
|||
return &struct { |
|||
Authenticated bool |
|||
}{false}, nil |
|||
} |
|||
|
|||
return &struct { |
|||
ID string |
|||
Authenticated bool |
|||
Info interface{} |
|||
}{session.ID(), session.Authenticated(), session.Info()}, nil |
|||
} |
@ -0,0 +1,40 @@ |
|||
package handlers |
|||
|
|||
import ( |
|||
"encoding/json" |
|||
"net/http" |
|||
|
|||
"github.com/matrix-org/go-neb/api" |
|||
"github.com/matrix-org/go-neb/clients" |
|||
"github.com/matrix-org/go-neb/errors" |
|||
) |
|||
|
|||
// ConfigureClient represents an HTTP handler capable of processing /configureClient requests
|
|||
type ConfigureClient struct { |
|||
Clients *clients.Clients |
|||
} |
|||
|
|||
func (s *ConfigureClient) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|||
if req.Method != "POST" { |
|||
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
|||
} |
|||
|
|||
var body api.ClientConfig |
|||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
|||
} |
|||
|
|||
if err := body.Check(); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing client config", 400} |
|||
} |
|||
|
|||
oldClient, err := s.Clients.Update(body) |
|||
if err != nil { |
|||
return nil, &errors.HTTPError{err, "Error storing token", 500} |
|||
} |
|||
|
|||
return &struct { |
|||
OldClient api.ClientConfig |
|||
NewClient api.ClientConfig |
|||
}{oldClient, body}, nil |
|||
} |
@ -0,0 +1,26 @@ |
|||
package handlers |
|||
|
|||
import ( |
|||
"github.com/matrix-org/go-neb/errors" |
|||
"net/http" |
|||
) |
|||
|
|||
// Heartbeat implements the heartbeat API
|
|||
type Heartbeat struct{} |
|||
|
|||
// OnIncomingRequest returns an empty JSON object which can be used to detect liveness of Go-NEB.
|
|||
//
|
|||
// Request:
|
|||
// ```
|
|||
// GET /test
|
|||
// ```
|
|||
//
|
|||
// Response:
|
|||
// ```
|
|||
// HTTP/1.1 200 OK
|
|||
// Content-Type: applicatoin/json
|
|||
// {}
|
|||
// ```
|
|||
func (*Heartbeat) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|||
return &struct{}{}, nil |
|||
} |
@ -0,0 +1,158 @@ |
|||
package handlers |
|||
|
|||
import ( |
|||
"database/sql" |
|||
"encoding/json" |
|||
"net/http" |
|||
"sync" |
|||
|
|||
log "github.com/Sirupsen/logrus" |
|||
"github.com/matrix-org/go-neb/api" |
|||
"github.com/matrix-org/go-neb/clients" |
|||
"github.com/matrix-org/go-neb/database" |
|||
"github.com/matrix-org/go-neb/errors" |
|||
"github.com/matrix-org/go-neb/metrics" |
|||
"github.com/matrix-org/go-neb/polling" |
|||
"github.com/matrix-org/go-neb/types" |
|||
) |
|||
|
|||
type ConfigureService struct { |
|||
db *database.ServiceDB |
|||
clients *clients.Clients |
|||
mapMutex sync.Mutex |
|||
mutexByServiceID map[string]*sync.Mutex |
|||
} |
|||
|
|||
func NewConfigureService(db *database.ServiceDB, clients *clients.Clients) *ConfigureService { |
|||
return &ConfigureService{ |
|||
db: db, |
|||
clients: clients, |
|||
mutexByServiceID: make(map[string]*sync.Mutex), |
|||
} |
|||
} |
|||
|
|||
func (s *ConfigureService) getMutexForServiceID(serviceID string) *sync.Mutex { |
|||
s.mapMutex.Lock() |
|||
defer s.mapMutex.Unlock() |
|||
m := s.mutexByServiceID[serviceID] |
|||
if m == nil { |
|||
// XXX TODO: There's a memory leak here. The amount of mutexes created is unbounded, as there will be 1 per service which are never deleted.
|
|||
// A better solution would be to have a striped hash map with a bounded pool of mutexes. We can't live with a single global mutex because the Register()
|
|||
// function this is protecting does many many HTTP requests which can take a long time on bad networks and will head of line block other services.
|
|||
m = &sync.Mutex{} |
|||
s.mutexByServiceID[serviceID] = m |
|||
} |
|||
return m |
|||
} |
|||
|
|||
func (s *ConfigureService) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|||
if req.Method != "POST" { |
|||
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
|||
} |
|||
|
|||
service, httpErr := s.createService(req) |
|||
if httpErr != nil { |
|||
return nil, httpErr |
|||
} |
|||
log.WithFields(log.Fields{ |
|||
"service_id": service.ServiceID(), |
|||
"service_type": service.ServiceType(), |
|||
"service_user_id": service.ServiceUserID(), |
|||
}).Print("Incoming configure service request") |
|||
|
|||
// Have mutexes around each service to queue up multiple requests for the same service ID
|
|||
mut := s.getMutexForServiceID(service.ServiceID()) |
|||
mut.Lock() |
|||
defer mut.Unlock() |
|||
|
|||
old, err := s.db.LoadService(service.ServiceID()) |
|||
if err != nil && err != sql.ErrNoRows { |
|||
return nil, &errors.HTTPError{err, "Error loading old service", 500} |
|||
} |
|||
|
|||
client, err := s.clients.Client(service.ServiceUserID()) |
|||
if err != nil { |
|||
return nil, &errors.HTTPError{err, "Unknown matrix client", 400} |
|||
} |
|||
|
|||
if err = service.Register(old, client); err != nil { |
|||
return nil, &errors.HTTPError{err, "Failed to register service: " + err.Error(), 500} |
|||
} |
|||
|
|||
oldService, err := s.db.StoreService(service) |
|||
if err != nil { |
|||
return nil, &errors.HTTPError{err, "Error storing service", 500} |
|||
} |
|||
|
|||
// Start any polling NOW because they may decide to stop it in PostRegister, and we want to make
|
|||
// sure we'll actually stop.
|
|||
if _, ok := service.(types.Poller); ok { |
|||
if err := polling.StartPolling(service); err != nil { |
|||
log.WithFields(log.Fields{ |
|||
"service_id": service.ServiceID(), |
|||
log.ErrorKey: err, |
|||
}).Error("Failed to start poll loop.") |
|||
} |
|||
} |
|||
|
|||
service.PostRegister(old) |
|||
metrics.IncrementConfigureService(service.ServiceType()) |
|||
|
|||
return &struct { |
|||
ID string |
|||
Type string |
|||
OldConfig types.Service |
|||
NewConfig types.Service |
|||
}{service.ServiceID(), service.ServiceType(), oldService, service}, nil |
|||
} |
|||
|
|||
func (s *ConfigureService) createService(req *http.Request) (types.Service, *errors.HTTPError) { |
|||
var body api.ConfigureServiceRequest |
|||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
|||
} |
|||
|
|||
if err := body.Check(); err != nil { |
|||
return nil, &errors.HTTPError{err, err.Error(), 400} |
|||
} |
|||
|
|||
service, err := types.CreateService(body.ID, body.Type, body.UserID, body.Config) |
|||
if err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing config JSON", 400} |
|||
} |
|||
return service, nil |
|||
} |
|||
|
|||
type GetService struct { |
|||
Db *database.ServiceDB |
|||
} |
|||
|
|||
func (h *GetService) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|||
if req.Method != "POST" { |
|||
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
|||
} |
|||
var body struct { |
|||
ID string |
|||
} |
|||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
|||
} |
|||
|
|||
if body.ID == "" { |
|||
return nil, &errors.HTTPError{nil, `Must supply a "ID"`, 400} |
|||
} |
|||
|
|||
srv, err := h.Db.LoadService(body.ID) |
|||
if err != nil { |
|||
if err == sql.ErrNoRows { |
|||
return nil, &errors.HTTPError{err, `Service not found`, 404} |
|||
} |
|||
return nil, &errors.HTTPError{err, `Failed to load service`, 500} |
|||
} |
|||
|
|||
return &struct { |
|||
ID string |
|||
Type string |
|||
Config types.Service |
|||
}{srv.ServiceID(), srv.ServiceType(), srv}, nil |
|||
} |
@ -0,0 +1,63 @@ |
|||
package handlers |
|||
|
|||
import ( |
|||
"encoding/base64" |
|||
log "github.com/Sirupsen/logrus" |
|||
"github.com/matrix-org/go-neb/clients" |
|||
"github.com/matrix-org/go-neb/database" |
|||
"github.com/matrix-org/go-neb/metrics" |
|||
"net/http" |
|||
"strings" |
|||
) |
|||
|
|||
// Webhook represents an HTTP handler capable of accepting webhook requests on behalf of services.
|
|||
type Webhook struct { |
|||
db *database.ServiceDB |
|||
clients *clients.Clients |
|||
} |
|||
|
|||
// NewWebhook returns a new webhook HTTP handler
|
|||
func NewWebhook(db *database.ServiceDB, cli *clients.Clients) *Webhook { |
|||
return &Webhook{db, cli} |
|||
} |
|||
|
|||
// Handle an incoming webhook HTTP request.
|
|||
//
|
|||
// The webhook MUST have a known base64 encoded service ID as the last path segment
|
|||
// in order for this request to be passed to the correct service.
|
|||
func (wh *Webhook) Handle(w http.ResponseWriter, req *http.Request) { |
|||
log.WithField("path", req.URL.Path).Print("Incoming webhook request") |
|||
segments := strings.Split(req.URL.Path, "/") |
|||
// last path segment is the service ID which we will pass the incoming request to,
|
|||
// but we've base64d it.
|
|||
base64srvID := segments[len(segments)-1] |
|||
bytesSrvID, err := base64.RawURLEncoding.DecodeString(base64srvID) |
|||
if err != nil { |
|||
log.WithError(err).WithField("base64_service_id", base64srvID).Print( |
|||
"Not a b64 encoded string", |
|||
) |
|||
w.WriteHeader(400) |
|||
return |
|||
} |
|||
srvID := string(bytesSrvID) |
|||
|
|||
service, err := wh.db.LoadService(srvID) |
|||
if err != nil { |
|||
log.WithError(err).WithField("service_id", srvID).Print("Failed to load service") |
|||
w.WriteHeader(404) |
|||
return |
|||
} |
|||
cli, err := wh.clients.Client(service.ServiceUserID()) |
|||
if err != nil { |
|||
log.WithError(err).WithField("user_id", service.ServiceUserID()).Print( |
|||
"Failed to retrieve matrix client instance") |
|||
w.WriteHeader(500) |
|||
return |
|||
} |
|||
log.WithFields(log.Fields{ |
|||
"service_id": service.ServiceID(), |
|||
"service_type": service.ServiceType(), |
|||
}).Print("Incoming webhook for service") |
|||
metrics.IncrementWebhook(service.ServiceType()) |
|||
service.OnReceiveWebhook(w, req, cli) |
|||
} |
@ -1,419 +0,0 @@ |
|||
package main |
|||
|
|||
import ( |
|||
"database/sql" |
|||
"encoding/base64" |
|||
"encoding/json" |
|||
log "github.com/Sirupsen/logrus" |
|||
"github.com/matrix-org/go-neb/api" |
|||
"github.com/matrix-org/go-neb/clients" |
|||
"github.com/matrix-org/go-neb/database" |
|||
"github.com/matrix-org/go-neb/errors" |
|||
"github.com/matrix-org/go-neb/metrics" |
|||
"github.com/matrix-org/go-neb/polling" |
|||
"github.com/matrix-org/go-neb/types" |
|||
"net/http" |
|||
"strings" |
|||
"sync" |
|||
) |
|||
|
|||
type heartbeatHandler struct{} |
|||
|
|||
func (*heartbeatHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|||
return &struct{}{}, nil |
|||
} |
|||
|
|||
type requestAuthSessionHandler struct { |
|||
db *database.ServiceDB |
|||
} |
|||
|
|||
func (h *requestAuthSessionHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|||
if req.Method != "POST" { |
|||
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
|||
} |
|||
var body struct { |
|||
RealmID string |
|||
UserID string |
|||
Config json.RawMessage |
|||
} |
|||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
|||
} |
|||
log.WithFields(log.Fields{ |
|||
"realm_id": body.RealmID, |
|||
"user_id": body.UserID, |
|||
}).Print("Incoming auth session request") |
|||
|
|||
if body.UserID == "" || body.RealmID == "" || body.Config == nil { |
|||
return nil, &errors.HTTPError{nil, `Must supply a "UserID", a "RealmID" and a "Config"`, 400} |
|||
} |
|||
|
|||
realm, err := h.db.LoadAuthRealm(body.RealmID) |
|||
if err != nil { |
|||
return nil, &errors.HTTPError{err, "Unknown RealmID", 400} |
|||
} |
|||
|
|||
response := realm.RequestAuthSession(body.UserID, body.Config) |
|||
if response == nil { |
|||
return nil, &errors.HTTPError{nil, "Failed to request auth session", 500} |
|||
} |
|||
|
|||
metrics.IncrementAuthSession(realm.Type()) |
|||
|
|||
return response, nil |
|||
} |
|||
|
|||
type removeAuthSessionHandler struct { |
|||
db *database.ServiceDB |
|||
} |
|||
|
|||
func (h *removeAuthSessionHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|||
if req.Method != "POST" { |
|||
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
|||
} |
|||
var body struct { |
|||
RealmID string |
|||
UserID string |
|||
} |
|||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
|||
} |
|||
log.WithFields(log.Fields{ |
|||
"realm_id": body.RealmID, |
|||
"user_id": body.UserID, |
|||
}).Print("Incoming remove auth session request") |
|||
|
|||
if body.UserID == "" || body.RealmID == "" { |
|||
return nil, &errors.HTTPError{nil, `Must supply a "UserID", a "RealmID"`, 400} |
|||
} |
|||
|
|||
_, err := h.db.LoadAuthRealm(body.RealmID) |
|||
if err != nil { |
|||
return nil, &errors.HTTPError{err, "Unknown RealmID", 400} |
|||
} |
|||
|
|||
if err := h.db.RemoveAuthSession(body.RealmID, body.UserID); err != nil { |
|||
return nil, &errors.HTTPError{err, "Failed to remove auth session", 500} |
|||
} |
|||
|
|||
return []byte(`{}`), nil |
|||
} |
|||
|
|||
type realmRedirectHandler struct { |
|||
db *database.ServiceDB |
|||
} |
|||
|
|||
func (rh *realmRedirectHandler) handle(w http.ResponseWriter, req *http.Request) { |
|||
segments := strings.Split(req.URL.Path, "/") |
|||
// last path segment is the base64d realm ID which we will pass the incoming request to
|
|||
base64realmID := segments[len(segments)-1] |
|||
bytesRealmID, err := base64.RawURLEncoding.DecodeString(base64realmID) |
|||
realmID := string(bytesRealmID) |
|||
if err != nil { |
|||
log.WithError(err).WithField("base64_realm_id", base64realmID).Print( |
|||
"Not a b64 encoded string", |
|||
) |
|||
w.WriteHeader(400) |
|||
return |
|||
} |
|||
|
|||
realm, err := rh.db.LoadAuthRealm(realmID) |
|||
if err != nil { |
|||
log.WithError(err).WithField("realm_id", realmID).Print("Failed to load realm") |
|||
w.WriteHeader(404) |
|||
return |
|||
} |
|||
log.WithFields(log.Fields{ |
|||
"realm_id": realmID, |
|||
}).Print("Incoming realm redirect request") |
|||
realm.OnReceiveRedirect(w, req) |
|||
} |
|||
|
|||
type configureAuthRealmHandler struct { |
|||
db *database.ServiceDB |
|||
} |
|||
|
|||
func (h *configureAuthRealmHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|||
if req.Method != "POST" { |
|||
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
|||
} |
|||
var body api.ConfigureAuthRealmRequest |
|||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
|||
} |
|||
|
|||
if err := body.Check(); err != nil { |
|||
return nil, &errors.HTTPError{err, err.Error(), 400} |
|||
} |
|||
|
|||
realm, err := types.CreateAuthRealm(body.ID, body.Type, body.Config) |
|||
if err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing config JSON", 400} |
|||
} |
|||
|
|||
if err = realm.Register(); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error registering auth realm", 400} |
|||
} |
|||
|
|||
oldRealm, err := h.db.StoreAuthRealm(realm) |
|||
if err != nil { |
|||
return nil, &errors.HTTPError{err, "Error storing realm", 500} |
|||
} |
|||
|
|||
return &struct { |
|||
ID string |
|||
Type string |
|||
OldConfig types.AuthRealm |
|||
NewConfig types.AuthRealm |
|||
}{body.ID, body.Type, oldRealm, realm}, nil |
|||
} |
|||
|
|||
type webhookHandler struct { |
|||
db *database.ServiceDB |
|||
clients *clients.Clients |
|||
} |
|||
|
|||
func (wh *webhookHandler) handle(w http.ResponseWriter, req *http.Request) { |
|||
log.WithField("path", req.URL.Path).Print("Incoming webhook request") |
|||
segments := strings.Split(req.URL.Path, "/") |
|||
// last path segment is the service ID which we will pass the incoming request to,
|
|||
// but we've base64d it.
|
|||
base64srvID := segments[len(segments)-1] |
|||
bytesSrvID, err := base64.RawURLEncoding.DecodeString(base64srvID) |
|||
if err != nil { |
|||
log.WithError(err).WithField("base64_service_id", base64srvID).Print( |
|||
"Not a b64 encoded string", |
|||
) |
|||
w.WriteHeader(400) |
|||
return |
|||
} |
|||
srvID := string(bytesSrvID) |
|||
|
|||
service, err := wh.db.LoadService(srvID) |
|||
if err != nil { |
|||
log.WithError(err).WithField("service_id", srvID).Print("Failed to load service") |
|||
w.WriteHeader(404) |
|||
return |
|||
} |
|||
cli, err := wh.clients.Client(service.ServiceUserID()) |
|||
if err != nil { |
|||
log.WithError(err).WithField("user_id", service.ServiceUserID()).Print( |
|||
"Failed to retrieve matrix client instance") |
|||
w.WriteHeader(500) |
|||
return |
|||
} |
|||
log.WithFields(log.Fields{ |
|||
"service_id": service.ServiceID(), |
|||
"service_type": service.ServiceType(), |
|||
}).Print("Incoming webhook for service") |
|||
metrics.IncrementWebhook(service.ServiceType()) |
|||
service.OnReceiveWebhook(w, req, cli) |
|||
} |
|||
|
|||
type configureClientHandler struct { |
|||
db *database.ServiceDB |
|||
clients *clients.Clients |
|||
} |
|||
|
|||
func (s *configureClientHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|||
if req.Method != "POST" { |
|||
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
|||
} |
|||
|
|||
var body api.ClientConfig |
|||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
|||
} |
|||
|
|||
if err := body.Check(); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing client config", 400} |
|||
} |
|||
|
|||
oldClient, err := s.clients.Update(body) |
|||
if err != nil { |
|||
return nil, &errors.HTTPError{err, "Error storing token", 500} |
|||
} |
|||
|
|||
return &struct { |
|||
OldClient api.ClientConfig |
|||
NewClient api.ClientConfig |
|||
}{oldClient, body}, nil |
|||
} |
|||
|
|||
type configureServiceHandler struct { |
|||
db *database.ServiceDB |
|||
clients *clients.Clients |
|||
mapMutex sync.Mutex |
|||
mutexByServiceID map[string]*sync.Mutex |
|||
} |
|||
|
|||
func newConfigureServiceHandler(db *database.ServiceDB, clients *clients.Clients) *configureServiceHandler { |
|||
return &configureServiceHandler{ |
|||
db: db, |
|||
clients: clients, |
|||
mutexByServiceID: make(map[string]*sync.Mutex), |
|||
} |
|||
} |
|||
|
|||
func (s *configureServiceHandler) getMutexForServiceID(serviceID string) *sync.Mutex { |
|||
s.mapMutex.Lock() |
|||
defer s.mapMutex.Unlock() |
|||
m := s.mutexByServiceID[serviceID] |
|||
if m == nil { |
|||
// XXX TODO: There's a memory leak here. The amount of mutexes created is unbounded, as there will be 1 per service which are never deleted.
|
|||
// A better solution would be to have a striped hash map with a bounded pool of mutexes. We can't live with a single global mutex because the Register()
|
|||
// function this is protecting does many many HTTP requests which can take a long time on bad networks and will head of line block other services.
|
|||
m = &sync.Mutex{} |
|||
s.mutexByServiceID[serviceID] = m |
|||
} |
|||
return m |
|||
} |
|||
|
|||
func (s *configureServiceHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|||
if req.Method != "POST" { |
|||
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
|||
} |
|||
|
|||
service, httpErr := s.createService(req) |
|||
if httpErr != nil { |
|||
return nil, httpErr |
|||
} |
|||
log.WithFields(log.Fields{ |
|||
"service_id": service.ServiceID(), |
|||
"service_type": service.ServiceType(), |
|||
"service_user_id": service.ServiceUserID(), |
|||
}).Print("Incoming configure service request") |
|||
|
|||
// Have mutexes around each service to queue up multiple requests for the same service ID
|
|||
mut := s.getMutexForServiceID(service.ServiceID()) |
|||
mut.Lock() |
|||
defer mut.Unlock() |
|||
|
|||
old, err := s.db.LoadService(service.ServiceID()) |
|||
if err != nil && err != sql.ErrNoRows { |
|||
return nil, &errors.HTTPError{err, "Error loading old service", 500} |
|||
} |
|||
|
|||
client, err := s.clients.Client(service.ServiceUserID()) |
|||
if err != nil { |
|||
return nil, &errors.HTTPError{err, "Unknown matrix client", 400} |
|||
} |
|||
|
|||
if err = service.Register(old, client); err != nil { |
|||
return nil, &errors.HTTPError{err, "Failed to register service: " + err.Error(), 500} |
|||
} |
|||
|
|||
oldService, err := s.db.StoreService(service) |
|||
if err != nil { |
|||
return nil, &errors.HTTPError{err, "Error storing service", 500} |
|||
} |
|||
|
|||
// Start any polling NOW because they may decide to stop it in PostRegister, and we want to make
|
|||
// sure we'll actually stop.
|
|||
if _, ok := service.(types.Poller); ok { |
|||
if err := polling.StartPolling(service); err != nil { |
|||
log.WithFields(log.Fields{ |
|||
"service_id": service.ServiceID(), |
|||
log.ErrorKey: err, |
|||
}).Error("Failed to start poll loop.") |
|||
} |
|||
} |
|||
|
|||
service.PostRegister(old) |
|||
metrics.IncrementConfigureService(service.ServiceType()) |
|||
|
|||
return &struct { |
|||
ID string |
|||
Type string |
|||
OldConfig types.Service |
|||
NewConfig types.Service |
|||
}{service.ServiceID(), service.ServiceType(), oldService, service}, nil |
|||
} |
|||
|
|||
func (s *configureServiceHandler) createService(req *http.Request) (types.Service, *errors.HTTPError) { |
|||
var body api.ConfigureServiceRequest |
|||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
|||
} |
|||
|
|||
if err := body.Check(); err != nil { |
|||
return nil, &errors.HTTPError{err, err.Error(), 400} |
|||
} |
|||
|
|||
service, err := types.CreateService(body.ID, body.Type, body.UserID, body.Config) |
|||
if err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing config JSON", 400} |
|||
} |
|||
return service, nil |
|||
} |
|||
|
|||
type getServiceHandler struct { |
|||
db *database.ServiceDB |
|||
} |
|||
|
|||
func (h *getServiceHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|||
if req.Method != "POST" { |
|||
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
|||
} |
|||
var body struct { |
|||
ID string |
|||
} |
|||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
|||
} |
|||
|
|||
if body.ID == "" { |
|||
return nil, &errors.HTTPError{nil, `Must supply a "ID"`, 400} |
|||
} |
|||
|
|||
srv, err := h.db.LoadService(body.ID) |
|||
if err != nil { |
|||
if err == sql.ErrNoRows { |
|||
return nil, &errors.HTTPError{err, `Service not found`, 404} |
|||
} |
|||
return nil, &errors.HTTPError{err, `Failed to load service`, 500} |
|||
} |
|||
|
|||
return &struct { |
|||
ID string |
|||
Type string |
|||
Config types.Service |
|||
}{srv.ServiceID(), srv.ServiceType(), srv}, nil |
|||
} |
|||
|
|||
type getSessionHandler struct { |
|||
db *database.ServiceDB |
|||
} |
|||
|
|||
func (h *getSessionHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|||
if req.Method != "POST" { |
|||
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
|||
} |
|||
var body struct { |
|||
RealmID string |
|||
UserID string |
|||
} |
|||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
|||
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
|||
} |
|||
|
|||
if body.RealmID == "" || body.UserID == "" { |
|||
return nil, &errors.HTTPError{nil, `Must supply a "RealmID" and "UserID"`, 400} |
|||
} |
|||
|
|||
session, err := h.db.LoadAuthSessionByUser(body.RealmID, body.UserID) |
|||
if err != nil && err != sql.ErrNoRows { |
|||
return nil, &errors.HTTPError{err, `Failed to load session`, 500} |
|||
} |
|||
if err == sql.ErrNoRows { |
|||
return &struct { |
|||
Authenticated bool |
|||
}{false}, nil |
|||
} |
|||
|
|||
return &struct { |
|||
ID string |
|||
Authenticated bool |
|||
Info interface{} |
|||
}{session.ID(), session.Authenticated(), session.Info()}, nil |
|||
} |
Write
Preview
Loading…
Cancel
Save
Reference in new issue