|
|
from typing import Tuple, Any, NamedTuple, Dict, Optional from time import time import ipaddress import logging import asyncio import json
from mautrix.util.logging import TraceLogger from aiohttp import ClientRequest, TCPConnector, ClientSession, ClientTimeout, ClientError from aiohttp.client_proto import ResponseHandler from yarl import URL import aiodns
log: TraceLogger = logging.getLogger("mau.federation")
class ResolvedServerName(NamedTuple): host_header: str host: str port: int expire: int
class ServerNameSplit(NamedTuple): host: str port: Optional[int] is_ip: bool
dns_resolver: aiodns.DNSResolver http: ClientSession server_name_cache: Dict[str, ResolvedServerName] = {}
class MatrixFederationTCPConnector(TCPConnector): """An extension to aiohttp's TCPConnector that correctly sets the TLS SNI for Matrix federation
requests, where the TCP host may not match the SNI/Host header."""
async def _wrap_create_connection(self, *args: Any, server_hostname: str, req: ClientRequest, **kwargs: Any) -> Tuple[asyncio.Transport, ResponseHandler]: split = parse_server_name(req.headers["Host"]) return await super()._wrap_create_connection(*args, server_hostname=split.host, req=req, **kwargs)
def parse_server_name(name: str) -> ServerNameSplit: port_split = name.rsplit(":", 1) if len(port_split) == 2 and port_split[1].isdecimal(): name, port = port_split else: port = None try: ipaddress.ip_address(name) is_ip = True except ValueError: is_ip = False res = ServerNameSplit(host=name, port=port, is_ip=is_ip) log.trace(f"Parsed server name {name} into {res}") return res
async def resolve_server_name(server_name: str) -> ResolvedServerName: try: cached = server_name_cache[server_name] if cached.expire > int(time()): log.trace(f"Using cached server name resolution for {server_name}: {cached}") return cached except KeyError: log.trace(f"No cached server name resolution for {server_name}")
host_header = server_name hostname, port, is_ip = parse_server_name(host_header) ttl = 86400 if port is None and not is_ip: well_known_url = URL.build(scheme="https", host=host_header, port=443, path="/.well-known/matrix/server") try: log.trace(f"Requesting {well_known_url} to resolve {server_name}'s .well-known") async with http.get(well_known_url) as resp: if resp.status == 200: well_known_data = await resp.json() host_header = well_known_data["m.server"] log.debug(f"Got {host_header} from {server_name}'s .well-known") hostname, port, is_ip = parse_server_name(host_header) else: log.trace(f"Got non-200 status {resp.status} from {server_name}'s .well-known") except (ClientError, json.JSONDecodeError, KeyError, ValueError) as e: log.debug(f"Failed to fetch .well-known for {server_name}: {e}") if port is None and not is_ip: log.trace(f"Querying SRV at _matrix._tcp.{host_header}") res = await dns_resolver.query(f"_matrix._tcp.{host_header}", "SRV") if res: hostname = res[0].host port = res[0].port ttl = max(res[0].ttl, 300) log.debug(f"Got {hostname}:{port} from {host_header}'s Matrix SRV record") else: log.trace(f"No SRV records found at _matrix._tcp.{host_header}") result = ResolvedServerName(host_header=host_header, host=hostname, port=port or 8448, expire=int(time()) + ttl) server_name_cache[server_name] = result log.debug(f"Resolved server name {server_name} -> {result}") return result
def init(): global http, dns_resolver dns_resolver = aiodns.DNSResolver(loop=asyncio.get_running_loop()) http = ClientSession(timeout=ClientTimeout(total=10), connector=MatrixFederationTCPConnector())
|