You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

216 lines
7.5 KiB

  1. # maunium-stickerpicker - A fast and simple Matrix sticker picker widget.
  2. # Copyright (C) 2020 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from typing import Tuple, Callable, Awaitable, Optional, TYPE_CHECKING
  17. import logging
  18. import json
  19. from mautrix.client import Client
  20. from mautrix.types import UserID
  21. from mautrix.util.logging import TraceLogger
  22. from aiohttp import web, hdrs, ClientError, ClientSession
  23. from yarl import URL
  24. from ..database import AccessToken, User
  25. from ..config import Config
  26. from .errors import Error
  27. from . import fed_connector
  28. if TYPE_CHECKING:
  29. from typing import TypedDict
  30. class OpenIDPayload(TypedDict):
  31. access_token: str
  32. token_type: str
  33. matrix_server_name: str
  34. expires_in: int
  35. class OpenIDResponse(TypedDict):
  36. sub: str
  37. Handler = Callable[[web.Request], Awaitable[web.Response]]
  38. log: TraceLogger = logging.getLogger("mau.api.auth")
  39. routes = web.RouteTableDef()
  40. config: Config
  41. def get_ip(request: web.Request) -> str:
  42. if config["server.trust_forward_headers"]:
  43. try:
  44. return request.headers["X-Forwarded-For"]
  45. except KeyError:
  46. pass
  47. return request.remote
  48. def get_auth_header(request: web.Request) -> str:
  49. try:
  50. auth = request.headers["Authorization"]
  51. if not auth.startswith("Bearer "):
  52. raise Error.invalid_auth_header
  53. return auth[len("Bearer "):]
  54. except KeyError:
  55. raise Error.missing_auth_header
  56. async def get_user(request: web.Request) -> Tuple[User, AccessToken]:
  57. auth = get_auth_header(request)
  58. try:
  59. token_id, token_val = auth.split(":")
  60. token_id = int(token_id)
  61. except ValueError:
  62. raise Error.invalid_auth_token
  63. token = await AccessToken.get(token_id)
  64. if not token or not token.check(token_val):
  65. raise Error.invalid_auth_token
  66. elif token.expired:
  67. raise Error.auth_token_expired
  68. await token.update_ip(get_ip(request))
  69. return await User.get(token.user_id), token
  70. @web.middleware
  71. async def token_middleware(request: web.Request, handler: Handler) -> web.Response:
  72. if request.method == hdrs.METH_OPTIONS:
  73. return await handler(request)
  74. user, token = await get_user(request)
  75. request["user"] = user
  76. request["token"] = token
  77. return await handler(request)
  78. async def get_widget_user(request: web.Request) -> User:
  79. try:
  80. user_id = UserID(request.headers["X-Matrix-User-ID"])
  81. except KeyError:
  82. raise Error.missing_user_id_header
  83. user = await User.get(user_id)
  84. if user is None:
  85. raise Error.user_not_found
  86. return user
  87. @web.middleware
  88. async def widget_secret_middleware(request: web.Request, handler: Handler) -> web.Response:
  89. if request.method == hdrs.METH_OPTIONS:
  90. return await handler(request)
  91. user = await get_widget_user(request)
  92. request["user"] = user
  93. return await handler(request)
  94. account_cors_headers = {
  95. "Access-Control-Allow-Origin": "*",
  96. "Access-Control-Allow-Methods": "OPTIONS, GET, POST",
  97. "Access-Control-Allow-Headers": "Authorization, Content-Type",
  98. }
  99. @routes.get("/account")
  100. async def get_auth(request: web.Request) -> web.Response:
  101. user, token = await get_user(request)
  102. return web.json_response({"user_id": token.user_id}, headers=account_cors_headers)
  103. async def check_openid_token(homeserver: str, token: str) -> Optional[UserID]:
  104. server_info = await fed_connector.resolve_server_name(homeserver)
  105. headers = {"Host": server_info.host_header}
  106. userinfo_url = URL.build(scheme="https", host=server_info.host, port=server_info.port,
  107. path="/_matrix/federation/v1/openid/userinfo",
  108. query={"access_token": token})
  109. try:
  110. async with fed_connector.http.get(userinfo_url, headers=headers) as resp:
  111. data: 'OpenIDResponse' = await resp.json()
  112. return UserID(data["sub"])
  113. except (ClientError, json.JSONDecodeError, KeyError, ValueError) as e:
  114. log.debug(f"Failed to check OpenID token from {homeserver}", exc_info=True)
  115. return None
  116. @routes.route(hdrs.METH_OPTIONS, "/account/register")
  117. @routes.route(hdrs.METH_OPTIONS, "/account/logout")
  118. @routes.route(hdrs.METH_OPTIONS, "/account")
  119. async def cors_token(_: web.Request) -> web.Response:
  120. return web.Response(status=200, headers=account_cors_headers)
  121. async def resolve_client_well_known(server_name: str) -> str:
  122. url = URL.build(scheme="https", host=server_name, port=443, path="/.well-known/matrix/client")
  123. async with ClientSession() as sess, sess.get(url) as resp:
  124. data = await resp.json()
  125. return data["m.homeserver"]["base_url"]
  126. @routes.post("/account/register")
  127. async def exchange_token(request: web.Request) -> web.Response:
  128. try:
  129. data: 'OpenIDPayload' = await request.json()
  130. except json.JSONDecodeError:
  131. raise Error.request_not_json
  132. try:
  133. matrix_server_name = data["matrix_server_name"]
  134. access_token = data["access_token"]
  135. except KeyError:
  136. raise Error.invalid_openid_payload
  137. log.trace(f"Validating OpenID token from {matrix_server_name}")
  138. user_id = await check_openid_token(matrix_server_name, access_token)
  139. if user_id is None:
  140. raise Error.invalid_openid_token
  141. _, homeserver = Client.parse_user_id(user_id)
  142. if homeserver != data["matrix_server_name"]:
  143. raise Error.homeserver_mismatch
  144. permissions = config.get_permissions(user_id)
  145. if not permissions.access:
  146. raise Error.no_access
  147. try:
  148. log.trace(f"Trying to resolve {matrix_server_name}'s client .well-known")
  149. homeserver_url = await resolve_client_well_known(matrix_server_name)
  150. log.trace(f"Got {homeserver_url} from {matrix_server_name}'s client .well-known")
  151. except (ClientError, json.JSONDecodeError, KeyError, ValueError, TypeError):
  152. log.trace(f"Failed to resolve {matrix_server_name}'s client .well-known", exc_info=True)
  153. raise Error.client_well_known_error
  154. user = await User.get(user_id)
  155. if user is None:
  156. log.debug(f"Creating user {user_id} with homeserver client URL {homeserver_url}")
  157. user = User.new(user_id, homeserver_url=homeserver_url)
  158. await user.insert()
  159. elif user.homeserver_url != homeserver_url:
  160. log.debug(f"Updating {user_id}'s homeserver client URL from {user.homeserver_url} "
  161. f"to {homeserver_url}")
  162. await user.set_homeserver_url(homeserver_url)
  163. token = await user.new_access_token(get_ip(request))
  164. return web.json_response({
  165. "user_id": user_id,
  166. "token": token,
  167. "permissions": permissions._asdict(),
  168. }, headers=account_cors_headers)
  169. @routes.post("/account/logout")
  170. async def logout(request: web.Request) -> web.Response:
  171. user, token = await get_user(request)
  172. await token.delete()
  173. return web.json_response({}, status=204, headers=account_cors_headers)
  174. def init(cfg: Config) -> None:
  175. global config
  176. config = cfg