mirror of https://github.com/matrix-org/go-neb.git
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
376 lines
11 KiB
376 lines
11 KiB
package main
|
|
|
|
import (
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
log "github.com/Sirupsen/logrus"
|
|
"github.com/matrix-org/go-neb/clients"
|
|
"github.com/matrix-org/go-neb/database"
|
|
"github.com/matrix-org/go-neb/errors"
|
|
"github.com/matrix-org/go-neb/types"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
)
|
|
|
|
type heartbeatHandler struct{}
|
|
|
|
func (*heartbeatHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) {
|
|
return &struct{}{}, nil
|
|
}
|
|
|
|
type requestAuthSessionHandler struct {
|
|
db *database.ServiceDB
|
|
}
|
|
|
|
func (h *requestAuthSessionHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) {
|
|
if req.Method != "POST" {
|
|
return nil, &errors.HTTPError{nil, "Unsupported Method", 405}
|
|
}
|
|
var body struct {
|
|
RealmID string
|
|
UserID string
|
|
Config json.RawMessage
|
|
}
|
|
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
|
|
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400}
|
|
}
|
|
log.WithFields(log.Fields{
|
|
"realm_id": body.RealmID,
|
|
"user_id": body.UserID,
|
|
}).Print("Incoming auth session request")
|
|
|
|
if body.UserID == "" || body.RealmID == "" || body.Config == nil {
|
|
return nil, &errors.HTTPError{nil, `Must supply a "UserID", a "RealmID" and a "Config"`, 400}
|
|
}
|
|
|
|
realm, err := h.db.LoadAuthRealm(body.RealmID)
|
|
if err != nil {
|
|
return nil, &errors.HTTPError{err, "Unknown RealmID", 400}
|
|
}
|
|
|
|
response := realm.RequestAuthSession(body.UserID, body.Config)
|
|
if response == nil {
|
|
return nil, &errors.HTTPError{nil, "Failed to request auth session", 500}
|
|
}
|
|
|
|
return response, nil
|
|
}
|
|
|
|
type realmRedirectHandler struct {
|
|
db *database.ServiceDB
|
|
}
|
|
|
|
func (rh *realmRedirectHandler) handle(w http.ResponseWriter, req *http.Request) {
|
|
segments := strings.Split(req.URL.Path, "/")
|
|
// last path segment is the base64d realm ID which we will pass the incoming request to
|
|
base64realmID := segments[len(segments)-1]
|
|
bytesRealmID, err := base64.RawURLEncoding.DecodeString(base64realmID)
|
|
realmID := string(bytesRealmID)
|
|
if err != nil {
|
|
log.WithError(err).WithField("base64_realm_id", base64realmID).Print(
|
|
"Not a b64 encoded string",
|
|
)
|
|
w.WriteHeader(400)
|
|
return
|
|
}
|
|
|
|
realm, err := rh.db.LoadAuthRealm(realmID)
|
|
if err != nil {
|
|
log.WithError(err).WithField("realm_id", realmID).Print("Failed to load realm")
|
|
w.WriteHeader(404)
|
|
return
|
|
}
|
|
log.WithFields(log.Fields{
|
|
"realm_id": realmID,
|
|
}).Print("Incoming realm redirect request")
|
|
realm.OnReceiveRedirect(w, req)
|
|
}
|
|
|
|
type configureAuthRealmHandler struct {
|
|
db *database.ServiceDB
|
|
}
|
|
|
|
func (h *configureAuthRealmHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) {
|
|
if req.Method != "POST" {
|
|
return nil, &errors.HTTPError{nil, "Unsupported Method", 405}
|
|
}
|
|
var body struct {
|
|
ID string
|
|
Type string
|
|
Config json.RawMessage
|
|
}
|
|
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
|
|
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400}
|
|
}
|
|
|
|
if body.ID == "" || body.Type == "" || body.Config == nil {
|
|
return nil, &errors.HTTPError{nil, `Must supply a "ID", a "Type" and a "Config"`, 400}
|
|
}
|
|
|
|
realm, err := types.CreateAuthRealm(body.ID, body.Type, body.Config)
|
|
if err != nil {
|
|
return nil, &errors.HTTPError{err, "Error parsing config JSON", 400}
|
|
}
|
|
|
|
if err = realm.Register(); err != nil {
|
|
return nil, &errors.HTTPError{err, "Error registering auth realm", 400}
|
|
}
|
|
|
|
oldRealm, err := h.db.StoreAuthRealm(realm)
|
|
if err != nil {
|
|
return nil, &errors.HTTPError{err, "Error storing realm", 500}
|
|
}
|
|
|
|
return &struct {
|
|
ID string
|
|
Type string
|
|
OldConfig types.AuthRealm
|
|
NewConfig types.AuthRealm
|
|
}{body.ID, body.Type, oldRealm, realm}, nil
|
|
}
|
|
|
|
type webhookHandler struct {
|
|
db *database.ServiceDB
|
|
clients *clients.Clients
|
|
}
|
|
|
|
func (wh *webhookHandler) handle(w http.ResponseWriter, req *http.Request) {
|
|
log.WithField("path", req.URL.Path).Print("Incoming webhook request")
|
|
segments := strings.Split(req.URL.Path, "/")
|
|
// last path segment is the service ID which we will pass the incoming request to,
|
|
// but we've base64d it.
|
|
base64srvID := segments[len(segments)-1]
|
|
bytesSrvID, err := base64.RawURLEncoding.DecodeString(base64srvID)
|
|
srvID := string(bytesSrvID)
|
|
if err != nil {
|
|
log.WithError(err).WithField("base64_service_id", base64srvID).Print(
|
|
"Not a b64 encoded string",
|
|
)
|
|
w.WriteHeader(400)
|
|
return
|
|
}
|
|
|
|
service, err := wh.db.LoadService(srvID)
|
|
if err != nil {
|
|
log.WithError(err).WithField("service_id", srvID).Print("Failed to load service")
|
|
w.WriteHeader(404)
|
|
return
|
|
}
|
|
cli, err := wh.clients.Client(service.ServiceUserID())
|
|
if err != nil {
|
|
log.WithError(err).WithField("user_id", service.ServiceUserID()).Print(
|
|
"Failed to retrieve matrix client instance")
|
|
w.WriteHeader(500)
|
|
return
|
|
}
|
|
log.WithFields(log.Fields{
|
|
"service_id": service.ServiceID(),
|
|
"service_typ": service.ServiceType(),
|
|
}).Print("Incoming webhook for service")
|
|
service.OnReceiveWebhook(w, req, cli)
|
|
}
|
|
|
|
type configureClientHandler struct {
|
|
db *database.ServiceDB
|
|
clients *clients.Clients
|
|
}
|
|
|
|
func (s *configureClientHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) {
|
|
if req.Method != "POST" {
|
|
return nil, &errors.HTTPError{nil, "Unsupported Method", 405}
|
|
}
|
|
|
|
var body types.ClientConfig
|
|
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
|
|
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400}
|
|
}
|
|
|
|
if err := body.Check(); err != nil {
|
|
return nil, &errors.HTTPError{err, "Error parsing client config", 400}
|
|
}
|
|
|
|
oldClient, err := s.clients.Update(body)
|
|
if err != nil {
|
|
return nil, &errors.HTTPError{err, "Error storing token", 500}
|
|
}
|
|
|
|
return &struct {
|
|
OldClient types.ClientConfig
|
|
NewClient types.ClientConfig
|
|
}{oldClient, body}, nil
|
|
}
|
|
|
|
type configureServiceHandler struct {
|
|
db *database.ServiceDB
|
|
clients *clients.Clients
|
|
mapMutex sync.Mutex
|
|
mutexByServiceID map[string]*sync.Mutex
|
|
}
|
|
|
|
func newConfigureServiceHandler(db *database.ServiceDB, clients *clients.Clients) *configureServiceHandler {
|
|
return &configureServiceHandler{
|
|
db: db,
|
|
clients: clients,
|
|
mutexByServiceID: make(map[string]*sync.Mutex),
|
|
}
|
|
}
|
|
|
|
func (s *configureServiceHandler) getMutexForServiceID(serviceID string) *sync.Mutex {
|
|
s.mapMutex.Lock()
|
|
defer s.mapMutex.Unlock()
|
|
m := s.mutexByServiceID[serviceID]
|
|
if m == nil {
|
|
// XXX TODO: There's a memory leak here. The amount of mutexes created is unbounded, as there will be 1 per service which are never deleted.
|
|
// A better solution would be to have a striped hash map with a bounded pool of mutexes. We can't live with a single global mutex because the Register()
|
|
// function this is protecting does many many HTTP requests which can take a long time on bad networks and will head of line block other services.
|
|
m = &sync.Mutex{}
|
|
s.mutexByServiceID[serviceID] = m
|
|
}
|
|
return m
|
|
}
|
|
|
|
func (s *configureServiceHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) {
|
|
if req.Method != "POST" {
|
|
return nil, &errors.HTTPError{nil, "Unsupported Method", 405}
|
|
}
|
|
|
|
service, httpErr := s.createService(req)
|
|
if httpErr != nil {
|
|
return nil, httpErr
|
|
}
|
|
log.WithFields(log.Fields{
|
|
"service_id": service.ServiceID(),
|
|
"service_type": service.ServiceType(),
|
|
"service_user_id": service.ServiceUserID(),
|
|
}).Print("Incoming configure service request")
|
|
|
|
// Have mutexes around each service to queue up multiple requests for the same service ID
|
|
mut := s.getMutexForServiceID(service.ServiceID())
|
|
mut.Lock()
|
|
defer mut.Unlock()
|
|
|
|
old, err := s.db.LoadService(service.ServiceID())
|
|
if err != nil && err != sql.ErrNoRows {
|
|
return nil, &errors.HTTPError{err, "Error loading old service", 500}
|
|
}
|
|
|
|
client, err := s.clients.Client(service.ServiceUserID())
|
|
if err != nil {
|
|
return nil, &errors.HTTPError{err, "Unknown matrix client", 400}
|
|
}
|
|
|
|
if err = service.Register(old, client); err != nil {
|
|
return nil, &errors.HTTPError{err, "Failed to register service: " + err.Error(), 500}
|
|
}
|
|
|
|
oldService, err := s.db.StoreService(service)
|
|
if err != nil {
|
|
return nil, &errors.HTTPError{err, "Error storing service", 500}
|
|
}
|
|
|
|
service.PostRegister(old)
|
|
|
|
return &struct {
|
|
ID string
|
|
Type string
|
|
OldConfig types.Service
|
|
NewConfig types.Service
|
|
}{service.ServiceID(), service.ServiceType(), oldService, service}, nil
|
|
}
|
|
|
|
func (s *configureServiceHandler) createService(req *http.Request) (types.Service, *errors.HTTPError) {
|
|
var body struct {
|
|
ID string
|
|
Type string
|
|
UserID string
|
|
Config json.RawMessage
|
|
}
|
|
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
|
|
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400}
|
|
}
|
|
|
|
if body.ID == "" || body.Type == "" || body.UserID == "" || body.Config == nil {
|
|
return nil, &errors.HTTPError{
|
|
nil, `Must supply an "ID", a "Type", a "UserID" and a "Config"`, 400,
|
|
}
|
|
}
|
|
|
|
service, err := types.CreateService(body.ID, body.Type, body.UserID, body.Config)
|
|
if err != nil {
|
|
return nil, &errors.HTTPError{err, "Error parsing config JSON", 400}
|
|
}
|
|
return service, nil
|
|
}
|
|
|
|
type getServiceHandler struct {
|
|
db *database.ServiceDB
|
|
}
|
|
|
|
func (h *getServiceHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) {
|
|
if req.Method != "POST" {
|
|
return nil, &errors.HTTPError{nil, "Unsupported Method", 405}
|
|
}
|
|
var body struct {
|
|
ID string
|
|
}
|
|
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
|
|
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400}
|
|
}
|
|
|
|
if body.ID == "" {
|
|
return nil, &errors.HTTPError{nil, `Must supply a "ID"`, 400}
|
|
}
|
|
|
|
srv, err := h.db.LoadService(body.ID)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
return nil, &errors.HTTPError{err, `Service not found`, 404}
|
|
}
|
|
return nil, &errors.HTTPError{err, `Failed to load service`, 500}
|
|
}
|
|
|
|
return &struct {
|
|
ID string
|
|
Type string
|
|
Config types.Service
|
|
}{srv.ServiceID(), srv.ServiceType(), srv}, nil
|
|
}
|
|
|
|
type getSessionHandler struct {
|
|
db *database.ServiceDB
|
|
}
|
|
|
|
func (h *getSessionHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) {
|
|
if req.Method != "POST" {
|
|
return nil, &errors.HTTPError{nil, "Unsupported Method", 405}
|
|
}
|
|
var body struct {
|
|
RealmID string
|
|
UserID string
|
|
}
|
|
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
|
|
return nil, &errors.HTTPError{err, "Error parsing request JSON", 400}
|
|
}
|
|
|
|
if body.RealmID == "" || body.UserID == "" {
|
|
return nil, &errors.HTTPError{nil, `Must supply a "RealmID" and "UserID"`, 400}
|
|
}
|
|
|
|
session, err := h.db.LoadAuthSessionByUser(body.RealmID, body.UserID)
|
|
if err != nil && err != sql.ErrNoRows {
|
|
return nil, &errors.HTTPError{err, `Failed to load session`, 500}
|
|
}
|
|
if err == sql.ErrNoRows {
|
|
return &struct {
|
|
Authenticated bool
|
|
}{false}, nil
|
|
}
|
|
|
|
return &struct {
|
|
ID string
|
|
Authenticated bool
|
|
Info interface{}
|
|
}{session.ID(), session.Authenticated(), session.Info()}, nil
|
|
}
|