You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

242 lines
7.0 KiB

8 years ago
8 years ago
8 years ago
  1. package handlers
  2. import (
  3. "database/sql"
  4. "encoding/json"
  5. "fmt"
  6. "net/http"
  7. "sync"
  8. "github.com/matrix-org/go-neb/api"
  9. "github.com/matrix-org/go-neb/clients"
  10. "github.com/matrix-org/go-neb/database"
  11. "github.com/matrix-org/go-neb/matrix"
  12. "github.com/matrix-org/go-neb/metrics"
  13. "github.com/matrix-org/go-neb/polling"
  14. "github.com/matrix-org/go-neb/types"
  15. "github.com/matrix-org/gomatrix"
  16. "github.com/matrix-org/util"
  17. log "github.com/sirupsen/logrus"
  18. )
  19. // ConfigureService represents an HTTP handler which can process /admin/configureService requests.
  20. type ConfigureService struct {
  21. db *database.ServiceDB
  22. clients *clients.Clients
  23. mapMutex sync.Mutex
  24. mutexByServiceID map[string]*sync.Mutex
  25. }
  26. // NewConfigureService creates a new ConfigureService handler
  27. func NewConfigureService(db *database.ServiceDB, clients *clients.Clients) *ConfigureService {
  28. return &ConfigureService{
  29. db: db,
  30. clients: clients,
  31. mutexByServiceID: make(map[string]*sync.Mutex),
  32. }
  33. }
  34. func (s *ConfigureService) getMutexForServiceID(serviceID string) *sync.Mutex {
  35. s.mapMutex.Lock()
  36. defer s.mapMutex.Unlock()
  37. m := s.mutexByServiceID[serviceID]
  38. if m == nil {
  39. // XXX TODO: There's a memory leak here. The amount of mutexes created is unbounded, as there will be 1 per service which are never deleted.
  40. // A better solution would be to have a striped hash map with a bounded pool of mutexes. We can't live with a single global mutex because the Register()
  41. // function this is protecting does many many HTTP requests which can take a long time on bad networks and will head of line block other services.
  42. m = &sync.Mutex{}
  43. s.mutexByServiceID[serviceID] = m
  44. }
  45. return m
  46. }
  47. // OnIncomingRequest handles POST requests to /admin/configureService.
  48. //
  49. // The request body MUST be of type "api.ConfigureServiceRequest".
  50. //
  51. // Request:
  52. // POST /admin/configureService
  53. // {
  54. // "ID": "my_service_id",
  55. // "Type": "service-type",
  56. // "UserID": "@my_bot:localhost",
  57. // "Config": {
  58. // // service-specific config information
  59. // }
  60. // }
  61. // Response:
  62. // HTTP/1.1 200 OK
  63. // {
  64. // "ID": "my_service_id",
  65. // "Type": "service-type",
  66. // "OldConfig": {
  67. // // old service-specific config information
  68. // },
  69. // "NewConfig": {
  70. // // new service-specific config information
  71. // },
  72. // }
  73. func (s *ConfigureService) OnIncomingRequest(req *http.Request) util.JSONResponse {
  74. if req.Method != "POST" {
  75. return util.MessageResponse(405, "Unsupported Method")
  76. }
  77. service, httpErr := s.createService(req)
  78. if httpErr != nil {
  79. return *httpErr
  80. }
  81. logger := util.GetLogger(req.Context())
  82. logger.WithFields(log.Fields{
  83. "service_id": service.ServiceID(),
  84. "service_type": service.ServiceType(),
  85. "service_user_id": service.ServiceUserID(),
  86. }).Print("Incoming configure service request")
  87. // Have mutexes around each service to queue up multiple requests for the same service ID
  88. mut := s.getMutexForServiceID(service.ServiceID())
  89. mut.Lock()
  90. defer mut.Unlock()
  91. old, err := s.db.LoadService(service.ServiceID())
  92. if err != nil && err != sql.ErrNoRows {
  93. logger.WithError(err).Error("Failed to LoadService")
  94. return util.MessageResponse(500, "Error loading old service")
  95. }
  96. client, err := s.clients.Client(service.ServiceUserID())
  97. if err != nil {
  98. return util.MessageResponse(400, "Unknown matrix client")
  99. }
  100. if err := checkClientForService(service, client); err != nil {
  101. return util.MessageResponse(400, err.Error())
  102. }
  103. if err = service.Register(old, client); err != nil {
  104. return util.MessageResponse(500, "Failed to register service: "+err.Error())
  105. }
  106. oldService, err := s.db.StoreService(service)
  107. if err != nil {
  108. logger.WithError(err).Error("Failed to StoreService")
  109. return util.MessageResponse(500, "Error storing service")
  110. }
  111. // Start any polling NOW because they may decide to stop it in PostRegister, and we want to make
  112. // sure we'll actually stop.
  113. if _, ok := service.(types.Poller); ok {
  114. if err := polling.StartPolling(service); err != nil {
  115. logger.WithFields(log.Fields{
  116. "service_id": service.ServiceID(),
  117. log.ErrorKey: err,
  118. }).Error("Failed to start poll loop.")
  119. }
  120. }
  121. service.PostRegister(old)
  122. metrics.IncrementConfigureService(service.ServiceType())
  123. return util.JSONResponse{
  124. Code: 200,
  125. JSON: struct {
  126. ID string
  127. Type string
  128. OldConfig types.Service
  129. NewConfig types.Service
  130. }{service.ServiceID(), service.ServiceType(), oldService, service},
  131. }
  132. }
  133. func (s *ConfigureService) createService(req *http.Request) (types.Service, *util.JSONResponse) {
  134. var body api.ConfigureServiceRequest
  135. if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
  136. res := util.MessageResponse(400, "Error parsing request JSON")
  137. return nil, &res
  138. }
  139. if err := body.Check(); err != nil {
  140. res := util.MessageResponse(400, err.Error())
  141. return nil, &res
  142. }
  143. service, err := types.CreateService(body.ID, body.Type, body.UserID, body.Config)
  144. if err != nil {
  145. res := util.MessageResponse(400, "Error parsing config JSON")
  146. return nil, &res
  147. }
  148. return service, nil
  149. }
  150. // GetService represents an HTTP handler which can process /admin/getService requests.
  151. type GetService struct {
  152. Db *database.ServiceDB
  153. }
  154. // OnIncomingRequest handles POST requests to /admin/getService.
  155. //
  156. // The request body MUST be a JSON body which has an "ID" key which represents
  157. // the service ID to get.
  158. //
  159. // Request:
  160. // POST /admin/getService
  161. // {
  162. // "ID": "my_service_id"
  163. // }
  164. // Response:
  165. // HTTP/1.1 200 OK
  166. // {
  167. // "ID": "my_service_id",
  168. // "Type": "github",
  169. // "Config": {
  170. // // service-specific config information
  171. // }
  172. // }
  173. func (h *GetService) OnIncomingRequest(req *http.Request) util.JSONResponse {
  174. if req.Method != "POST" {
  175. return util.MessageResponse(405, "Unsupported Method")
  176. }
  177. var body struct {
  178. ID string
  179. }
  180. if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
  181. return util.MessageResponse(400, "Error parsing request JSON")
  182. }
  183. if body.ID == "" {
  184. return util.MessageResponse(400, `Must supply a "ID"`)
  185. }
  186. srv, err := h.Db.LoadService(body.ID)
  187. if err != nil {
  188. if err == sql.ErrNoRows {
  189. return util.MessageResponse(404, `Service not found`)
  190. }
  191. util.GetLogger(req.Context()).WithError(err).Error("Failed to LoadService")
  192. return util.MessageResponse(500, `Failed to load service`)
  193. }
  194. return util.JSONResponse{
  195. Code: 200,
  196. JSON: struct {
  197. ID string
  198. Type string
  199. Config types.Service
  200. }{srv.ServiceID(), srv.ServiceType(), srv},
  201. }
  202. }
  203. func checkClientForService(service types.Service, client *gomatrix.Client) error {
  204. // If there are any commands or expansions for this Service then the service user ID
  205. // MUST be a syncing client or else the Service will never get the incoming command/expansion!
  206. cmds := service.Commands(client)
  207. expans := service.Expansions(client)
  208. if len(cmds) > 0 || len(expans) > 0 {
  209. nebStore := client.Store.(*matrix.NEBStore)
  210. if !nebStore.ClientConfig.Sync {
  211. return fmt.Errorf(
  212. "Service type '%s' requires a syncing client", service.ServiceType(),
  213. )
  214. }
  215. }
  216. return nil
  217. }