You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

376 lines
11 KiB

8 years ago
8 years ago
8 years ago
  1. package main
  2. import (
  3. "database/sql"
  4. "encoding/base64"
  5. "encoding/json"
  6. log "github.com/Sirupsen/logrus"
  7. "github.com/matrix-org/go-neb/clients"
  8. "github.com/matrix-org/go-neb/database"
  9. "github.com/matrix-org/go-neb/errors"
  10. "github.com/matrix-org/go-neb/types"
  11. "net/http"
  12. "strings"
  13. "sync"
  14. )
  15. type heartbeatHandler struct{}
  16. func (*heartbeatHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) {
  17. return &struct{}{}, nil
  18. }
  19. type requestAuthSessionHandler struct {
  20. db *database.ServiceDB
  21. }
  22. func (h *requestAuthSessionHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) {
  23. if req.Method != "POST" {
  24. return nil, &errors.HTTPError{nil, "Unsupported Method", 405}
  25. }
  26. var body struct {
  27. RealmID string
  28. UserID string
  29. Config json.RawMessage
  30. }
  31. if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
  32. return nil, &errors.HTTPError{err, "Error parsing request JSON", 400}
  33. }
  34. log.WithFields(log.Fields{
  35. "realm_id": body.RealmID,
  36. "user_id": body.UserID,
  37. }).Print("Incoming auth session request")
  38. if body.UserID == "" || body.RealmID == "" || body.Config == nil {
  39. return nil, &errors.HTTPError{nil, `Must supply a "UserID", a "RealmID" and a "Config"`, 400}
  40. }
  41. realm, err := h.db.LoadAuthRealm(body.RealmID)
  42. if err != nil {
  43. return nil, &errors.HTTPError{err, "Unknown RealmID", 400}
  44. }
  45. response := realm.RequestAuthSession(body.UserID, body.Config)
  46. if response == nil {
  47. return nil, &errors.HTTPError{nil, "Failed to request auth session", 500}
  48. }
  49. return response, nil
  50. }
  51. type realmRedirectHandler struct {
  52. db *database.ServiceDB
  53. }
  54. func (rh *realmRedirectHandler) handle(w http.ResponseWriter, req *http.Request) {
  55. segments := strings.Split(req.URL.Path, "/")
  56. // last path segment is the base64d realm ID which we will pass the incoming request to
  57. base64realmID := segments[len(segments)-1]
  58. bytesRealmID, err := base64.RawURLEncoding.DecodeString(base64realmID)
  59. realmID := string(bytesRealmID)
  60. if err != nil {
  61. log.WithError(err).WithField("base64_realm_id", base64realmID).Print(
  62. "Not a b64 encoded string",
  63. )
  64. w.WriteHeader(400)
  65. return
  66. }
  67. realm, err := rh.db.LoadAuthRealm(realmID)
  68. if err != nil {
  69. log.WithError(err).WithField("realm_id", realmID).Print("Failed to load realm")
  70. w.WriteHeader(404)
  71. return
  72. }
  73. log.WithFields(log.Fields{
  74. "realm_id": realmID,
  75. }).Print("Incoming realm redirect request")
  76. realm.OnReceiveRedirect(w, req)
  77. }
  78. type configureAuthRealmHandler struct {
  79. db *database.ServiceDB
  80. }
  81. func (h *configureAuthRealmHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) {
  82. if req.Method != "POST" {
  83. return nil, &errors.HTTPError{nil, "Unsupported Method", 405}
  84. }
  85. var body struct {
  86. ID string
  87. Type string
  88. Config json.RawMessage
  89. }
  90. if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
  91. return nil, &errors.HTTPError{err, "Error parsing request JSON", 400}
  92. }
  93. if body.ID == "" || body.Type == "" || body.Config == nil {
  94. return nil, &errors.HTTPError{nil, `Must supply a "ID", a "Type" and a "Config"`, 400}
  95. }
  96. realm, err := types.CreateAuthRealm(body.ID, body.Type, body.Config)
  97. if err != nil {
  98. return nil, &errors.HTTPError{err, "Error parsing config JSON", 400}
  99. }
  100. if err = realm.Register(); err != nil {
  101. return nil, &errors.HTTPError{err, "Error registering auth realm", 400}
  102. }
  103. oldRealm, err := h.db.StoreAuthRealm(realm)
  104. if err != nil {
  105. return nil, &errors.HTTPError{err, "Error storing realm", 500}
  106. }
  107. return &struct {
  108. ID string
  109. Type string
  110. OldConfig types.AuthRealm
  111. NewConfig types.AuthRealm
  112. }{body.ID, body.Type, oldRealm, realm}, nil
  113. }
  114. type webhookHandler struct {
  115. db *database.ServiceDB
  116. clients *clients.Clients
  117. }
  118. func (wh *webhookHandler) handle(w http.ResponseWriter, req *http.Request) {
  119. log.WithField("path", req.URL.Path).Print("Incoming webhook request")
  120. segments := strings.Split(req.URL.Path, "/")
  121. // last path segment is the service ID which we will pass the incoming request to,
  122. // but we've base64d it.
  123. base64srvID := segments[len(segments)-1]
  124. bytesSrvID, err := base64.RawURLEncoding.DecodeString(base64srvID)
  125. srvID := string(bytesSrvID)
  126. if err != nil {
  127. log.WithError(err).WithField("base64_service_id", base64srvID).Print(
  128. "Not a b64 encoded string",
  129. )
  130. w.WriteHeader(400)
  131. return
  132. }
  133. service, err := wh.db.LoadService(srvID)
  134. if err != nil {
  135. log.WithError(err).WithField("service_id", srvID).Print("Failed to load service")
  136. w.WriteHeader(404)
  137. return
  138. }
  139. cli, err := wh.clients.Client(service.ServiceUserID())
  140. if err != nil {
  141. log.WithError(err).WithField("user_id", service.ServiceUserID()).Print(
  142. "Failed to retrieve matrix client instance")
  143. w.WriteHeader(500)
  144. return
  145. }
  146. log.WithFields(log.Fields{
  147. "service_id": service.ServiceID(),
  148. "service_typ": service.ServiceType(),
  149. }).Print("Incoming webhook for service")
  150. service.OnReceiveWebhook(w, req, cli)
  151. }
  152. type configureClientHandler struct {
  153. db *database.ServiceDB
  154. clients *clients.Clients
  155. }
  156. func (s *configureClientHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) {
  157. if req.Method != "POST" {
  158. return nil, &errors.HTTPError{nil, "Unsupported Method", 405}
  159. }
  160. var body types.ClientConfig
  161. if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
  162. return nil, &errors.HTTPError{err, "Error parsing request JSON", 400}
  163. }
  164. if err := body.Check(); err != nil {
  165. return nil, &errors.HTTPError{err, "Error parsing client config", 400}
  166. }
  167. oldClient, err := s.clients.Update(body)
  168. if err != nil {
  169. return nil, &errors.HTTPError{err, "Error storing token", 500}
  170. }
  171. return &struct {
  172. OldClient types.ClientConfig
  173. NewClient types.ClientConfig
  174. }{oldClient, body}, nil
  175. }
  176. type configureServiceHandler struct {
  177. db *database.ServiceDB
  178. clients *clients.Clients
  179. mapMutex sync.Mutex
  180. mutexByServiceID map[string]*sync.Mutex
  181. }
  182. func newConfigureServiceHandler(db *database.ServiceDB, clients *clients.Clients) *configureServiceHandler {
  183. return &configureServiceHandler{
  184. db: db,
  185. clients: clients,
  186. mutexByServiceID: make(map[string]*sync.Mutex),
  187. }
  188. }
  189. func (s *configureServiceHandler) getMutexForServiceID(serviceID string) *sync.Mutex {
  190. s.mapMutex.Lock()
  191. defer s.mapMutex.Unlock()
  192. m := s.mutexByServiceID[serviceID]
  193. if m == nil {
  194. // XXX TODO: There's a memory leak here. The amount of mutexes created is unbounded, as there will be 1 per service which are never deleted.
  195. // A better solution would be to have a striped hash map with a bounded pool of mutexes. We can't live with a single global mutex because the Register()
  196. // function this is protecting does many many HTTP requests which can take a long time on bad networks and will head of line block other services.
  197. m = &sync.Mutex{}
  198. s.mutexByServiceID[serviceID] = m
  199. }
  200. return m
  201. }
  202. func (s *configureServiceHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) {
  203. if req.Method != "POST" {
  204. return nil, &errors.HTTPError{nil, "Unsupported Method", 405}
  205. }
  206. service, httpErr := s.createService(req)
  207. if httpErr != nil {
  208. return nil, httpErr
  209. }
  210. log.WithFields(log.Fields{
  211. "service_id": service.ServiceID(),
  212. "service_type": service.ServiceType(),
  213. "service_user_id": service.ServiceUserID(),
  214. }).Print("Incoming configure service request")
  215. // Have mutexes around each service to queue up multiple requests for the same service ID
  216. mut := s.getMutexForServiceID(service.ServiceID())
  217. mut.Lock()
  218. defer mut.Unlock()
  219. old, err := s.db.LoadService(service.ServiceID())
  220. if err != nil && err != sql.ErrNoRows {
  221. return nil, &errors.HTTPError{err, "Error loading old service", 500}
  222. }
  223. client, err := s.clients.Client(service.ServiceUserID())
  224. if err != nil {
  225. return nil, &errors.HTTPError{err, "Unknown matrix client", 400}
  226. }
  227. if err = service.Register(old, client); err != nil {
  228. return nil, &errors.HTTPError{err, "Failed to register service: " + err.Error(), 500}
  229. }
  230. oldService, err := s.db.StoreService(service)
  231. if err != nil {
  232. return nil, &errors.HTTPError{err, "Error storing service", 500}
  233. }
  234. service.PostRegister(old)
  235. return &struct {
  236. ID string
  237. Type string
  238. OldConfig types.Service
  239. NewConfig types.Service
  240. }{service.ServiceID(), service.ServiceType(), oldService, service}, nil
  241. }
  242. func (s *configureServiceHandler) createService(req *http.Request) (types.Service, *errors.HTTPError) {
  243. var body struct {
  244. ID string
  245. Type string
  246. UserID string
  247. Config json.RawMessage
  248. }
  249. if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
  250. return nil, &errors.HTTPError{err, "Error parsing request JSON", 400}
  251. }
  252. if body.ID == "" || body.Type == "" || body.UserID == "" || body.Config == nil {
  253. return nil, &errors.HTTPError{
  254. nil, `Must supply an "ID", a "Type", a "UserID" and a "Config"`, 400,
  255. }
  256. }
  257. service, err := types.CreateService(body.ID, body.Type, body.UserID, body.Config)
  258. if err != nil {
  259. return nil, &errors.HTTPError{err, "Error parsing config JSON", 400}
  260. }
  261. return service, nil
  262. }
  263. type getServiceHandler struct {
  264. db *database.ServiceDB
  265. }
  266. func (h *getServiceHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) {
  267. if req.Method != "POST" {
  268. return nil, &errors.HTTPError{nil, "Unsupported Method", 405}
  269. }
  270. var body struct {
  271. ID string
  272. }
  273. if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
  274. return nil, &errors.HTTPError{err, "Error parsing request JSON", 400}
  275. }
  276. if body.ID == "" {
  277. return nil, &errors.HTTPError{nil, `Must supply a "ID"`, 400}
  278. }
  279. srv, err := h.db.LoadService(body.ID)
  280. if err != nil {
  281. if err == sql.ErrNoRows {
  282. return nil, &errors.HTTPError{err, `Service not found`, 404}
  283. }
  284. return nil, &errors.HTTPError{err, `Failed to load service`, 500}
  285. }
  286. return &struct {
  287. ID string
  288. Type string
  289. Config types.Service
  290. }{srv.ServiceID(), srv.ServiceType(), srv}, nil
  291. }
  292. type getSessionHandler struct {
  293. db *database.ServiceDB
  294. }
  295. func (h *getSessionHandler) OnIncomingRequest(req *http.Request) (interface{}, *errors.HTTPError) {
  296. if req.Method != "POST" {
  297. return nil, &errors.HTTPError{nil, "Unsupported Method", 405}
  298. }
  299. var body struct {
  300. RealmID string
  301. UserID string
  302. }
  303. if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
  304. return nil, &errors.HTTPError{err, "Error parsing request JSON", 400}
  305. }
  306. if body.RealmID == "" || body.UserID == "" {
  307. return nil, &errors.HTTPError{nil, `Must supply a "RealmID" and "UserID"`, 400}
  308. }
  309. session, err := h.db.LoadAuthSessionByUser(body.RealmID, body.UserID)
  310. if err != nil && err != sql.ErrNoRows {
  311. return nil, &errors.HTTPError{err, `Failed to load session`, 500}
  312. }
  313. if err == sql.ErrNoRows {
  314. return &struct {
  315. Authenticated bool
  316. }{false}, nil
  317. }
  318. return &struct {
  319. ID string
  320. Authenticated bool
  321. Info interface{}
  322. }{session.ID(), session.Authenticated(), session.Info()}, nil
  323. }