You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

110 lines
4.1 KiB

  1. from typing import Tuple, Any, NamedTuple, Dict, Optional
  2. from time import time
  3. import ipaddress
  4. import logging
  5. import asyncio
  6. import json
  7. from mautrix.util.logging import TraceLogger
  8. from aiohttp import ClientRequest, TCPConnector, ClientSession, ClientTimeout, ClientError
  9. from aiohttp.client_proto import ResponseHandler
  10. from yarl import URL
  11. import aiodns
  12. log: TraceLogger = logging.getLogger("mau.federation")
  13. class ResolvedServerName(NamedTuple):
  14. host_header: str
  15. host: str
  16. port: int
  17. expire: int
  18. class ServerNameSplit(NamedTuple):
  19. host: str
  20. port: Optional[int]
  21. is_ip: bool
  22. dns_resolver: aiodns.DNSResolver
  23. http: ClientSession
  24. server_name_cache: Dict[str, ResolvedServerName] = {}
  25. class MatrixFederationTCPConnector(TCPConnector):
  26. """An extension to aiohttp's TCPConnector that correctly sets the TLS SNI for Matrix federation
  27. requests, where the TCP host may not match the SNI/Host header."""
  28. async def _wrap_create_connection(self, *args: Any, server_hostname: str, req: ClientRequest,
  29. **kwargs: Any) -> Tuple[asyncio.Transport, ResponseHandler]:
  30. split = parse_server_name(req.headers["Host"])
  31. return await super()._wrap_create_connection(*args, server_hostname=split.host,
  32. req=req, **kwargs)
  33. def parse_server_name(name: str) -> ServerNameSplit:
  34. port_split = name.rsplit(":", 1)
  35. if len(port_split) == 2 and port_split[1].isdecimal():
  36. name, port = port_split
  37. else:
  38. port = None
  39. try:
  40. ipaddress.ip_address(name)
  41. is_ip = True
  42. except ValueError:
  43. is_ip = False
  44. res = ServerNameSplit(host=name, port=port, is_ip=is_ip)
  45. log.trace(f"Parsed server name {name} into {res}")
  46. return res
  47. async def resolve_server_name(server_name: str) -> ResolvedServerName:
  48. try:
  49. cached = server_name_cache[server_name]
  50. if cached.expire > int(time()):
  51. log.trace(f"Using cached server name resolution for {server_name}: {cached}")
  52. return cached
  53. except KeyError:
  54. log.trace(f"No cached server name resolution for {server_name}")
  55. host_header = server_name
  56. hostname, port, is_ip = parse_server_name(host_header)
  57. ttl = 86400
  58. if port is None and not is_ip:
  59. well_known_url = URL.build(scheme="https", host=host_header, port=443,
  60. path="/.well-known/matrix/server")
  61. try:
  62. log.trace(f"Requesting {well_known_url} to resolve {server_name}'s .well-known")
  63. async with http.get(well_known_url) as resp:
  64. if resp.status == 200:
  65. well_known_data = await resp.json()
  66. host_header = well_known_data["m.server"]
  67. log.debug(f"Got {host_header} from {server_name}'s .well-known")
  68. hostname, port, is_ip = parse_server_name(host_header)
  69. else:
  70. log.trace(f"Got non-200 status {resp.status} from {server_name}'s .well-known")
  71. except (ClientError, json.JSONDecodeError, KeyError, ValueError) as e:
  72. log.debug(f"Failed to fetch .well-known for {server_name}: {e}")
  73. if port is None and not is_ip:
  74. log.trace(f"Querying SRV at _matrix._tcp.{host_header}")
  75. res = await dns_resolver.query(f"_matrix._tcp.{host_header}", "SRV")
  76. if res:
  77. hostname = res[0].host
  78. port = res[0].port
  79. ttl = max(res[0].ttl, 300)
  80. log.debug(f"Got {hostname}:{port} from {host_header}'s Matrix SRV record")
  81. else:
  82. log.trace(f"No SRV records found at _matrix._tcp.{host_header}")
  83. result = ResolvedServerName(host_header=host_header, host=hostname, port=port or 8448,
  84. expire=int(time()) + ttl)
  85. server_name_cache[server_name] = result
  86. log.debug(f"Resolved server name {server_name} -> {result}")
  87. return result
  88. def init():
  89. global http, dns_resolver
  90. dns_resolver = aiodns.DNSResolver(loop=asyncio.get_running_loop())
  91. http = ClientSession(timeout=ClientTimeout(total=10),
  92. connector=MatrixFederationTCPConnector())