You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

241 lines
7.0 KiB

8 years ago
  1. package handlers
  2. import (
  3. "database/sql"
  4. "encoding/json"
  5. "fmt"
  6. "net/http"
  7. "sync"
  8. "github.com/matrix-org/go-neb/api"
  9. "github.com/matrix-org/go-neb/clients"
  10. "github.com/matrix-org/go-neb/database"
  11. "github.com/matrix-org/go-neb/matrix"
  12. "github.com/matrix-org/go-neb/metrics"
  13. "github.com/matrix-org/go-neb/polling"
  14. "github.com/matrix-org/go-neb/types"
  15. "github.com/matrix-org/util"
  16. log "github.com/sirupsen/logrus"
  17. )
  18. // ConfigureService represents an HTTP handler which can process /admin/configureService requests.
  19. type ConfigureService struct {
  20. db *database.ServiceDB
  21. clients *clients.Clients
  22. mapMutex sync.Mutex
  23. mutexByServiceID map[string]*sync.Mutex
  24. }
  25. // NewConfigureService creates a new ConfigureService handler
  26. func NewConfigureService(db *database.ServiceDB, clients *clients.Clients) *ConfigureService {
  27. return &ConfigureService{
  28. db: db,
  29. clients: clients,
  30. mutexByServiceID: make(map[string]*sync.Mutex),
  31. }
  32. }
  33. func (s *ConfigureService) getMutexForServiceID(serviceID string) *sync.Mutex {
  34. s.mapMutex.Lock()
  35. defer s.mapMutex.Unlock()
  36. m := s.mutexByServiceID[serviceID]
  37. if m == nil {
  38. // XXX TODO: There's a memory leak here. The amount of mutexes created is unbounded, as there will be 1 per service which are never deleted.
  39. // A better solution would be to have a striped hash map with a bounded pool of mutexes. We can't live with a single global mutex because the Register()
  40. // function this is protecting does many many HTTP requests which can take a long time on bad networks and will head of line block other services.
  41. m = &sync.Mutex{}
  42. s.mutexByServiceID[serviceID] = m
  43. }
  44. return m
  45. }
  46. // OnIncomingRequest handles POST requests to /admin/configureService.
  47. //
  48. // The request body MUST be of type "api.ConfigureServiceRequest".
  49. //
  50. // Request:
  51. // POST /admin/configureService
  52. // {
  53. // "ID": "my_service_id",
  54. // "Type": "service-type",
  55. // "UserID": "@my_bot:localhost",
  56. // "Config": {
  57. // // service-specific config information
  58. // }
  59. // }
  60. // Response:
  61. // HTTP/1.1 200 OK
  62. // {
  63. // "ID": "my_service_id",
  64. // "Type": "service-type",
  65. // "OldConfig": {
  66. // // old service-specific config information
  67. // },
  68. // "NewConfig": {
  69. // // new service-specific config information
  70. // },
  71. // }
  72. func (s *ConfigureService) OnIncomingRequest(req *http.Request) util.JSONResponse {
  73. if req.Method != "POST" {
  74. return util.MessageResponse(405, "Unsupported Method")
  75. }
  76. service, httpErr := s.createService(req)
  77. if httpErr != nil {
  78. return *httpErr
  79. }
  80. logger := util.GetLogger(req.Context())
  81. logger.WithFields(log.Fields{
  82. "service_id": service.ServiceID(),
  83. "service_type": service.ServiceType(),
  84. "service_user_id": service.ServiceUserID(),
  85. }).Print("Incoming configure service request")
  86. // Have mutexes around each service to queue up multiple requests for the same service ID
  87. mut := s.getMutexForServiceID(service.ServiceID())
  88. mut.Lock()
  89. defer mut.Unlock()
  90. old, err := s.db.LoadService(service.ServiceID())
  91. if err != nil && err != sql.ErrNoRows {
  92. logger.WithError(err).Error("Failed to LoadService")
  93. return util.MessageResponse(500, "Error loading old service")
  94. }
  95. client, err := s.clients.Client(service.ServiceUserID())
  96. if err != nil {
  97. return util.MessageResponse(400, "Unknown matrix client")
  98. }
  99. if err := checkClientForService(service, client); err != nil {
  100. return util.MessageResponse(400, err.Error())
  101. }
  102. if err = service.Register(old, client); err != nil {
  103. return util.MessageResponse(500, "Failed to register service: "+err.Error())
  104. }
  105. oldService, err := s.db.StoreService(service)
  106. if err != nil {
  107. logger.WithError(err).Error("Failed to StoreService")
  108. return util.MessageResponse(500, "Error storing service")
  109. }
  110. // Start any polling NOW because they may decide to stop it in PostRegister, and we want to make
  111. // sure we'll actually stop.
  112. if _, ok := service.(types.Poller); ok {
  113. if err := polling.StartPolling(service); err != nil {
  114. logger.WithFields(log.Fields{
  115. "service_id": service.ServiceID(),
  116. log.ErrorKey: err,
  117. }).Error("Failed to start poll loop.")
  118. }
  119. }
  120. service.PostRegister(old)
  121. metrics.IncrementConfigureService(service.ServiceType())
  122. return util.JSONResponse{
  123. Code: 200,
  124. JSON: struct {
  125. ID string
  126. Type string
  127. OldConfig types.Service
  128. NewConfig types.Service
  129. }{service.ServiceID(), service.ServiceType(), oldService, service},
  130. }
  131. }
  132. func (s *ConfigureService) createService(req *http.Request) (types.Service, *util.JSONResponse) {
  133. var body api.ConfigureServiceRequest
  134. if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
  135. res := util.MessageResponse(400, "Error parsing request JSON")
  136. return nil, &res
  137. }
  138. if err := body.Check(); err != nil {
  139. res := util.MessageResponse(400, err.Error())
  140. return nil, &res
  141. }
  142. service, err := types.CreateService(body.ID, body.Type, body.UserID, body.Config)
  143. if err != nil {
  144. res := util.MessageResponse(400, "Error parsing config JSON")
  145. return nil, &res
  146. }
  147. return service, nil
  148. }
  149. // GetService represents an HTTP handler which can process /admin/getService requests.
  150. type GetService struct {
  151. DB *database.ServiceDB
  152. }
  153. // OnIncomingRequest handles POST requests to /admin/getService.
  154. //
  155. // The request body MUST be a JSON body which has an "ID" key which represents
  156. // the service ID to get.
  157. //
  158. // Request:
  159. // POST /admin/getService
  160. // {
  161. // "ID": "my_service_id"
  162. // }
  163. // Response:
  164. // HTTP/1.1 200 OK
  165. // {
  166. // "ID": "my_service_id",
  167. // "Type": "github",
  168. // "Config": {
  169. // // service-specific config information
  170. // }
  171. // }
  172. func (h *GetService) OnIncomingRequest(req *http.Request) util.JSONResponse {
  173. if req.Method != "POST" {
  174. return util.MessageResponse(405, "Unsupported Method")
  175. }
  176. var body struct {
  177. ID string
  178. }
  179. if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
  180. return util.MessageResponse(400, "Error parsing request JSON")
  181. }
  182. if body.ID == "" {
  183. return util.MessageResponse(400, `Must supply a "ID"`)
  184. }
  185. srv, err := h.DB.LoadService(body.ID)
  186. if err != nil {
  187. if err == sql.ErrNoRows {
  188. return util.MessageResponse(404, `Service not found`)
  189. }
  190. util.GetLogger(req.Context()).WithError(err).Error("Failed to LoadService")
  191. return util.MessageResponse(500, `Failed to load service`)
  192. }
  193. return util.JSONResponse{
  194. Code: 200,
  195. JSON: struct {
  196. ID string
  197. Type string
  198. Config types.Service
  199. }{srv.ServiceID(), srv.ServiceType(), srv},
  200. }
  201. }
  202. func checkClientForService(service types.Service, client *clients.BotClient) error {
  203. // If there are any commands or expansions for this Service then the service user ID
  204. // MUST be a syncing client or else the Service will never get the incoming command/expansion!
  205. cmds := service.Commands(client)
  206. expans := service.Expansions(client)
  207. if len(cmds) > 0 || len(expans) > 0 {
  208. nebStore := client.Store.(*matrix.NEBStore)
  209. if !nebStore.ClientConfig.Sync {
  210. return fmt.Errorf(
  211. "Service type '%s' requires a syncing client", service.ServiceType(),
  212. )
  213. }
  214. }
  215. return nil
  216. }