You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

242 lines
7.0 KiB

package handlers
import (
"database/sql"
"encoding/json"
"fmt"
"net/http"
"sync"
"github.com/matrix-org/go-neb/api"
"github.com/matrix-org/go-neb/clients"
"github.com/matrix-org/go-neb/database"
"github.com/matrix-org/go-neb/matrix"
"github.com/matrix-org/go-neb/metrics"
"github.com/matrix-org/go-neb/polling"
"github.com/matrix-org/go-neb/types"
"github.com/matrix-org/util"
log "github.com/sirupsen/logrus"
"maunium.net/go/mautrix"
)
// ConfigureService represents an HTTP handler which can process /admin/configureService requests.
type ConfigureService struct {
db *database.ServiceDB
clients *clients.Clients
mapMutex sync.Mutex
mutexByServiceID map[string]*sync.Mutex
}
// NewConfigureService creates a new ConfigureService handler
func NewConfigureService(db *database.ServiceDB, clients *clients.Clients) *ConfigureService {
return &ConfigureService{
db: db,
clients: clients,
mutexByServiceID: make(map[string]*sync.Mutex),
}
}
func (s *ConfigureService) getMutexForServiceID(serviceID string) *sync.Mutex {
s.mapMutex.Lock()
defer s.mapMutex.Unlock()
m := s.mutexByServiceID[serviceID]
if m == nil {
// XXX TODO: There's a memory leak here. The amount of mutexes created is unbounded, as there will be 1 per service which are never deleted.
// A better solution would be to have a striped hash map with a bounded pool of mutexes. We can't live with a single global mutex because the Register()
// function this is protecting does many many HTTP requests which can take a long time on bad networks and will head of line block other services.
m = &sync.Mutex{}
s.mutexByServiceID[serviceID] = m
}
return m
}
// OnIncomingRequest handles POST requests to /admin/configureService.
//
// The request body MUST be of type "api.ConfigureServiceRequest".
//
// Request:
// POST /admin/configureService
// {
// "ID": "my_service_id",
// "Type": "service-type",
// "UserID": "@my_bot:localhost",
// "Config": {
// // service-specific config information
// }
// }
// Response:
// HTTP/1.1 200 OK
// {
// "ID": "my_service_id",
// "Type": "service-type",
// "OldConfig": {
// // old service-specific config information
// },
// "NewConfig": {
// // new service-specific config information
// },
// }
func (s *ConfigureService) OnIncomingRequest(req *http.Request) util.JSONResponse {
if req.Method != "POST" {
return util.MessageResponse(405, "Unsupported Method")
}
service, httpErr := s.createService(req)
if httpErr != nil {
return *httpErr
}
logger := util.GetLogger(req.Context())
logger.WithFields(log.Fields{
"service_id": service.ServiceID(),
"service_type": service.ServiceType(),
"service_user_id": service.ServiceUserID(),
}).Print("Incoming configure service request")
// Have mutexes around each service to queue up multiple requests for the same service ID
mut := s.getMutexForServiceID(service.ServiceID())
mut.Lock()
defer mut.Unlock()
old, err := s.db.LoadService(service.ServiceID())
if err != nil && err != sql.ErrNoRows {
logger.WithError(err).Error("Failed to LoadService")
return util.MessageResponse(500, "Error loading old service")
}
client, err := s.clients.Client(service.ServiceUserID())
if err != nil {
return util.MessageResponse(400, "Unknown matrix client")
}
if err := checkClientForService(service, client); err != nil {
return util.MessageResponse(400, err.Error())
}
if err = service.Register(old, client); err != nil {
return util.MessageResponse(500, "Failed to register service: "+err.Error())
}
oldService, err := s.db.StoreService(service)
if err != nil {
logger.WithError(err).Error("Failed to StoreService")
return util.MessageResponse(500, "Error storing service")
}
// Start any polling NOW because they may decide to stop it in PostRegister, and we want to make
// sure we'll actually stop.
if _, ok := service.(types.Poller); ok {
if err := polling.StartPolling(service); err != nil {
logger.WithFields(log.Fields{
"service_id": service.ServiceID(),
log.ErrorKey: err,
}).Error("Failed to start poll loop.")
}
}
service.PostRegister(old)
metrics.IncrementConfigureService(service.ServiceType())
return util.JSONResponse{
Code: 200,
JSON: struct {
ID string
Type string
OldConfig types.Service
NewConfig types.Service
}{service.ServiceID(), service.ServiceType(), oldService, service},
}
}
func (s *ConfigureService) createService(req *http.Request) (types.Service, *util.JSONResponse) {
var body api.ConfigureServiceRequest
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
res := util.MessageResponse(400, "Error parsing request JSON")
return nil, &res
}
if err := body.Check(); err != nil {
res := util.MessageResponse(400, err.Error())
return nil, &res
}
service, err := types.CreateService(body.ID, body.Type, body.UserID, body.Config)
if err != nil {
res := util.MessageResponse(400, "Error parsing config JSON")
return nil, &res
}
return service, nil
}
// GetService represents an HTTP handler which can process /admin/getService requests.
type GetService struct {
Db *database.ServiceDB
}
// OnIncomingRequest handles POST requests to /admin/getService.
//
// The request body MUST be a JSON body which has an "ID" key which represents
// the service ID to get.
//
// Request:
// POST /admin/getService
// {
// "ID": "my_service_id"
// }
// Response:
// HTTP/1.1 200 OK
// {
// "ID": "my_service_id",
// "Type": "github",
// "Config": {
// // service-specific config information
// }
// }
func (h *GetService) OnIncomingRequest(req *http.Request) util.JSONResponse {
if req.Method != "POST" {
return util.MessageResponse(405, "Unsupported Method")
}
var body struct {
ID string
}
if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
return util.MessageResponse(400, "Error parsing request JSON")
}
if body.ID == "" {
return util.MessageResponse(400, `Must supply a "ID"`)
}
srv, err := h.Db.LoadService(body.ID)
if err != nil {
if err == sql.ErrNoRows {
return util.MessageResponse(404, `Service not found`)
}
util.GetLogger(req.Context()).WithError(err).Error("Failed to LoadService")
return util.MessageResponse(500, `Failed to load service`)
}
return util.JSONResponse{
Code: 200,
JSON: struct {
ID string
Type string
Config types.Service
}{srv.ServiceID(), srv.ServiceType(), srv},
}
}
func checkClientForService(service types.Service, client *mautrix.Client) error {
// If there are any commands or expansions for this Service then the service user ID
// MUST be a syncing client or else the Service will never get the incoming command/expansion!
cmds := service.Commands(client)
expans := service.Expansions(client)
if len(cmds) > 0 || len(expans) > 0 {
nebStore := client.Store.(*matrix.NEBStore)
if !nebStore.ClientConfig.Sync {
return fmt.Errorf(
"Service type '%s' requires a syncing client", service.ServiceType(),
)
}
}
return nil
}