You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

312 lines
8.7 KiB

  1. package handlers
  2. import (
  3. "database/sql"
  4. "encoding/base64"
  5. "encoding/json"
  6. "net/http"
  7. "strings"
  8. "github.com/matrix-org/go-neb/api"
  9. "github.com/matrix-org/go-neb/database"
  10. "github.com/matrix-org/go-neb/metrics"
  11. "github.com/matrix-org/go-neb/types"
  12. "github.com/matrix-org/util"
  13. log "github.com/sirupsen/logrus"
  14. "maunium.net/go/mautrix/id"
  15. )
  16. // RequestAuthSession represents an HTTP handler capable of processing /admin/requestAuthSession requests.
  17. type RequestAuthSession struct {
  18. Db *database.ServiceDB
  19. }
  20. // OnIncomingRequest handles POST requests to /admin/requestAuthSession. The HTTP body MUST be
  21. // a JSON object representing type "api.RequestAuthSessionRequest".
  22. //
  23. // This will return HTTP 400 if there are missing fields or the Realm ID is unknown.
  24. // For the format of the response, see the specific AuthRealm that the Realm ID corresponds to.
  25. //
  26. // Request:
  27. // POST /admin/requestAuthSession
  28. // {
  29. // "RealmID": "github_realm_id",
  30. // "UserID": "@my_user:localhost",
  31. // "Config": {
  32. // // AuthRealm specific config info
  33. // }
  34. // }
  35. // Response:
  36. // HTTP/1.1 200 OK
  37. // {
  38. // // AuthRealm-specific information
  39. // }
  40. func (h *RequestAuthSession) OnIncomingRequest(req *http.Request) util.JSONResponse {
  41. logger := util.GetLogger(req.Context())
  42. if req.Method != "POST" {
  43. return util.MessageResponse(405, "Unsupported Method")
  44. }
  45. var body api.RequestAuthSessionRequest
  46. if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
  47. return util.MessageResponse(400, "Error parsing request JSON")
  48. }
  49. logger.WithFields(log.Fields{
  50. "realm_id": body.RealmID,
  51. "user_id": body.UserID,
  52. }).Print("Incoming auth session request")
  53. if err := body.Check(); err != nil {
  54. logger.WithError(err).Info("Failed Check")
  55. return util.MessageResponse(400, err.Error())
  56. }
  57. realm, err := h.Db.LoadAuthRealm(body.RealmID)
  58. if err != nil {
  59. logger.WithError(err).Info("Failed to LoadAuthRealm")
  60. return util.MessageResponse(400, "Unknown RealmID")
  61. }
  62. response := realm.RequestAuthSession(body.UserID, body.Config)
  63. if response == nil {
  64. logger.WithField("body", body).Error("Failed to RequestAuthSession")
  65. return util.MessageResponse(500, "Failed to request auth session")
  66. }
  67. metrics.IncrementAuthSession(realm.Type())
  68. return util.JSONResponse{
  69. Code: 200,
  70. JSON: response,
  71. }
  72. }
  73. // RemoveAuthSession represents an HTTP handler capable of processing /admin/removeAuthSession requests.
  74. type RemoveAuthSession struct {
  75. Db *database.ServiceDB
  76. }
  77. // OnIncomingRequest handles POST requests to /admin/removeAuthSession.
  78. //
  79. // The JSON object MUST contain the keys "RealmID" and "UserID" to identify the session to remove.
  80. //
  81. // Request
  82. // POST /admin/removeAuthSession
  83. // {
  84. // "RealmID": "github-realm",
  85. // "UserID": "@my_user:localhost"
  86. // }
  87. // Response:
  88. // HTTP/1.1 200 OK
  89. // {}
  90. func (h *RemoveAuthSession) OnIncomingRequest(req *http.Request) util.JSONResponse {
  91. logger := util.GetLogger(req.Context())
  92. if req.Method != "POST" {
  93. return util.MessageResponse(405, "Unsupported Method")
  94. }
  95. var body struct {
  96. RealmID string
  97. UserID id.UserID
  98. }
  99. if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
  100. return util.MessageResponse(400, "Error parsing request JSON")
  101. }
  102. logger.WithFields(log.Fields{
  103. "realm_id": body.RealmID,
  104. "user_id": body.UserID,
  105. }).Print("Incoming remove auth session request")
  106. if body.UserID == "" || body.RealmID == "" {
  107. return util.MessageResponse(400, `Must supply a "UserID", a "RealmID"`)
  108. }
  109. _, err := h.Db.LoadAuthRealm(body.RealmID)
  110. if err != nil {
  111. return util.MessageResponse(400, "Unknown RealmID")
  112. }
  113. if err := h.Db.RemoveAuthSession(body.RealmID, body.UserID); err != nil {
  114. logger.WithError(err).Error("Failed to RemoveAuthSession")
  115. return util.MessageResponse(500, "Failed to remove auth session")
  116. }
  117. return util.JSONResponse{
  118. Code: 200,
  119. JSON: struct{}{},
  120. }
  121. }
  122. // RealmRedirect represents an HTTP handler which can process incoming redirects for auth realms.
  123. type RealmRedirect struct {
  124. Db *database.ServiceDB
  125. }
  126. // Handle requests for an auth realm.
  127. //
  128. // The last path segment of the URL MUST be the base64 form of the Realm ID. What response
  129. // this returns depends on the specific AuthRealm implementation.
  130. func (rh *RealmRedirect) Handle(w http.ResponseWriter, req *http.Request) {
  131. segments := strings.Split(req.URL.Path, "/")
  132. // last path segment is the base64d realm ID which we will pass the incoming request to
  133. base64realmID := segments[len(segments)-1]
  134. bytesRealmID, err := base64.RawURLEncoding.DecodeString(base64realmID)
  135. realmID := string(bytesRealmID)
  136. if err != nil {
  137. log.WithError(err).WithField("base64_realm_id", base64realmID).Print(
  138. "Not a b64 encoded string",
  139. )
  140. w.WriteHeader(400)
  141. return
  142. }
  143. realm, err := rh.Db.LoadAuthRealm(realmID)
  144. if err != nil {
  145. log.WithError(err).WithField("realm_id", realmID).Print("Failed to load realm")
  146. w.WriteHeader(404)
  147. return
  148. }
  149. log.WithFields(log.Fields{
  150. "realm_id": realmID,
  151. }).Print("Incoming realm redirect request")
  152. realm.OnReceiveRedirect(w, req)
  153. }
  154. // ConfigureAuthRealm represents an HTTP handler capable of processing /admin/configureAuthRealm requests.
  155. type ConfigureAuthRealm struct {
  156. Db *database.ServiceDB
  157. }
  158. // OnIncomingRequest handles POST requests to /admin/configureAuthRealm. The JSON object
  159. // provided is of type "api.ConfigureAuthRealmRequest".
  160. //
  161. // Request:
  162. // POST /admin/configureAuthRealm
  163. // {
  164. // "ID": "my-realm-id",
  165. // "Type": "github",
  166. // "Config": {
  167. // // Realm-specific configuration information
  168. // }
  169. // }
  170. // Response:
  171. // HTTP/1.1 200 OK
  172. // {
  173. // "ID": "my-realm-id",
  174. // "Type": "github",
  175. // "OldConfig": {
  176. // // Old auth realm config information
  177. // },
  178. // "NewConfig": {
  179. // // New auth realm config information
  180. // },
  181. // }
  182. func (h *ConfigureAuthRealm) OnIncomingRequest(req *http.Request) util.JSONResponse {
  183. logger := util.GetLogger(req.Context())
  184. if req.Method != "POST" {
  185. return util.MessageResponse(405, "Unsupported Method")
  186. }
  187. var body api.ConfigureAuthRealmRequest
  188. if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
  189. return util.MessageResponse(400, "Error parsing request JSON")
  190. }
  191. if err := body.Check(); err != nil {
  192. return util.MessageResponse(400, err.Error())
  193. }
  194. realm, err := types.CreateAuthRealm(body.ID, body.Type, body.Config)
  195. if err != nil {
  196. return util.MessageResponse(400, "Error parsing config JSON")
  197. }
  198. if err = realm.Register(); err != nil {
  199. return util.MessageResponse(400, "Error registering auth realm")
  200. }
  201. oldRealm, err := h.Db.StoreAuthRealm(realm)
  202. if err != nil {
  203. logger.WithError(err).Error("Failed to StoreAuthRealm")
  204. return util.MessageResponse(500, "Error storing realm")
  205. }
  206. return util.JSONResponse{
  207. Code: 200,
  208. JSON: struct {
  209. ID string
  210. Type string
  211. OldConfig types.AuthRealm
  212. NewConfig types.AuthRealm
  213. }{body.ID, body.Type, oldRealm, realm},
  214. }
  215. }
  216. // GetSession represents an HTTP handler capable of processing /admin/getSession requests.
  217. type GetSession struct {
  218. Db *database.ServiceDB
  219. }
  220. // OnIncomingRequest handles POST requests to /admin/getSession.
  221. //
  222. // The JSON object provided MUST have a "RealmID" and "UserID" in order to fetch the
  223. // correct AuthSession. If there is no session for this tuple of realm and user ID,
  224. // a 200 OK is still returned with "Authenticated" set to false.
  225. //
  226. // Request:
  227. // POST /admin/getSession
  228. // {
  229. // "RealmID": "my-realm",
  230. // "UserID": "@my_user:localhost"
  231. // }
  232. // Response:
  233. // HTTP/1.1 200 OK
  234. // {
  235. // "ID": "session_id",
  236. // "Authenticated": true,
  237. // "Info": {
  238. // // Session-specific config info
  239. // }
  240. // }
  241. // Response if session not found:
  242. // HTTP/1.1 200 OK
  243. // {
  244. // "Authenticated": false
  245. // }
  246. func (h *GetSession) OnIncomingRequest(req *http.Request) util.JSONResponse {
  247. logger := util.GetLogger(req.Context())
  248. if req.Method != "POST" {
  249. return util.MessageResponse(405, "Unsupported Method")
  250. }
  251. var body struct {
  252. RealmID string
  253. UserID id.UserID
  254. }
  255. if err := json.NewDecoder(req.Body).Decode(&body); err != nil {
  256. return util.MessageResponse(400, "Error parsing request JSON")
  257. }
  258. if body.RealmID == "" || body.UserID == "" {
  259. return util.MessageResponse(400, `Must supply a "RealmID" and "UserID"`)
  260. }
  261. session, err := h.Db.LoadAuthSessionByUser(body.RealmID, body.UserID)
  262. if err != nil && err != sql.ErrNoRows {
  263. logger.WithError(err).WithField("body", body).Error("Failed to LoadAuthSessionByUser")
  264. return util.MessageResponse(500, `Failed to load session`)
  265. }
  266. if err == sql.ErrNoRows {
  267. return util.JSONResponse{
  268. Code: 200,
  269. JSON: struct {
  270. Authenticated bool
  271. }{false},
  272. }
  273. }
  274. return util.JSONResponse{
  275. Code: 200,
  276. JSON: struct {
  277. ID string
  278. Authenticated bool
  279. Info interface{}
  280. }{session.ID(), session.Authenticated(), session.Info()},
  281. }
  282. }