You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
216 lines
7.5 KiB
216 lines
7.5 KiB
# maunium-stickerpicker - A fast and simple Matrix sticker picker widget.
|
|
# Copyright (C) 2020 Tulir Asokan
|
|
#
|
|
# This program is free software: you can redistribute it and/or modify
|
|
# it under the terms of the GNU Affero General Public License as published by
|
|
# the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# This program is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# GNU Affero General Public License for more details.
|
|
#
|
|
# You should have received a copy of the GNU Affero General Public License
|
|
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
|
from typing import Tuple, Callable, Awaitable, Optional, TYPE_CHECKING
|
|
import logging
|
|
import json
|
|
|
|
from mautrix.client import Client
|
|
from mautrix.types import UserID
|
|
from mautrix.util.logging import TraceLogger
|
|
from aiohttp import web, hdrs, ClientError, ClientSession
|
|
from yarl import URL
|
|
|
|
from ..database import AccessToken, User
|
|
from ..config import Config
|
|
from .errors import Error
|
|
from . import fed_connector
|
|
|
|
if TYPE_CHECKING:
|
|
from typing import TypedDict
|
|
|
|
|
|
class OpenIDPayload(TypedDict):
|
|
access_token: str
|
|
token_type: str
|
|
matrix_server_name: str
|
|
expires_in: int
|
|
|
|
|
|
class OpenIDResponse(TypedDict):
|
|
sub: str
|
|
|
|
Handler = Callable[[web.Request], Awaitable[web.Response]]
|
|
|
|
log: TraceLogger = logging.getLogger("mau.api.auth")
|
|
routes = web.RouteTableDef()
|
|
config: Config
|
|
|
|
|
|
def get_ip(request: web.Request) -> str:
|
|
if config["server.trust_forward_headers"]:
|
|
try:
|
|
return request.headers["X-Forwarded-For"]
|
|
except KeyError:
|
|
pass
|
|
return request.remote
|
|
|
|
|
|
def get_auth_header(request: web.Request) -> str:
|
|
try:
|
|
auth = request.headers["Authorization"]
|
|
if not auth.startswith("Bearer "):
|
|
raise Error.invalid_auth_header
|
|
return auth[len("Bearer "):]
|
|
except KeyError:
|
|
raise Error.missing_auth_header
|
|
|
|
|
|
async def get_user(request: web.Request) -> Tuple[User, AccessToken]:
|
|
auth = get_auth_header(request)
|
|
try:
|
|
token_id, token_val = auth.split(":")
|
|
token_id = int(token_id)
|
|
except ValueError:
|
|
raise Error.invalid_auth_token
|
|
token = await AccessToken.get(token_id)
|
|
if not token or not token.check(token_val):
|
|
raise Error.invalid_auth_token
|
|
elif token.expired:
|
|
raise Error.auth_token_expired
|
|
await token.update_ip(get_ip(request))
|
|
return await User.get(token.user_id), token
|
|
|
|
|
|
@web.middleware
|
|
async def token_middleware(request: web.Request, handler: Handler) -> web.Response:
|
|
if request.method == hdrs.METH_OPTIONS:
|
|
return await handler(request)
|
|
user, token = await get_user(request)
|
|
request["user"] = user
|
|
request["token"] = token
|
|
return await handler(request)
|
|
|
|
|
|
async def get_widget_user(request: web.Request) -> User:
|
|
try:
|
|
user_id = UserID(request.headers["X-Matrix-User-ID"])
|
|
except KeyError:
|
|
raise Error.missing_user_id_header
|
|
user = await User.get(user_id)
|
|
if user is None:
|
|
raise Error.user_not_found
|
|
return user
|
|
|
|
|
|
@web.middleware
|
|
async def widget_secret_middleware(request: web.Request, handler: Handler) -> web.Response:
|
|
if request.method == hdrs.METH_OPTIONS:
|
|
return await handler(request)
|
|
user = await get_widget_user(request)
|
|
request["user"] = user
|
|
return await handler(request)
|
|
|
|
|
|
account_cors_headers = {
|
|
"Access-Control-Allow-Origin": "*",
|
|
"Access-Control-Allow-Methods": "OPTIONS, GET, POST",
|
|
"Access-Control-Allow-Headers": "Authorization, Content-Type",
|
|
}
|
|
|
|
|
|
@routes.get("/account")
|
|
async def get_auth(request: web.Request) -> web.Response:
|
|
user, token = await get_user(request)
|
|
return web.json_response({"user_id": token.user_id}, headers=account_cors_headers)
|
|
|
|
|
|
async def check_openid_token(homeserver: str, token: str) -> Optional[UserID]:
|
|
server_info = await fed_connector.resolve_server_name(homeserver)
|
|
headers = {"Host": server_info.host_header}
|
|
userinfo_url = URL.build(scheme="https", host=server_info.host, port=server_info.port,
|
|
path="/_matrix/federation/v1/openid/userinfo",
|
|
query={"access_token": token})
|
|
try:
|
|
async with fed_connector.http.get(userinfo_url, headers=headers) as resp:
|
|
data: 'OpenIDResponse' = await resp.json()
|
|
return UserID(data["sub"])
|
|
except (ClientError, json.JSONDecodeError, KeyError, ValueError) as e:
|
|
log.debug(f"Failed to check OpenID token from {homeserver}", exc_info=True)
|
|
return None
|
|
|
|
|
|
@routes.route(hdrs.METH_OPTIONS, "/account/register")
|
|
@routes.route(hdrs.METH_OPTIONS, "/account/logout")
|
|
@routes.route(hdrs.METH_OPTIONS, "/account")
|
|
async def cors_token(_: web.Request) -> web.Response:
|
|
return web.Response(status=200, headers=account_cors_headers)
|
|
|
|
|
|
async def resolve_client_well_known(server_name: str) -> str:
|
|
url = URL.build(scheme="https", host=server_name, port=443, path="/.well-known/matrix/client")
|
|
async with ClientSession() as sess, sess.get(url) as resp:
|
|
data = await resp.json()
|
|
return data["m.homeserver"]["base_url"]
|
|
|
|
|
|
@routes.post("/account/register")
|
|
async def exchange_token(request: web.Request) -> web.Response:
|
|
try:
|
|
data: 'OpenIDPayload' = await request.json()
|
|
except json.JSONDecodeError:
|
|
raise Error.request_not_json
|
|
try:
|
|
matrix_server_name = data["matrix_server_name"]
|
|
access_token = data["access_token"]
|
|
except KeyError:
|
|
raise Error.invalid_openid_payload
|
|
log.trace(f"Validating OpenID token from {matrix_server_name}")
|
|
user_id = await check_openid_token(matrix_server_name, access_token)
|
|
if user_id is None:
|
|
raise Error.invalid_openid_token
|
|
_, homeserver = Client.parse_user_id(user_id)
|
|
if homeserver != data["matrix_server_name"]:
|
|
raise Error.homeserver_mismatch
|
|
|
|
permissions = config.get_permissions(user_id)
|
|
if not permissions.access:
|
|
raise Error.no_access
|
|
|
|
try:
|
|
log.trace(f"Trying to resolve {matrix_server_name}'s client .well-known")
|
|
homeserver_url = await resolve_client_well_known(matrix_server_name)
|
|
log.trace(f"Got {homeserver_url} from {matrix_server_name}'s client .well-known")
|
|
except (ClientError, json.JSONDecodeError, KeyError, ValueError, TypeError):
|
|
log.trace(f"Failed to resolve {matrix_server_name}'s client .well-known", exc_info=True)
|
|
raise Error.client_well_known_error
|
|
|
|
user = await User.get(user_id)
|
|
if user is None:
|
|
log.debug(f"Creating user {user_id} with homeserver client URL {homeserver_url}")
|
|
user = User.new(user_id, homeserver_url=homeserver_url)
|
|
await user.insert()
|
|
elif user.homeserver_url != homeserver_url:
|
|
log.debug(f"Updating {user_id}'s homeserver client URL from {user.homeserver_url} "
|
|
f"to {homeserver_url}")
|
|
await user.set_homeserver_url(homeserver_url)
|
|
token = await user.new_access_token(get_ip(request))
|
|
return web.json_response({
|
|
"user_id": user_id,
|
|
"token": token,
|
|
"permissions": permissions._asdict(),
|
|
}, headers=account_cors_headers)
|
|
|
|
|
|
@routes.post("/account/logout")
|
|
async def logout(request: web.Request) -> web.Response:
|
|
user, token = await get_user(request)
|
|
await token.delete()
|
|
return web.json_response({}, status=204, headers=account_cors_headers)
|
|
|
|
|
|
def init(cfg: Config) -> None:
|
|
global config
|
|
config = cfg
|