mirror of https://github.com/matrix-org/go-neb.git
Browse Source
Merge pull request #106 from matrix-org/kegan/move-handlers
Merge pull request #106 from matrix-org/kegan/move-handlers
Split up the enormous handlers.go into more manageable fileskegan/docs-services
Kegsay
8 years ago
committed by
GitHub
7 changed files with 505 additions and 437 deletions
-
198src/github.com/matrix-org/go-neb/api/handlers/auth.go
-
40src/github.com/matrix-org/go-neb/api/handlers/client.go
-
26src/github.com/matrix-org/go-neb/api/handlers/heartbeat.go
-
158src/github.com/matrix-org/go-neb/api/handlers/service.go
-
63src/github.com/matrix-org/go-neb/api/handlers/webhook.go
-
38src/github.com/matrix-org/go-neb/goneb.go
-
419src/github.com/matrix-org/go-neb/handlers.go
@ -0,0 +1,198 @@ |
|||||
|
package handlers |
||||
|
|
||||
|
import ( |
||||
|
"database/sql" |
||||
|
"encoding/base64" |
||||
|
"encoding/json" |
||||
|
"net/http" |
||||
|
"strings" |
||||
|
|
||||
|
log "github.com/Sirupsen/logrus" |
||||
|
"github.com/matrix-org/go-neb/api" |
||||
|
"github.com/matrix-org/go-neb/database" |
||||
|
"github.com/matrix-org/go-neb/errors" |
||||
|
"github.com/matrix-org/go-neb/metrics" |
||||
|
"github.com/matrix-org/go-neb/types" |
||||
|
) |
||||
|
|
||||
|
type RequestAuthSession struct { |
||||
|
Db *database.ServiceDB |
||||
|
} |
||||
|
|
||||
|
func (h *RequestAuthSession) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
||||
|
if req.Method != "POST" { |
||||
|
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
||||
|
} |
||||
|
var body struct { |
||||
|
RealmID string |
||||
|
UserID string |
||||
|
Config json.RawMessage |
||||
|
} |
||||
|
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
||||
|
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
||||
|
} |
||||
|
log.WithFields(log.Fields{ |
||||
|
"realm_id": body.RealmID, |
||||
|
"user_id": body.UserID, |
||||
|
}).Print("Incoming auth session request") |
||||
|
|
||||
|
if body.UserID == "" || body.RealmID == "" || body.Config == nil { |
||||
|
return nil, &errors.HTTPError{nil, `Must supply a "UserID", a "RealmID" and a "Config"`, 400} |
||||
|
} |
||||
|
|
||||
|
realm, err := h.Db.LoadAuthRealm(body.RealmID) |
||||
|
if err != nil { |
||||
|
return nil, &errors.HTTPError{err, "Unknown RealmID", 400} |
||||
|
} |
||||
|
|
||||
|
response := realm.RequestAuthSession(body.UserID, body.Config) |
||||
|
if response == nil { |
||||
|
return nil, &errors.HTTPError{nil, "Failed to request auth session", 500} |
||||
|
} |
||||
|
|
||||
|
metrics.IncrementAuthSession(realm.Type()) |
||||
|
|
||||
|
return response, nil |
||||
|
} |
||||
|
|
||||
|
type RemoveAuthSession struct { |
||||
|
Db *database.ServiceDB |
||||
|
} |
||||
|
|
||||
|
func (h *RemoveAuthSession) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
||||
|
if req.Method != "POST" { |
||||
|
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
||||
|
} |
||||
|
var body struct { |
||||
|
RealmID string |
||||
|
UserID string |
||||
|
} |
||||
|
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
||||
|
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
||||
|
} |
||||
|
log.WithFields(log.Fields{ |
||||
|
"realm_id": body.RealmID, |
||||
|
"user_id": body.UserID, |
||||
|
}).Print("Incoming remove auth session request") |
||||
|
|
||||
|
if body.UserID == "" || body.RealmID == "" { |
||||
|
return nil, &errors.HTTPError{nil, `Must supply a "UserID", a "RealmID"`, 400} |
||||
|
} |
||||
|
|
||||
|
_, err := h.Db.LoadAuthRealm(body.RealmID) |
||||
|
if err != nil { |
||||
|
return nil, &errors.HTTPError{err, "Unknown RealmID", 400} |
||||
|
} |
||||
|
|
||||
|
if err := h.Db.RemoveAuthSession(body.RealmID, body.UserID); err != nil { |
||||
|
return nil, &errors.HTTPError{err, "Failed to remove auth session", 500} |
||||
|
} |
||||
|
|
||||
|
return []byte(`{}`), nil |
||||
|
} |
||||
|
|
||||
|
type RealmRedirect struct { |
||||
|
Db *database.ServiceDB |
||||
|
} |
||||
|
|
||||
|
func (rh *RealmRedirect) Handle(w http.ResponseWriter, req *http.Request) { |
||||
|
segments := strings.Split(req.URL.Path, "/") |
||||
|
// last path segment is the base64d realm ID which we will pass the incoming request to
|
||||
|
base64realmID := segments[len(segments)-1] |
||||
|
bytesRealmID, err := base64.RawURLEncoding.DecodeString(base64realmID) |
||||
|
realmID := string(bytesRealmID) |
||||
|
if err != nil { |
||||
|
log.WithError(err).WithField("base64_realm_id", base64realmID).Print( |
||||
|
"Not a b64 encoded string", |
||||
|
) |
||||
|
w.WriteHeader(400) |
||||
|
return |
||||
|
} |
||||
|
|
||||
|
realm, err := rh.Db.LoadAuthRealm(realmID) |
||||
|
if err != nil { |
||||
|
log.WithError(err).WithField("realm_id", realmID).Print("Failed to load realm") |
||||
|
w.WriteHeader(404) |
||||
|
return |
||||
|
} |
||||
|
log.WithFields(log.Fields{ |
||||
|
"realm_id": realmID, |
||||
|
}).Print("Incoming realm redirect request") |
||||
|
realm.OnReceiveRedirect(w, req) |
||||
|
} |
||||
|
|
||||
|
type ConfigureAuthRealm struct { |
||||
|
Db *database.ServiceDB |
||||
|
} |
||||
|
|
||||
|
func (h *ConfigureAuthRealm) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
||||
|
if req.Method != "POST" { |
||||
|
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
||||
|
} |
||||
|
var body api.ConfigureAuthRealmRequest |
||||
|
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
||||
|
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
||||
|
} |
||||
|
|
||||
|
if err := body.Check(); err != nil { |
||||
|
return nil, &errors.HTTPError{err, err.Error(), 400} |
||||
|
} |
||||
|
|
||||
|
realm, err := types.CreateAuthRealm(body.ID, body.Type, body.Config) |
||||
|
if err != nil { |
||||
|
return nil, &errors.HTTPError{err, "Error parsing config JSON", 400} |
||||
|
} |
||||
|
|
||||
|
if err = realm.Register(); err != nil { |
||||
|
return nil, &errors.HTTPError{err, "Error registering auth realm", 400} |
||||
|
} |
||||
|
|
||||
|
oldRealm, err := h.Db.StoreAuthRealm(realm) |
||||
|
if err != nil { |
||||
|
return nil, &errors.HTTPError{err, "Error storing realm", 500} |
||||
|
} |
||||
|
|
||||
|
return &struct { |
||||
|
ID string |
||||
|
Type string |
||||
|
OldConfig types.AuthRealm |
||||
|
NewConfig types.AuthRealm |
||||
|
}{body.ID, body.Type, oldRealm, realm}, nil |
||||
|
} |
||||
|
|
||||
|
type GetSession struct { |
||||
|
Db *database.ServiceDB |
||||
|
} |
||||
|
|
||||
|
func (h *GetSession) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
||||
|
if req.Method != "POST" { |
||||
|
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
||||
|
} |
||||
|
var body struct { |
||||
|
RealmID string |
||||
|
UserID string |
||||
|
} |
||||
|
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
||||
|
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
||||
|
} |
||||
|
|
||||
|
if body.RealmID == "" || body.UserID == "" { |
||||
|
return nil, &errors.HTTPError{nil, `Must supply a "RealmID" and "UserID"`, 400} |
||||
|
} |
||||
|
|
||||
|
session, err := h.Db.LoadAuthSessionByUser(body.RealmID, body.UserID) |
||||
|
if err != nil && err != sql.ErrNoRows { |
||||
|
return nil, &errors.HTTPError{err, `Failed to load session`, 500} |
||||
|
} |
||||
|
if err == sql.ErrNoRows { |
||||
|
return &struct { |
||||
|
Authenticated bool |
||||
|
}{false}, nil |
||||
|
} |
||||
|
|
||||
|
return &struct { |
||||
|
ID string |
||||
|
Authenticated bool |
||||
|
Info interface{} |
||||
|
}{session.ID(), session.Authenticated(), session.Info()}, nil |
||||
|
} |
@ -0,0 +1,40 @@ |
|||||
|
package handlers |
||||
|
|
||||
|
import ( |
||||
|
"encoding/json" |
||||
|
"net/http" |
||||
|
|
||||
|
"github.com/matrix-org/go-neb/api" |
||||
|
"github.com/matrix-org/go-neb/clients" |
||||
|
"github.com/matrix-org/go-neb/errors" |
||||
|
) |
||||
|
|
||||
|
// ConfigureClient represents an HTTP handler capable of processing /configureClient requests
|
||||
|
type ConfigureClient struct { |
||||
|
Clients *clients.Clients |
||||
|
} |
||||
|
|
||||
|
func (s *ConfigureClient) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
||||
|
if req.Method != "POST" { |
||||
|
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
||||
|
} |
||||
|
|
||||
|
var body api.ClientConfig |
||||
|
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
||||
|
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
||||
|
} |
||||
|
|
||||
|
if err := body.Check(); err != nil { |
||||
|
return nil, &errors.HTTPError{err, "Error parsing client config", 400} |
||||
|
} |
||||
|
|
||||
|
oldClient, err := s.Clients.Update(body) |
||||
|
if err != nil { |
||||
|
return nil, &errors.HTTPError{err, "Error storing token", 500} |
||||
|
} |
||||
|
|
||||
|
return &struct { |
||||
|
OldClient api.ClientConfig |
||||
|
NewClient api.ClientConfig |
||||
|
}{oldClient, body}, nil |
||||
|
} |
@ -0,0 +1,26 @@ |
|||||
|
package handlers |
||||
|
|
||||
|
import ( |
||||
|
"github.com/matrix-org/go-neb/errors" |
||||
|
"net/http" |
||||
|
) |
||||
|
|
||||
|
// Heartbeat implements the heartbeat API
|
||||
|
type Heartbeat struct{} |
||||
|
|
||||
|
// OnIncomingRequest returns an empty JSON object which can be used to detect liveness of Go-NEB.
|
||||
|
//
|
||||
|
// Request:
|
||||
|
// ```
|
||||
|
// GET /test
|
||||
|
// ```
|
||||
|
//
|
||||
|
// Response:
|
||||
|
// ```
|
||||
|
// HTTP/1.1 200 OK
|
||||
|
// Content-Type: applicatoin/json
|
||||
|
// {}
|
||||
|
// ```
|
||||
|
func (*Heartbeat) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
||||
|
return &struct{}{}, nil |
||||
|
} |
@ -0,0 +1,158 @@ |
|||||
|
package handlers |
||||
|
|
||||
|
import ( |
||||
|
"database/sql" |
||||
|
"encoding/json" |
||||
|
"net/http" |
||||
|
"sync" |
||||
|
|
||||
|
log "github.com/Sirupsen/logrus" |
||||
|
"github.com/matrix-org/go-neb/api" |
||||
|
"github.com/matrix-org/go-neb/clients" |
||||
|
"github.com/matrix-org/go-neb/database" |
||||
|
"github.com/matrix-org/go-neb/errors" |
||||
|
"github.com/matrix-org/go-neb/metrics" |
||||
|
"github.com/matrix-org/go-neb/polling" |
||||
|
"github.com/matrix-org/go-neb/types" |
||||
|
) |
||||
|
|
||||
|
type ConfigureService struct { |
||||
|
db *database.ServiceDB |
||||
|
clients *clients.Clients |
||||
|
mapMutex sync.Mutex |
||||
|
mutexByServiceID map[string]*sync.Mutex |
||||
|
} |
||||
|
|
||||
|
func NewConfigureService(db *database.ServiceDB, clients *clients.Clients) *ConfigureService { |
||||
|
return &ConfigureService{ |
||||
|
db: db, |
||||
|
clients: clients, |
||||
|
mutexByServiceID: make(map[string]*sync.Mutex), |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
func (s *ConfigureService) getMutexForServiceID(serviceID string) *sync.Mutex { |
||||
|
s.mapMutex.Lock() |
||||
|
defer s.mapMutex.Unlock() |
||||
|
m := s.mutexByServiceID[serviceID] |
||||
|
if m == nil { |
||||
|
// XXX TODO: There's a memory leak here. The amount of mutexes created is unbounded, as there will be 1 per service which are never deleted.
|
||||
|
// A better solution would be to have a striped hash map with a bounded pool of mutexes. We can't live with a single global mutex because the Register()
|
||||
|
// function this is protecting does many many HTTP requests which can take a long time on bad networks and will head of line block other services.
|
||||
|
m = &sync.Mutex{} |
||||
|
s.mutexByServiceID[serviceID] = m |
||||
|
} |
||||
|
return m |
||||
|
} |
||||
|
|
||||
|
func (s *ConfigureService) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
||||
|
if req.Method != "POST" { |
||||
|
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
||||
|
} |
||||
|
|
||||
|
service, httpErr := s.createService(req) |
||||
|
if httpErr != nil { |
||||
|
return nil, httpErr |
||||
|
} |
||||
|
log.WithFields(log.Fields{ |
||||
|
"service_id": service.ServiceID(), |
||||
|
"service_type": service.ServiceType(), |
||||
|
"service_user_id": service.ServiceUserID(), |
||||
|
}).Print("Incoming configure service request") |
||||
|
|
||||
|
// Have mutexes around each service to queue up multiple requests for the same service ID
|
||||
|
mut := s.getMutexForServiceID(service.ServiceID()) |
||||
|
mut.Lock() |
||||
|
defer mut.Unlock() |
||||
|
|
||||
|
old, err := s.db.LoadService(service.ServiceID()) |
||||
|
if err != nil && err != sql.ErrNoRows { |
||||
|
return nil, &errors.HTTPError{err, "Error loading old service", 500} |
||||
|
} |
||||
|
|
||||
|
client, err := s.clients.Client(service.ServiceUserID()) |
||||
|
if err != nil { |
||||
|
return nil, &errors.HTTPError{err, "Unknown matrix client", 400} |
||||
|
} |
||||
|
|
||||
|
if err = service.Register(old, client); err != nil { |
||||
|
return nil, &errors.HTTPError{err, "Failed to register service: " + err.Error(), 500} |
||||
|
} |
||||
|
|
||||
|
oldService, err := s.db.StoreService(service) |
||||
|
if err != nil { |
||||
|
return nil, &errors.HTTPError{err, "Error storing service", 500} |
||||
|
} |
||||
|
|
||||
|
// Start any polling NOW because they may decide to stop it in PostRegister, and we want to make
|
||||
|
// sure we'll actually stop.
|
||||
|
if _, ok := service.(types.Poller); ok { |
||||
|
if err := polling.StartPolling(service); err != nil { |
||||
|
log.WithFields(log.Fields{ |
||||
|
"service_id": service.ServiceID(), |
||||
|
log.ErrorKey: err, |
||||
|
}).Error("Failed to start poll loop.") |
||||
|
} |
||||
|
} |
||||
|
|
||||
|
service.PostRegister(old) |
||||
|
metrics.IncrementConfigureService(service.ServiceType()) |
||||
|
|
||||
|
return &struct { |
||||
|
ID string |
||||
|
Type string |
||||
|
OldConfig types.Service |
||||
|
NewConfig types.Service |
||||
|
}{service.ServiceID(), service.ServiceType(), oldService, service}, nil |
||||
|
} |
||||
|
|
||||
|
func (s *ConfigureService) createService(req *http.Request) (types.Service, *errors.HTTPError) { |
||||
|
var body api.ConfigureServiceRequest |
||||
|
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
||||
|
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
||||
|
} |
||||
|
|
||||
|
if err := body.Check(); err != nil { |
||||
|
return nil, &errors.HTTPError{err, err.Error(), 400} |
||||
|
} |
||||
|
|
||||
|
service, err := types.CreateService(body.ID, body.Type, body.UserID, body.Config) |
||||
|
if err != nil { |
||||
|
return nil, &errors.HTTPError{err, "Error parsing config JSON", 400} |
||||
|
} |
||||
|
return service, nil |
||||
|
} |
||||
|
|
||||
|
type GetService struct { |
||||
|
Db *database.ServiceDB |
||||
|
} |
||||
|
|
||||
|
func (h *GetService) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
||||
|
if req.Method != "POST" { |
||||
|
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
||||
|
} |
||||
|
var body struct { |
||||
|
ID string |
||||
|
} |
||||
|
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
||||
|
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
||||
|
} |
||||
|
|
||||
|
if body.ID == "" { |
||||
|
return nil, &errors.HTTPError{nil, `Must supply a "ID"`, 400} |
||||
|
} |
||||
|
|
||||
|
srv, err := h.Db.LoadService(body.ID) |
||||
|
if err != nil { |
||||
|
if err == sql.ErrNoRows { |
||||
|
return nil, &errors.HTTPError{err, `Service not found`, 404} |
||||
|
} |
||||
|
return nil, &errors.HTTPError{err, `Failed to load service`, 500} |
||||
|
} |
||||
|
|
||||
|
return &struct { |
||||
|
ID string |
||||
|
Type string |
||||
|
Config types.Service |
||||
|
}{srv.ServiceID(), srv.ServiceType(), srv}, nil |
||||
|
} |
@ -0,0 +1,63 @@ |
|||||
|
package handlers |
||||
|
|
||||
|
import ( |
||||
|
"encoding/base64" |
||||
|
log "github.com/Sirupsen/logrus" |
||||
|
"github.com/matrix-org/go-neb/clients" |
||||
|
"github.com/matrix-org/go-neb/database" |
||||
|
"github.com/matrix-org/go-neb/metrics" |
||||
|
"net/http" |
||||
|
"strings" |
||||
|
) |
||||
|
|
||||
|
// Webhook represents an HTTP handler capable of accepting webhook requests on behalf of services.
|
||||
|
type Webhook struct { |
||||
|
db *database.ServiceDB |
||||
|
clients *clients.Clients |
||||
|
} |
||||
|
|
||||
|
// NewWebhook returns a new webhook HTTP handler
|
||||
|
func NewWebhook(db *database.ServiceDB, cli *clients.Clients) *Webhook { |
||||
|
return &Webhook{db, cli} |
||||
|
} |
||||
|
|
||||
|
// Handle an incoming webhook HTTP request.
|
||||
|
//
|
||||
|
// The webhook MUST have a known base64 encoded service ID as the last path segment
|
||||
|
// in order for this request to be passed to the correct service.
|
||||
|
func (wh *Webhook) Handle(w http.ResponseWriter, req *http.Request) { |
||||
|
log.WithField("path", req.URL.Path).Print("Incoming webhook request") |
||||
|
segments := strings.Split(req.URL.Path, "/") |
||||
|
// last path segment is the service ID which we will pass the incoming request to,
|
||||
|
// but we've base64d it.
|
||||
|
base64srvID := segments[len(segments)-1] |
||||
|
bytesSrvID, err := base64.RawURLEncoding.DecodeString(base64srvID) |
||||
|
if err != nil { |
||||
|
log.WithError(err).WithField("base64_service_id", base64srvID).Print( |
||||
|
"Not a b64 encoded string", |
||||
|
) |
||||
|
w.WriteHeader(400) |
||||
|
return |
||||
|
} |
||||
|
srvID := string(bytesSrvID) |
||||
|
|
||||
|
service, err := wh.db.LoadService(srvID) |
||||
|
if err != nil { |
||||
|
log.WithError(err).WithField("service_id", srvID).Print("Failed to load service") |
||||
|
w.WriteHeader(404) |
||||
|
return |
||||
|
} |
||||
|
cli, err := wh.clients.Client(service.ServiceUserID()) |
||||
|
if err != nil { |
||||
|
log.WithError(err).WithField("user_id", service.ServiceUserID()).Print( |
||||
|
"Failed to retrieve matrix client instance") |
||||
|
w.WriteHeader(500) |
||||
|
return |
||||
|
} |
||||
|
log.WithFields(log.Fields{ |
||||
|
"service_id": service.ServiceID(), |
||||
|
"service_type": service.ServiceType(), |
||||
|
}).Print("Incoming webhook for service") |
||||
|
metrics.IncrementWebhook(service.ServiceType()) |
||||
|
service.OnReceiveWebhook(w, req, cli) |
||||
|
} |
@ -1,419 +0,0 @@ |
|||||
package main |
|
||||
|
|
||||
import ( |
|
||||
"database/sql" |
|
||||
"encoding/base64" |
|
||||
"encoding/json" |
|
||||
log "github.com/Sirupsen/logrus" |
|
||||
"github.com/matrix-org/go-neb/api" |
|
||||
"github.com/matrix-org/go-neb/clients" |
|
||||
"github.com/matrix-org/go-neb/database" |
|
||||
"github.com/matrix-org/go-neb/errors" |
|
||||
"github.com/matrix-org/go-neb/metrics" |
|
||||
"github.com/matrix-org/go-neb/polling" |
|
||||
"github.com/matrix-org/go-neb/types" |
|
||||
"net/http" |
|
||||
"strings" |
|
||||
"sync" |
|
||||
) |
|
||||
|
|
||||
type heartbeatHandler struct{} |
|
||||
|
|
||||
func (*heartbeatHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|
||||
return &struct{}{}, nil |
|
||||
} |
|
||||
|
|
||||
type requestAuthSessionHandler struct { |
|
||||
db *database.ServiceDB |
|
||||
} |
|
||||
|
|
||||
func (h *requestAuthSessionHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|
||||
if req.Method != "POST" { |
|
||||
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
|
||||
} |
|
||||
var body struct { |
|
||||
RealmID string |
|
||||
UserID string |
|
||||
Config json.RawMessage |
|
||||
} |
|
||||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
|
||||
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
|
||||
} |
|
||||
log.WithFields(log.Fields{ |
|
||||
"realm_id": body.RealmID, |
|
||||
"user_id": body.UserID, |
|
||||
}).Print("Incoming auth session request") |
|
||||
|
|
||||
if body.UserID == "" || body.RealmID == "" || body.Config == nil { |
|
||||
return nil, &errors.HTTPError{nil, `Must supply a "UserID", a "RealmID" and a "Config"`, 400} |
|
||||
} |
|
||||
|
|
||||
realm, err := h.db.LoadAuthRealm(body.RealmID) |
|
||||
if err != nil { |
|
||||
return nil, &errors.HTTPError{err, "Unknown RealmID", 400} |
|
||||
} |
|
||||
|
|
||||
response := realm.RequestAuthSession(body.UserID, body.Config) |
|
||||
if response == nil { |
|
||||
return nil, &errors.HTTPError{nil, "Failed to request auth session", 500} |
|
||||
} |
|
||||
|
|
||||
metrics.IncrementAuthSession(realm.Type()) |
|
||||
|
|
||||
return response, nil |
|
||||
} |
|
||||
|
|
||||
type removeAuthSessionHandler struct { |
|
||||
db *database.ServiceDB |
|
||||
} |
|
||||
|
|
||||
func (h *removeAuthSessionHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|
||||
if req.Method != "POST" { |
|
||||
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
|
||||
} |
|
||||
var body struct { |
|
||||
RealmID string |
|
||||
UserID string |
|
||||
} |
|
||||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
|
||||
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
|
||||
} |
|
||||
log.WithFields(log.Fields{ |
|
||||
"realm_id": body.RealmID, |
|
||||
"user_id": body.UserID, |
|
||||
}).Print("Incoming remove auth session request") |
|
||||
|
|
||||
if body.UserID == "" || body.RealmID == "" { |
|
||||
return nil, &errors.HTTPError{nil, `Must supply a "UserID", a "RealmID"`, 400} |
|
||||
} |
|
||||
|
|
||||
_, err := h.db.LoadAuthRealm(body.RealmID) |
|
||||
if err != nil { |
|
||||
return nil, &errors.HTTPError{err, "Unknown RealmID", 400} |
|
||||
} |
|
||||
|
|
||||
if err := h.db.RemoveAuthSession(body.RealmID, body.UserID); err != nil { |
|
||||
return nil, &errors.HTTPError{err, "Failed to remove auth session", 500} |
|
||||
} |
|
||||
|
|
||||
return []byte(`{}`), nil |
|
||||
} |
|
||||
|
|
||||
type realmRedirectHandler struct { |
|
||||
db *database.ServiceDB |
|
||||
} |
|
||||
|
|
||||
func (rh *realmRedirectHandler) handle(w http.ResponseWriter, req *http.Request) { |
|
||||
segments := strings.Split(req.URL.Path, "/") |
|
||||
// last path segment is the base64d realm ID which we will pass the incoming request to
|
|
||||
base64realmID := segments[len(segments)-1] |
|
||||
bytesRealmID, err := base64.RawURLEncoding.DecodeString(base64realmID) |
|
||||
realmID := string(bytesRealmID) |
|
||||
if err != nil { |
|
||||
log.WithError(err).WithField("base64_realm_id", base64realmID).Print( |
|
||||
"Not a b64 encoded string", |
|
||||
) |
|
||||
w.WriteHeader(400) |
|
||||
return |
|
||||
} |
|
||||
|
|
||||
realm, err := rh.db.LoadAuthRealm(realmID) |
|
||||
if err != nil { |
|
||||
log.WithError(err).WithField("realm_id", realmID).Print("Failed to load realm") |
|
||||
w.WriteHeader(404) |
|
||||
return |
|
||||
} |
|
||||
log.WithFields(log.Fields{ |
|
||||
"realm_id": realmID, |
|
||||
}).Print("Incoming realm redirect request") |
|
||||
realm.OnReceiveRedirect(w, req) |
|
||||
} |
|
||||
|
|
||||
type configureAuthRealmHandler struct { |
|
||||
db *database.ServiceDB |
|
||||
} |
|
||||
|
|
||||
func (h *configureAuthRealmHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|
||||
if req.Method != "POST" { |
|
||||
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
|
||||
} |
|
||||
var body api.ConfigureAuthRealmRequest |
|
||||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
|
||||
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
|
||||
} |
|
||||
|
|
||||
if err := body.Check(); err != nil { |
|
||||
return nil, &errors.HTTPError{err, err.Error(), 400} |
|
||||
} |
|
||||
|
|
||||
realm, err := types.CreateAuthRealm(body.ID, body.Type, body.Config) |
|
||||
if err != nil { |
|
||||
return nil, &errors.HTTPError{err, "Error parsing config JSON", 400} |
|
||||
} |
|
||||
|
|
||||
if err = realm.Register(); err != nil { |
|
||||
return nil, &errors.HTTPError{err, "Error registering auth realm", 400} |
|
||||
} |
|
||||
|
|
||||
oldRealm, err := h.db.StoreAuthRealm(realm) |
|
||||
if err != nil { |
|
||||
return nil, &errors.HTTPError{err, "Error storing realm", 500} |
|
||||
} |
|
||||
|
|
||||
return &struct { |
|
||||
ID string |
|
||||
Type string |
|
||||
OldConfig types.AuthRealm |
|
||||
NewConfig types.AuthRealm |
|
||||
}{body.ID, body.Type, oldRealm, realm}, nil |
|
||||
} |
|
||||
|
|
||||
type webhookHandler struct { |
|
||||
db *database.ServiceDB |
|
||||
clients *clients.Clients |
|
||||
} |
|
||||
|
|
||||
func (wh *webhookHandler) handle(w http.ResponseWriter, req *http.Request) { |
|
||||
log.WithField("path", req.URL.Path).Print("Incoming webhook request") |
|
||||
segments := strings.Split(req.URL.Path, "/") |
|
||||
// last path segment is the service ID which we will pass the incoming request to,
|
|
||||
// but we've base64d it.
|
|
||||
base64srvID := segments[len(segments)-1] |
|
||||
bytesSrvID, err := base64.RawURLEncoding.DecodeString(base64srvID) |
|
||||
if err != nil { |
|
||||
log.WithError(err).WithField("base64_service_id", base64srvID).Print( |
|
||||
"Not a b64 encoded string", |
|
||||
) |
|
||||
w.WriteHeader(400) |
|
||||
return |
|
||||
} |
|
||||
srvID := string(bytesSrvID) |
|
||||
|
|
||||
service, err := wh.db.LoadService(srvID) |
|
||||
if err != nil { |
|
||||
log.WithError(err).WithField("service_id", srvID).Print("Failed to load service") |
|
||||
w.WriteHeader(404) |
|
||||
return |
|
||||
} |
|
||||
cli, err := wh.clients.Client(service.ServiceUserID()) |
|
||||
if err != nil { |
|
||||
log.WithError(err).WithField("user_id", service.ServiceUserID()).Print( |
|
||||
"Failed to retrieve matrix client instance") |
|
||||
w.WriteHeader(500) |
|
||||
return |
|
||||
} |
|
||||
log.WithFields(log.Fields{ |
|
||||
"service_id": service.ServiceID(), |
|
||||
"service_type": service.ServiceType(), |
|
||||
}).Print("Incoming webhook for service") |
|
||||
metrics.IncrementWebhook(service.ServiceType()) |
|
||||
service.OnReceiveWebhook(w, req, cli) |
|
||||
} |
|
||||
|
|
||||
type configureClientHandler struct { |
|
||||
db *database.ServiceDB |
|
||||
clients *clients.Clients |
|
||||
} |
|
||||
|
|
||||
func (s *configureClientHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|
||||
if req.Method != "POST" { |
|
||||
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
|
||||
} |
|
||||
|
|
||||
var body api.ClientConfig |
|
||||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
|
||||
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
|
||||
} |
|
||||
|
|
||||
if err := body.Check(); err != nil { |
|
||||
return nil, &errors.HTTPError{err, "Error parsing client config", 400} |
|
||||
} |
|
||||
|
|
||||
oldClient, err := s.clients.Update(body) |
|
||||
if err != nil { |
|
||||
return nil, &errors.HTTPError{err, "Error storing token", 500} |
|
||||
} |
|
||||
|
|
||||
return &struct { |
|
||||
OldClient api.ClientConfig |
|
||||
NewClient api.ClientConfig |
|
||||
}{oldClient, body}, nil |
|
||||
} |
|
||||
|
|
||||
type configureServiceHandler struct { |
|
||||
db *database.ServiceDB |
|
||||
clients *clients.Clients |
|
||||
mapMutex sync.Mutex |
|
||||
mutexByServiceID map[string]*sync.Mutex |
|
||||
} |
|
||||
|
|
||||
func newConfigureServiceHandler(db *database.ServiceDB, clients *clients.Clients) *configureServiceHandler { |
|
||||
return &configureServiceHandler{ |
|
||||
db: db, |
|
||||
clients: clients, |
|
||||
mutexByServiceID: make(map[string]*sync.Mutex), |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
func (s *configureServiceHandler) getMutexForServiceID(serviceID string) *sync.Mutex { |
|
||||
s.mapMutex.Lock() |
|
||||
defer s.mapMutex.Unlock() |
|
||||
m := s.mutexByServiceID[serviceID] |
|
||||
if m == nil { |
|
||||
// XXX TODO: There's a memory leak here. The amount of mutexes created is unbounded, as there will be 1 per service which are never deleted.
|
|
||||
// A better solution would be to have a striped hash map with a bounded pool of mutexes. We can't live with a single global mutex because the Register()
|
|
||||
// function this is protecting does many many HTTP requests which can take a long time on bad networks and will head of line block other services.
|
|
||||
m = &sync.Mutex{} |
|
||||
s.mutexByServiceID[serviceID] = m |
|
||||
} |
|
||||
return m |
|
||||
} |
|
||||
|
|
||||
func (s *configureServiceHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|
||||
if req.Method != "POST" { |
|
||||
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
|
||||
} |
|
||||
|
|
||||
service, httpErr := s.createService(req) |
|
||||
if httpErr != nil { |
|
||||
return nil, httpErr |
|
||||
} |
|
||||
log.WithFields(log.Fields{ |
|
||||
"service_id": service.ServiceID(), |
|
||||
"service_type": service.ServiceType(), |
|
||||
"service_user_id": service.ServiceUserID(), |
|
||||
}).Print("Incoming configure service request") |
|
||||
|
|
||||
// Have mutexes around each service to queue up multiple requests for the same service ID
|
|
||||
mut := s.getMutexForServiceID(service.ServiceID()) |
|
||||
mut.Lock() |
|
||||
defer mut.Unlock() |
|
||||
|
|
||||
old, err := s.db.LoadService(service.ServiceID()) |
|
||||
if err != nil && err != sql.ErrNoRows { |
|
||||
return nil, &errors.HTTPError{err, "Error loading old service", 500} |
|
||||
} |
|
||||
|
|
||||
client, err := s.clients.Client(service.ServiceUserID()) |
|
||||
if err != nil { |
|
||||
return nil, &errors.HTTPError{err, "Unknown matrix client", 400} |
|
||||
} |
|
||||
|
|
||||
if err = service.Register(old, client); err != nil { |
|
||||
return nil, &errors.HTTPError{err, "Failed to register service: " + err.Error(), 500} |
|
||||
} |
|
||||
|
|
||||
oldService, err := s.db.StoreService(service) |
|
||||
if err != nil { |
|
||||
return nil, &errors.HTTPError{err, "Error storing service", 500} |
|
||||
} |
|
||||
|
|
||||
// Start any polling NOW because they may decide to stop it in PostRegister, and we want to make
|
|
||||
// sure we'll actually stop.
|
|
||||
if _, ok := service.(types.Poller); ok { |
|
||||
if err := polling.StartPolling(service); err != nil { |
|
||||
log.WithFields(log.Fields{ |
|
||||
"service_id": service.ServiceID(), |
|
||||
log.ErrorKey: err, |
|
||||
}).Error("Failed to start poll loop.") |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
service.PostRegister(old) |
|
||||
metrics.IncrementConfigureService(service.ServiceType()) |
|
||||
|
|
||||
return &struct { |
|
||||
ID string |
|
||||
Type string |
|
||||
OldConfig types.Service |
|
||||
NewConfig types.Service |
|
||||
}{service.ServiceID(), service.ServiceType(), oldService, service}, nil |
|
||||
} |
|
||||
|
|
||||
func (s *configureServiceHandler) createService(req *http.Request) (types.Service, *errors.HTTPError) { |
|
||||
var body api.ConfigureServiceRequest |
|
||||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
|
||||
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
|
||||
} |
|
||||
|
|
||||
if err := body.Check(); err != nil { |
|
||||
return nil, &errors.HTTPError{err, err.Error(), 400} |
|
||||
} |
|
||||
|
|
||||
service, err := types.CreateService(body.ID, body.Type, body.UserID, body.Config) |
|
||||
if err != nil { |
|
||||
return nil, &errors.HTTPError{err, "Error parsing config JSON", 400} |
|
||||
} |
|
||||
return service, nil |
|
||||
} |
|
||||
|
|
||||
type getServiceHandler struct { |
|
||||
db *database.ServiceDB |
|
||||
} |
|
||||
|
|
||||
func (h *getServiceHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|
||||
if req.Method != "POST" { |
|
||||
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
|
||||
} |
|
||||
var body struct { |
|
||||
ID string |
|
||||
} |
|
||||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
|
||||
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
|
||||
} |
|
||||
|
|
||||
if body.ID == "" { |
|
||||
return nil, &errors.HTTPError{nil, `Must supply a "ID"`, 400} |
|
||||
} |
|
||||
|
|
||||
srv, err := h.db.LoadService(body.ID) |
|
||||
if err != nil { |
|
||||
if err == sql.ErrNoRows { |
|
||||
return nil, &errors.HTTPError{err, `Service not found`, 404} |
|
||||
} |
|
||||
return nil, &errors.HTTPError{err, `Failed to load service`, 500} |
|
||||
} |
|
||||
|
|
||||
return &struct { |
|
||||
ID string |
|
||||
Type string |
|
||||
Config types.Service |
|
||||
}{srv.ServiceID(), srv.ServiceType(), srv}, nil |
|
||||
} |
|
||||
|
|
||||
type getSessionHandler struct { |
|
||||
db *database.ServiceDB |
|
||||
} |
|
||||
|
|
||||
func (h *getSessionHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) { |
|
||||
if req.Method != "POST" { |
|
||||
return nil, &errors.HTTPError{nil, "Unsupported Method", 405} |
|
||||
} |
|
||||
var body struct { |
|
||||
RealmID string |
|
||||
UserID string |
|
||||
} |
|
||||
if err := json.NewDecoder(req.Body).Decode(&body); err != nil { |
|
||||
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400} |
|
||||
} |
|
||||
|
|
||||
if body.RealmID == "" || body.UserID == "" { |
|
||||
return nil, &errors.HTTPError{nil, `Must supply a "RealmID" and "UserID"`, 400} |
|
||||
} |
|
||||
|
|
||||
session, err := h.db.LoadAuthSessionByUser(body.RealmID, body.UserID) |
|
||||
if err != nil && err != sql.ErrNoRows { |
|
||||
return nil, &errors.HTTPError{err, `Failed to load session`, 500} |
|
||||
} |
|
||||
if err == sql.ErrNoRows { |
|
||||
return &struct { |
|
||||
Authenticated bool |
|
||||
}{false}, nil |
|
||||
} |
|
||||
|
|
||||
return &struct { |
|
||||
ID string |
|
||||
Authenticated bool |
|
||||
Info interface{} |
|
||||
}{session.ID(), session.Authenticated(), session.Info()}, nil |
|
||||
} |
|
Write
Preview
Loading…
Cancel
Save
Reference in new issue