Source code for platforms.matrix

"""Matrix platform adapter using matrix-nio.

Wraps the matrix-nio ``AsyncClient`` and converts Matrix events into
:class:`~platforms.base.IncomingMessage` instances for the shared
:class:`~message_processor.MessageProcessor`.
"""

from __future__ import annotations

import asyncio
import base64
import io
import json
import logging
import re
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Union

import markdown as md
import olm

import aiohttp
import aiofiles
from nio import (
    AsyncClient,
    AsyncClientConfig,
    InviteMemberEvent,
    JoinError,
    JoinResponse,
    LoginResponse,
    MatrixRoom,
    MegolmEvent,
    RoomMessageText,
    RoomSendResponse,
)
from nio.api import Api as NioApi
from nio.crypto import Sas
from nio.crypto.attachments import decrypt_attachment
from nio.crypto.device import OlmDevice
from nio.events.room_events import (
    RoomEncryptedAudio,
    RoomEncryptedFile,
    RoomEncryptedImage,
    RoomEncryptedMedia,
    RoomEncryptedVideo,
    RoomMessageAudio,
    RoomMessageFile,
    RoomMessageImage,
    RoomMessageMedia,
    RoomMessageUnknown,
    RoomMessageVideo,
    UnknownEvent,
)
from nio.events.to_device import (
    KeyVerificationCancel,
    KeyVerificationKey,
    KeyVerificationMac,
    KeyVerificationStart,
    UnknownToDeviceEvent,
)
from nio import ToDeviceMessage
from nio.responses import DownloadError, RoomMessagesResponse

from media_cache import MediaCache
from config import Config
from platforms.base import (
    Attachment,
    HistoricalMessage,
    IncomingMessage,
    MessageHandler,
    PlatformAdapter,
)

logger = logging.getLogger(__name__)

# Union of all media event types handled by on_media
MediaEvent = Union[
    RoomMessageImage, RoomEncryptedImage,
    RoomMessageAudio, RoomEncryptedAudio,
    RoomMessageVideo, RoomEncryptedVideo,
    RoomMessageFile, RoomEncryptedFile,
]


def _get_reply_to_id(event: Any) -> str:
    """Extract the reply-to event ID from a Matrix event, or ``""``."""
    source = getattr(event, "source", None) or {}
    content = source.get("content", {})
    relates = content.get("m.relates_to", {})
    reply_to = relates.get("m.in_reply_to", {})
    return reply_to.get("event_id", "")


_STRIP_HTML_RE = re.compile(r"<[^>]+>")
_MD_EXTENSIONS = ["fenced_code", "tables", "nl2br"]

# Matches Matrix user IDs like @localpart:server or @localpart:server:port
_MATRIX_MENTION_RE = re.compile(r"@[\w.=\-/+]+:[\w.\-]+(:\d+)?")
# Matches any HTML tag (used to skip mention replacement inside tags)
_HTML_TAG_RE = re.compile(r"<[^>]+>")


def _markdown_to_html(text: str) -> str:
    """Convert Markdown text to HTML for Matrix ``formatted_body``."""
    return md.markdown(text, extensions=_MD_EXTENSIONS)


def _strip_html(text: str) -> str:
    """Strip HTML tags for the plain-text ``body`` fallback."""
    return _STRIP_HTML_RE.sub("", text)


def _linkify_matrix_mentions(html: str) -> tuple[str, list[str]]:
    """Replace bare @user:server mentions in HTML with matrix.to anchor links.

    Operates only on text nodes (skips content inside HTML tags) so existing
    links or attributes are never double-processed.

    Returns ``(new_html, unique_user_ids)`` where ``unique_user_ids`` is the
    ordered list of Matrix user IDs found, suitable for ``m.mentions.user_ids``.
    """
    mentions: list[str] = []

    def _replace(m: re.Match) -> str:
        uid = m.group(0)
        mentions.append(uid)
        return f'<a href="https://matrix.to/#/{uid}">{uid}</a>'

    parts: list[str] = []
    pos = 0
    for tag_match in _HTML_TAG_RE.finditer(html):
        # Linkify text between tags, keep the tag itself verbatim
        parts.append(_MATRIX_MENTION_RE.sub(_replace, html[pos:tag_match.start()]))
        parts.append(tag_match.group(0))
        pos = tag_match.end()
    parts.append(_MATRIX_MENTION_RE.sub(_replace, html[pos:]))

    seen: set[str] = set()
    unique = [u for u in mentions if not (u in seen or seen.add(u))]  # type: ignore[func-returns-value]
    return "".join(parts), unique


# ------------------------------------------------------------------
# Matrix-specific media download
# ------------------------------------------------------------------

[docs] async def download_matrix_media( client: AsyncClient, event: RoomMessageMedia | RoomEncryptedMedia, ) -> tuple[bytes, str, str]: """Download (and optionally decrypt) a media attachment from Matrix. Returns ------- tuple of (data, mimetype, filename) """ mxc_url: str = event.url filename: str = event.body or "attachment" resp = await client.download(mxc=mxc_url) if isinstance(resp, DownloadError): raise RuntimeError( f"Failed to download {mxc_url}: {resp.message}" ) data: bytes = resp.body # type: ignore[assignment] # Determine MIME type if isinstance(event, RoomEncryptedMedia): mimetype = ( event.mimetype or resp.content_type or "application/octet-stream" ) else: info = event.source.get("content", {}).get("info", {}) mimetype = ( info.get("mimetype") or resp.content_type or "application/octet-stream" ) # Decrypt if necessary if isinstance(event, RoomEncryptedMedia): data = decrypt_attachment( ciphertext=data, key=event.key["k"], hash=event.hashes["sha256"], iv=event.iv, ) return data, mimetype, filename
# ------------------------------------------------------------------ # Credential helpers # ------------------------------------------------------------------
[docs] async def save_matrix_credentials( credentials_file: str, homeserver: str, client: AsyncClient, ) -> None: """Persist Matrix login credentials (preserving extra keys like seeds).""" existing: dict[str, Any] = {} cred_path = Path(credentials_file) if cred_path.exists(): try: async with aiofiles.open(credentials_file, "r") as f: existing = json.loads(await f.read()) except (json.JSONDecodeError, OSError): pass existing.update({ "homeserver": homeserver, "user_id": client.user_id, "device_id": client.device_id, "access_token": client.access_token, }) async with aiofiles.open(credentials_file, "w") as f: await f.write(json.dumps(existing, indent=2)) logger.info("Matrix credentials saved to %s", credentials_file)
[docs] async def load_matrix_credentials(credentials_file: str) -> dict | None: """Load previously saved Matrix credentials, or return ``None``.""" cred_path = Path(credentials_file) if not cred_path.exists(): return None async with aiofiles.open(credentials_file, "r") as f: content = await f.read() try: return json.loads(content) except json.JSONDecodeError: return None
# ------------------------------------------------------------------ # Auto-trust helper # ------------------------------------------------------------------
[docs] def trust_all_devices(client: AsyncClient) -> None: """Mark every known device of every tracked user as trusted.""" for user_id, devices in client.device_store.items(): for device_id, olm_device in devices.items(): if not client.olm.is_device_verified(olm_device): client.verify_device(olm_device) logger.debug( "Auto-trusted device %s of %s", device_id, user_id, )
# ------------------------------------------------------------------ # Cross-signing setup # ------------------------------------------------------------------ def _sign_json(signing_key: olm.PkSigning, obj: dict) -> str: """Canonical-JSON-sign *obj* with *signing_key*, return the signature.""" canonical = json.dumps(obj, ensure_ascii=False, sort_keys=True, separators=(",", ":")) return signing_key.sign(canonical)
[docs] async def setup_cross_signing( client: AsyncClient, password: str, credentials_file: str, saved_seeds: dict[str, str] | None = None, ) -> None: """Generate cross-signing keys, upload them, and self-sign the device. If *saved_seeds* is provided the keys are re-derived from persisted seeds instead of generating new ones. Seeds are persisted to *credentials_file* for future restarts. """ user_id = client.user_id device_id = client.device_id assert user_id and device_id # --- Derive or generate keys ---------------------------------------- if saved_seeds: master_seed = base64.b64decode(saved_seeds["master"]) self_signing_seed = base64.b64decode(saved_seeds["self_signing"]) user_signing_seed = base64.b64decode(saved_seeds["user_signing"]) else: master_seed = olm.PkSigning.generate_seed() self_signing_seed = olm.PkSigning.generate_seed() user_signing_seed = olm.PkSigning.generate_seed() master_key = olm.PkSigning(master_seed) self_signing_key = olm.PkSigning(self_signing_seed) user_signing_key = olm.PkSigning(user_signing_seed) # --- Check if master key is already published ----------------------- query_path = ( f"/_matrix/client/v3/keys/query" f"?access_token={client.access_token}" ) query_body = json.dumps({"device_keys": {user_id: []}}) try: resp = await client.send("POST", query_path, query_body) if resp.status == 200: data = json.loads(await resp.read()) existing_master = ( data.get("master_keys", {}).get(user_id, {}).get("keys", {}) ) if f"ed25519:{master_key.public_key}" in existing_master: logger.info( "Cross-signing master key already published — skipping upload", ) await _sign_own_device( client, self_signing_key, user_id, device_id, ) return except Exception: logger.debug("Could not query existing cross-signing keys", exc_info=True) if not password: logger.warning( "Cross-signing setup requires a password for UIA — skipping. " "Set up will be attempted on next fresh login.", ) return # --- Build the key upload payload ----------------------------------- master_key_obj = { "user_id": user_id, "usage": ["master"], "keys": {f"ed25519:{master_key.public_key}": master_key.public_key}, } self_signing_key_obj = { "user_id": user_id, "usage": ["self_signing"], "keys": { f"ed25519:{self_signing_key.public_key}": self_signing_key.public_key, }, } ss_sig = _sign_json(master_key, self_signing_key_obj) self_signing_key_obj["signatures"] = { user_id: {f"ed25519:{master_key.public_key}": ss_sig}, } user_signing_key_obj = { "user_id": user_id, "usage": ["user_signing"], "keys": { f"ed25519:{user_signing_key.public_key}": user_signing_key.public_key, }, } us_sig = _sign_json(master_key, user_signing_key_obj) user_signing_key_obj["signatures"] = { user_id: {f"ed25519:{master_key.public_key}": us_sig}, } upload_body: dict[str, Any] = { "master_key": master_key_obj, "self_signing_key": self_signing_key_obj, "user_signing_key": user_signing_key_obj, } # --- Upload with UIA (password auth) -------------------------------- upload_path = ( f"/_matrix/client/v3/keys/device_signing/upload" f"?access_token={client.access_token}" ) # First request to get the session for UIA resp = await client.send( "POST", upload_path, json.dumps(upload_body), ) if resp.status == 401: uia_data = json.loads(await resp.read()) session = uia_data.get("session", "") upload_body["auth"] = { "type": "m.login.password", "user": user_id, "password": password, "session": session, } resp = await client.send( "POST", upload_path, json.dumps(upload_body), ) if resp.status != 200: body = await resp.read() logger.error( "Cross-signing key upload failed (HTTP %d): %s", resp.status, body.decode(errors="replace"), ) return logger.info("Cross-signing keys uploaded successfully") # --- Sign own device ------------------------------------------------ await _sign_own_device(client, self_signing_key, user_id, device_id) # --- Persist seeds -------------------------------------------------- seeds = { "master": base64.b64encode(master_seed).decode(), "self_signing": base64.b64encode(self_signing_seed).decode(), "user_signing": base64.b64encode(user_signing_seed).decode(), } try: creds: dict[str, Any] = {} cred_path = Path(credentials_file) if cred_path.exists(): async with aiofiles.open(credentials_file, "r") as f: creds = json.loads(await f.read()) creds["cross_signing_seeds"] = seeds async with aiofiles.open(credentials_file, "w") as f: await f.write(json.dumps(creds, indent=2)) logger.info("Cross-signing seeds persisted to %s", credentials_file) except Exception: logger.warning("Failed to persist cross-signing seeds", exc_info=True)
async def _sign_own_device( client: AsyncClient, self_signing_key: olm.PkSigning, user_id: str, device_id: str, ) -> None: """Sign the bot's own device key with the self-signing key.""" device_key_id = f"ed25519:{device_id}" device_keys = client.olm.account.identity_keys ed25519_key = device_keys["ed25519"] curve25519_key = device_keys["curve25519"] device_key_obj = { "algorithms": [ "m.olm.v1.curve25519-aes-sha2-256", "m.megolm.v1.aes-sha2", ], "device_id": device_id, "keys": { f"curve25519:{device_id}": curve25519_key, f"ed25519:{device_id}": ed25519_key, }, "user_id": user_id, } dev_sig = _sign_json(self_signing_key, device_key_obj) sig_upload = { user_id: { device_id: { f"ed25519:{self_signing_key.public_key}": dev_sig, }, }, } sig_path = ( f"/_matrix/client/v3/keys/signatures/upload" f"?access_token={client.access_token}" ) try: resp = await client.send( "POST", sig_path, json.dumps(sig_upload), ) if resp.status == 200: logger.info( "Device %s self-signed with cross-signing key", device_id, ) else: body = await resp.read() logger.warning( "Device signature upload failed (HTTP %d): %s", resp.status, body.decode(errors="replace"), ) except Exception: logger.warning("Failed to upload device signature", exc_info=True) # ------------------------------------------------------------------ # MatrixPlatform adapter # ------------------------------------------------------------------
[docs] class MatrixPlatform(PlatformAdapter): """Platform adapter for Matrix via matrix-nio. Parameters ---------- message_handler: Async callback that receives :class:`IncomingMessage` instances. homeserver: Matrix homeserver URL. user_id: Matrix user ID for the bot. password: Password (only needed for first login). store_path: Path to the nio E2EE key store. credentials_file: Path to the JSON file for persisting login credentials. """
[docs] def __init__( self, message_handler: MessageHandler, *, homeserver: str, user_id: str, password: str = "", store_path: str = "nio_store", credentials_file: str = "matrix_credentials.json", media_cache: MediaCache | None = None, config: Config | None = None, ) -> None: """Initialize the instance. Args: message_handler (MessageHandler): The message handler value. """ super().__init__(message_handler) self._homeserver = homeserver self._user_id = user_id self._password = password self._store_path = store_path self._credentials_file = credentials_file self._media_cache = media_cache self._config = config self.client: AsyncClient | None = None self.credentials: dict | None = None self._sync_task: asyncio.Task | None = None self._stop_event = asyncio.Event() # channel_id -> set of event IDs the bot has sent self._sent_events: dict[str, set[str]] = {} self._typing_tasks: dict[str, asyncio.Task[None]] = {} # request_event_id -> (Sas, OlmDevice, room_id) for in-room SAS verifications self._in_room_sas: dict[str, tuple[Sas, OlmDevice, str]] = {}
# -- PlatformAdapter metadata -------------------------------------- @property def name(self) -> str: """Name. Returns: str: Result string. """ return "matrix" @property def is_running(self) -> bool: """Check whether is running. Returns: bool: True on success, False otherwise. """ return self._sync_task is not None and not self._sync_task.done() @property def bot_identity(self) -> dict[str, str]: uid = (self.client.user_id if self.client else None) or self._user_id return { "platform": "matrix", "user_id": uid, "display_name": uid, "mention": uid, } # -- PlatformAdapter lifecycle -------------------------------------
[docs] async def start(self) -> None: """Start. """ if self.is_running: logger.warning("Matrix platform is already running") return self._stop_event.clear() # Ensure store directory exists store_path = Path(self._store_path) store_path.mkdir(parents=True, exist_ok=True) # Configure nio with E2EE client_config = AsyncClientConfig( max_limit_exceeded=0, max_timeouts=0, store_sync_tokens=True, encryption_enabled=True, ) self.client = AsyncClient( homeserver=self._homeserver, user=self._user_id, store_path=str(store_path), config=client_config, ) # Login or restore session creds = self.credentials or await load_matrix_credentials( self._credentials_file, ) self.credentials = creds if creds and creds.get("access_token"): logger.info( "Restoring Matrix session for %s on device %s", creds["user_id"], creds["device_id"], ) self.client.restore_login( user_id=creds["user_id"], device_id=creds["device_id"], access_token=creds["access_token"], ) else: if not self._password: raise RuntimeError( "No saved Matrix credentials and no password " "configured – cannot login" ) logger.info("Logging in to Matrix as %s …", self._user_id) resp = await self.client.login( self._password, device_name="MatrixLLMBot", ) if not isinstance(resp, LoginResponse): await self.client.close() raise RuntimeError(f"Matrix login failed: {resp}") logger.info( "Logged in as %s on device %s", resp.user_id, resp.device_id, ) await save_matrix_credentials( self._credentials_file, self._homeserver, self.client, ) self.credentials = { "homeserver": self._homeserver, "user_id": resp.user_id, "device_id": resp.device_id, "access_token": resp.access_token, } # Register event callbacks self._register_callbacks() # Initial sync + auto-trust + key management logger.info("Matrix: performing initial sync …") await self.client.sync(timeout=30000, full_state=True) trust_all_devices(self.client) if self.client.should_upload_keys: await self.client.keys_upload() if self.client.should_query_keys: await self.client.keys_query() if self.client.should_claim_keys: await self.client.keys_claim(self.client.get_users_for_key_claiming()) await self.client.send_to_device_messages() # Cross-signing: self-sign our device so other users don't see # "Encrypted by a device not verified by its owner" try: saved_seeds = (self.credentials or {}).get("cross_signing_seeds") await setup_cross_signing( self.client, password=self._password, credentials_file=self._credentials_file, saved_seeds=saved_seeds, ) except Exception: logger.warning("Cross-signing setup failed", exc_info=True) logger.info("Matrix platform is running. Listening for messages …") self._sync_task = asyncio.create_task(self._sync_loop())
[docs] async def stop(self) -> None: """Stop. """ if not self.is_running: return self._stop_event.set() if self._sync_task: self._sync_task.cancel() try: await self._sync_task except asyncio.CancelledError: pass self._sync_task = None if self.client: await self.client.close() self.client = None logger.info("Matrix platform stopped")
# -- PlatformAdapter outbound messaging ----------------------------
[docs] async def send(self, channel_id: str, text: str) -> str: """Send. Args: channel_id (str): Discord/Matrix channel identifier. text (str): Text content. Returns: str: Result string. """ if self.client is None: logger.error("Matrix client is not connected") return "" html = _markdown_to_html(text) html, mentioned_users = _linkify_matrix_mentions(html) content: dict[str, Any] = { "msgtype": "m.text", "body": _strip_html(text), "format": "org.matrix.custom.html", "formatted_body": html, "m.mentions": {"user_ids": mentioned_users}, } try: resp = await self.client.room_send( room_id=channel_id, message_type="m.room.message", content=content, ignore_unverified_devices=True, ) if isinstance(resp, RoomSendResponse): self._sent_events.setdefault( channel_id, set(), ).add(resp.event_id) return resp.event_id logger.warning( "room_send to %s returned non-success response: %s", channel_id, type(resp).__name__, ) except Exception: logger.exception( "Failed to send message to Matrix room %s", channel_id, ) return ""
[docs] async def send_file( self, channel_id: str, data: bytes, filename: str, mimetype: str = "application/octet-stream", ) -> str | None: """Send file. Args: channel_id (str): Discord/Matrix channel identifier. data (bytes): Input data payload. filename (str): The filename value. mimetype (str): The mimetype value. Returns: str | None: The mxc:// content URI, or None on failure. """ if self.client is None: logger.error("Matrix client is not connected") return None try: room = self.client.rooms.get(channel_id) room_encrypted = room.encrypted if room else False upload_resp, encryption_keys = await self.client.upload( io.BytesIO(data), content_type=mimetype, filename=filename, filesize=len(data), encrypt=room_encrypted, ) content_uri: str = upload_resp.content_uri msgtype = "m.file" if mimetype.startswith("image/"): msgtype = "m.image" elif mimetype.startswith("audio/"): msgtype = "m.audio" elif mimetype.startswith("video/"): msgtype = "m.video" content = { "msgtype": msgtype, "body": filename, "info": { "mimetype": mimetype, "size": len(data), }, } if encryption_keys: encryption_keys["url"] = content_uri content["file"] = encryption_keys else: content["url"] = content_uri resp = await self.client.room_send( room_id=channel_id, message_type="m.room.message", content=content, ignore_unverified_devices=True, ) if isinstance(resp, RoomSendResponse): self._sent_events.setdefault( channel_id, set(), ).add(resp.event_id) return content_uri except Exception: logger.exception( "Failed to send file to Matrix room %s", channel_id, ) return None
# -- Typing indicator ----------------------------------------------
[docs] async def start_typing(self, channel_id: str) -> None: """Start typing. Args: channel_id (str): Discord/Matrix channel identifier. """ await self.stop_typing(channel_id) if self.client is None: return async def _typing_loop(room_id: str) -> None: """Internal helper: typing loop. Args: room_id (str): The room id value. """ try: while True: await self.client.room_typing( # type: ignore[union-attr] room_id, typing_state=True, timeout=30000, ) await asyncio.sleep(25) except asyncio.CancelledError: pass self._typing_tasks[channel_id] = asyncio.create_task( _typing_loop(channel_id), )
[docs] async def stop_typing(self, channel_id: str) -> None: """Stop typing. Args: channel_id (str): Discord/Matrix channel identifier. """ task = self._typing_tasks.pop(channel_id, None) if task is not None and not task.done(): task.cancel() try: await task except asyncio.CancelledError: pass if self.client is not None: try: await self.client.room_typing( channel_id, typing_state=False, ) except Exception: logger.debug( "Failed to clear typing indicator in %s", channel_id, exc_info=True, )
# -- Server/channel listing ----------------------------------------
[docs] async def list_servers_and_channels(self) -> list[dict[str, Any]]: """Return all Matrix rooms the bot is in. Matrix doesn't have a guild/server hierarchy in the same way Discord does — each room is listed as a standalone entry. """ if self.client is None: return [] rooms: list[dict[str, Any]] = [] for room_id, room in self.client.rooms.items(): rooms.append({ "server_name": room.display_name, "server_id": room_id, "member_count": room.member_count, "channels": [], # Matrix rooms are flat }) return rooms
# -- Channel history -----------------------------------------------
[docs] async def fetch_history( self, channel_id: str, limit: int = 100, ) -> list[HistoricalMessage]: """Fetch history. Args: channel_id (str): Discord/Matrix channel identifier. limit (int): Maximum number of items. Returns: list[HistoricalMessage]: The result. """ if self.client is None: return [] start_token = self.client.next_batch if not start_token: return [] try: resp = await self.client.room_messages( room_id=channel_id, start=start_token, limit=limit, message_filter={"types": ["m.room.message"]}, ) except Exception: logger.debug( "Failed to fetch history for Matrix room %s", channel_id, exc_info=True, ) return [] if not isinstance(resp, RoomMessagesResponse): return [] room = self.client.rooms.get(channel_id) bot_user_id = self.client.user_id messages: list[HistoricalMessage] = [] _media_types = ( RoomMessageImage, RoomMessageVideo, RoomMessageAudio, RoomMessageFile, RoomEncryptedImage, RoomEncryptedVideo, RoomEncryptedAudio, RoomEncryptedFile, ) for event in resp.chunk: text: str | None = None if isinstance(event, RoomMessageText): text = event.body elif isinstance(event, _media_types): url = getattr(event, "url", "") or "" body = getattr(event, "body", "") or "" if isinstance(event, (RoomMessageImage, RoomEncryptedImage)): text = f"[Image: {body}]" if body else "[Image]" elif isinstance(event, (RoomMessageVideo, RoomEncryptedVideo)): text = f"[Video: {body}]" if body else "[Video]" elif isinstance(event, (RoomMessageAudio, RoomEncryptedAudio)): text = f"[Audio: {body}]" if body else "[Audio]" else: text = f"[File: {body}]" if body else "[File]" if url: text += f" ({url})" else: continue sender = event.sender if room is not None: display_name = room.user_name(sender) or sender else: display_name = sender messages.append(HistoricalMessage( user_id=sender, user_name=display_name, text=text, timestamp=datetime.fromtimestamp( event.server_timestamp / 1000, tz=timezone.utc, ), message_id=event.event_id, is_bot=(sender == bot_user_id), reply_to_id=_get_reply_to_id(event), )) messages.reverse() return messages
# -- Internal: event callbacks ------------------------------------- def _register_callbacks(self) -> None: """Attach matrix-nio event callbacks.""" assert self.client is not None self.client.add_event_callback( self._on_message, RoomMessageText, ) self.client.add_event_callback( self._on_invite, InviteMemberEvent, ) self.client.add_event_callback( self._on_megolm_event, MegolmEvent, ) _media_types = ( RoomMessageImage, RoomEncryptedImage, RoomMessageAudio, RoomEncryptedAudio, RoomMessageVideo, RoomEncryptedVideo, RoomMessageFile, RoomEncryptedFile, ) for event_type in _media_types: self.client.add_event_callback( self._on_media, event_type, ) self.client.add_to_device_callback( self._on_unknown_to_device, UnknownToDeviceEvent, ) self.client.add_to_device_callback( self._on_key_verification_start, KeyVerificationStart, ) self.client.add_to_device_callback( self._on_key_verification_key, KeyVerificationKey, ) self.client.add_to_device_callback( self._on_key_verification_mac, KeyVerificationMac, ) self.client.add_to_device_callback( self._on_key_verification_cancel, KeyVerificationCancel, ) self.client.add_event_callback( self._on_room_verification_request, RoomMessageUnknown, ) self.client.add_event_callback( self._on_room_verification_event, UnknownEvent, ) # -- In-room SAS verification helpers --------------------------------- async def _in_room_verif_send( self, room_id: str, event_type: str, content: dict, request_event_id: str, ) -> None: """Send an in-room verification event with the required m.relates_to. In-room verification events must NOT include transaction_id in the content body; the relation is carried by m.relates_to instead. """ assert self.client is not None content.pop("transaction_id", None) content["m.relates_to"] = { "rel_type": "m.reference", "event_id": request_event_id, } await self.client.room_send( room_id, event_type, content, ignore_unverified_devices=True, ) @staticmethod def _to_device_dict(source: dict, tx_id: str) -> dict: """Build a to-device-style dict from a room event source dict. nio's KeyVerification* from_dict() methods expect a to-device envelope with ``sender`` and ``content.transaction_id``. """ content = dict(source.get("content", {})) content["transaction_id"] = tx_id return {"sender": source["sender"], "content": content} # -- In-room SAS handlers --------------------------------------------- async def _on_room_verification_request( self, room: MatrixRoom, event: RoomMessageUnknown, ) -> None: """Handle m.key.verification.request room messages and send ready.""" if self.client is None: return if event.msgtype != "m.key.verification.request": return if event.sender == self.client.user_id: return content = event.source.get("content", {}) methods: list[str] = content.get("methods", []) request_event_id: str = event.event_id if "m.sas.v1" not in methods: logger.debug( "Ignoring in-room verification request from %s: unsupported methods %s", event.sender, methods, ) return logger.debug( "In-room SAS verification request from %s in %s (event %s)", event.sender, room.room_id, request_event_id, ) ready_content = { "from_device": self.client.device_id, "methods": ["m.sas.v1"], } await self._in_room_verif_send( room.room_id, "m.key.verification.ready", ready_content, request_event_id, ) async def _on_room_verification_event( self, room: MatrixRoom, event: UnknownEvent, ) -> None: """Dispatch in-room m.key.verification.* events.""" if self.client is None: return if event.sender == self.client.user_id: return ev_type: str = event.type source: dict = event.source content: dict = source.get("content", {}) # transaction_id for in-room verification is the original request event_id relates_to: dict = content.get("m.relates_to", {}) request_event_id: str | None = relates_to.get("event_id") if ev_type == "m.key.verification.done": logger.debug( "Received in-room m.key.verification.done from %s", event.sender, ) return if ev_type == "m.key.verification.cancel": logger.debug( "In-room SAS canceled by %s (event %s): [%s] %s", event.sender, request_event_id, content.get("code", ""), content.get("reason", ""), ) if request_event_id: self._in_room_sas.pop(request_event_id, None) return if not ev_type.startswith("m.key.verification."): return if not request_event_id: logger.debug("In-room verification event %s has no m.relates_to", ev_type) return # ---- start ------------------------------------------------------- if ev_type == "m.key.verification.start": if content.get("method") != "m.sas.v1": return try: user_devices = self.client.device_store[event.sender] except KeyError: user_devices = {} from_device: str = content.get("from_device", "") olm_device: OlmDevice | None = user_devices.get(from_device) if olm_device is None: logger.debug( "In-room SAS start from unknown device %s/%s", event.sender, from_device, ) return fp_key: str = self.client.olm.account.identity_keys["ed25519"] td = self._to_device_dict(source, request_event_id) try: typed_start = KeyVerificationStart.from_dict(td) except Exception: logger.exception("Failed to parse in-room KeyVerificationStart") return try: sas = Sas.from_key_verification_start( self.client.user_id, self.client.device_id, fp_key, olm_device, typed_start, ) except Exception: logger.exception("Failed to create Sas for in-room verification") return # Recompute commitment from the original room event content # (without the injected transaction_id) so it matches what the # initiator will verify against. original_content = source.get("content", {}) sas.commitment = olm.sha256( sas.pubkey + NioApi.to_canonical_json(original_content), ) if sas.canceled: logger.debug( "In-room SAS start from %s/%s was invalid: %s", event.sender, from_device, sas.cancel_reason, ) return accept_msg = sas.accept_verification() await self._in_room_verif_send( room.room_id, accept_msg.type, dict(accept_msg.content), request_event_id, ) self._in_room_sas[request_event_id] = (sas, olm_device, room.room_id) logger.debug( "In-room SAS started with %s/%s (event %s)", event.sender, from_device, request_event_id, ) return # ---- key / mac --------------------------------------------------- entry = self._in_room_sas.get(request_event_id) if entry is None: logger.debug( "In-room verification event %s for unknown request %s", ev_type, request_event_id, ) return sas, olm_device, room_id = entry if ev_type == "m.key.verification.key": # As responder, send our key only after receiving theirs. key_msg = sas.share_key() await self._in_room_verif_send( room_id, key_msg.type, dict(key_msg.content), request_event_id, ) td = self._to_device_dict(source, request_event_id) try: typed_key = KeyVerificationKey.from_dict(td) except Exception: logger.exception("Failed to parse in-room KeyVerificationKey") return sas.receive_key_event(typed_key) if sas.canceled: logger.debug( "In-room SAS canceled after key exchange: %s", sas.cancel_reason, ) self._in_room_sas.pop(request_event_id, None) return emojis = sas.get_emoji() emoji_str = " ".join(f"{e} {d}" for e, d in emojis) logger.debug( "In-room SAS emojis (event %s): %s — auto-accepting", request_event_id, emoji_str, ) sas.accept_sas() mac_msg = sas.get_mac() await self._in_room_verif_send( room_id, mac_msg.type, dict(mac_msg.content), request_event_id, ) elif ev_type == "m.key.verification.mac": td = self._to_device_dict(source, request_event_id) try: typed_mac = KeyVerificationMac.from_dict(td) except Exception: logger.exception("Failed to parse in-room KeyVerificationMac") return sas.receive_mac_event(typed_mac) if sas.verified: self.client.verify_device(olm_device) await self._in_room_verif_send( room_id, "m.key.verification.done", {}, request_event_id, ) logger.debug( "In-room SAS verified device %s/%s (event %s)", olm_device.user_id, olm_device.id, request_event_id, ) elif sas.canceled: logger.debug( "In-room SAS MAC verification failed (event %s): %s", request_event_id, sas.cancel_reason, ) self._in_room_sas.pop(request_event_id, None) # -- SAS key-verification handlers ------------------------------------ async def _on_unknown_to_device( self, event: UnknownToDeviceEvent, ) -> None: """Handle untyped to-device events (verification.request, verification.done).""" if event.type == "m.key.verification.done": logger.debug( "Received m.key.verification.done from %s (source: %s)", event.sender, event.source.get("content", {}), ) return if event.type != "m.key.verification.request": return if self.client is None: return content = event.source.get("content", {}) methods: list[str] = content.get("methods", []) transaction_id: str | None = content.get("transaction_id") from_device: str | None = content.get("from_device") logger.debug("Verification request content: %s", content) if "m.sas.v1" not in methods or not transaction_id or not from_device: logger.debug( "Ignoring verification request from %s: unsupported methods %s", event.sender, methods, ) return logger.debug( "SAS verification request from %s (device %s, tx %s)", event.sender, from_device, transaction_id, ) ready_content = { "from_device": self.client.device_id, "methods": ["m.sas.v1"], "transaction_id": transaction_id, } ready_msg = ToDeviceMessage( "m.key.verification.ready", event.sender, from_device, ready_content, ) await self.client.to_device(ready_msg) async def _on_key_verification_start( self, event: KeyVerificationStart, ) -> None: """Accept a SAS verification start via nio's internal Sas object.""" if self.client is None: return if event.method != "m.sas.v1": return sas = self.client.olm.key_verifications.get(event.transaction_id) if sas is None or sas.canceled: logger.debug( "SAS start from %s/%s but no valid internal Sas (tx %s)", event.sender, event.from_device, event.transaction_id, ) return await self.client.to_device(sas.accept_verification()) logger.debug( "SAS verification accepted for %s/%s (tx %s)", event.sender, event.from_device, event.transaction_id, ) async def _on_key_verification_key( self, event: KeyVerificationKey, ) -> None: """Log emojis, auto-accept, and send our MAC after key exchange.""" if self.client is None: return sas = self.client.olm.key_verifications.get(event.transaction_id) if sas is None: return if sas.canceled: logger.debug( "SAS canceled after key exchange (tx %s): %s", event.transaction_id, sas.cancel_reason, ) return # nio's internal handler already called receive_key_event() and # queued share_key() in outgoing_to_device_messages — flush it. await self.client.send_to_device_messages() emojis = sas.get_emoji() emoji_str = " ".join(f"{e} {d}" for e, d in emojis) logger.debug( "SAS emojis for tx %s: %s — auto-accepting", event.transaction_id, emoji_str, ) sas.accept_sas() await self.client.to_device(sas.get_mac()) async def _on_key_verification_mac( self, event: KeyVerificationMac, ) -> None: """Log the result after nio's internal handler verified the MAC.""" if self.client is None: return sas = self.client.olm.key_verifications.get(event.transaction_id) if sas is None: return # nio's internal handler already called receive_mac_event() and # verify_device() if the MAC was valid. if sas.verified: done_content = {"transaction_id": event.transaction_id} done_msg = ToDeviceMessage( "m.key.verification.done", sas.other_olm_device.user_id, sas.other_olm_device.id, done_content, ) resp = await self.client.to_device(done_msg) logger.debug( "Device %s/%s verified via SAS (tx %s), done sent -> %s", sas.other_olm_device.user_id, sas.other_olm_device.id, event.transaction_id, type(resp).__name__, ) elif sas.canceled: logger.debug( "SAS MAC verification failed (tx %s): %s", event.transaction_id, sas.cancel_reason, ) async def _on_key_verification_cancel( self, event: KeyVerificationCancel, ) -> None: """Log cancellation of an in-progress SAS verification.""" logger.debug( "SAS verification canceled by %s (tx %s): [%s] %s", event.sender, event.transaction_id, event.code, event.reason, ) async def _on_message( self, room: MatrixRoom, event: RoomMessageText, ) -> None: """Convert a Matrix text message to IncomingMessage.""" assert self.client is not None if event.sender == self.client.user_id: return text = event.body attachments: list[Attachment] = [] # --- Resolve custom emojis as images -------------------------- cfg = self._config if cfg is not None and cfg.resolve_emojis_as_images and text: try: source_content = (event.source or {}).get("content", {}) formatted_body = source_content.get("formatted_body", "") if formatted_body: from platforms.emoji_resolver import ( extract_matrix_emojis, rewrite_matrix_emoji_text, download_matrix_emojis, ) emoji_matches = extract_matrix_emojis(formatted_body) if emoji_matches: emoji_atts = await download_matrix_emojis( emoji_matches, self.client, max_emojis=cfg.max_emojis_per_message, media_cache=self._media_cache, ) if emoji_atts: text = rewrite_matrix_emoji_text(text, emoji_matches[:cfg.max_emojis_per_message]) attachments.extend(emoji_atts) logger.info( "Resolved %d/%d Matrix custom emojis as images", len(emoji_atts), len(emoji_matches), ) except Exception: logger.debug("Matrix emoji resolution failed", exc_info=True) msg = IncomingMessage( platform="matrix", channel_id=room.room_id, user_id=event.sender, user_name=room.user_name(event.sender) or event.sender, text=text, is_addressed=self._is_addressed(room, event), attachments=attachments, channel_name=room.display_name, timestamp=datetime.fromtimestamp( event.server_timestamp / 1000, tz=timezone.utc, ), message_id=event.event_id, reply_to_id=_get_reply_to_id(event), extra={ "bot_id": self.client.user_id, "member_count": room.member_count, "is_dm": room.member_count <= 2, "is_server_admin": ( room.power_levels.get_user_level(event.sender) >= 50 if room.power_levels else False ), }, ) logger.debug( "[Matrix/%s] %s: %s", room.display_name, msg.user_name, msg.text[:120], ) await self._message_handler(msg, self) async def _on_media( self, room: MatrixRoom, event: MediaEvent, ) -> None: """Convert a Matrix media message to IncomingMessage.""" assert self.client is not None if event.sender == self.client.user_id: return is_addressed = self._is_addressed(room, event) caption = event.body or "" logger.debug( "[Matrix/%s] %s: [media] %s", room.display_name, room.user_name(event.sender), caption[:120], ) # Download the attachment (using cache when available) attachments: list[Attachment] = [] mxc_url: str = event.url try: if self._media_cache is not None: async def _download() -> tuple[bytes, str, str]: """Internal helper: download. Returns: tuple[bytes, str, str]: The result. """ return await download_matrix_media(self.client, event) # type: ignore[arg-type] data, mimetype, filename = await self._media_cache.get_or_download( mxc_url, _download, ) else: data, mimetype, filename = await download_matrix_media( self.client, event, ) attachments.append(Attachment( data=data, mimetype=mimetype, filename=filename, source_url=mxc_url, )) except Exception: logger.exception( "Failed to download Matrix media from %s", event.sender, ) msg = IncomingMessage( platform="matrix", channel_id=room.room_id, user_id=event.sender, user_name=room.user_name(event.sender) or event.sender, text=caption, is_addressed=is_addressed, attachments=attachments, channel_name=room.display_name, timestamp=datetime.fromtimestamp( event.server_timestamp / 1000, tz=timezone.utc, ), message_id=event.event_id, reply_to_id=_get_reply_to_id(event), extra={ "bot_id": self.client.user_id, "member_count": room.member_count, "is_dm": room.member_count <= 2, "is_server_admin": ( room.power_levels.get_user_level(event.sender) >= 50 if room.power_levels else False ), }, ) await self._message_handler(msg, self) async def _on_invite( self, room: MatrixRoom, event: InviteMemberEvent, ) -> None: """Auto-accept room invites directed at the bot.""" assert self.client is not None if event.state_key != self.client.user_id: return logger.info( "Matrix: received invite to %s from %s – joining", room.room_id, event.sender, ) asyncio.ensure_future(self._join_room(room.room_id)) async def _join_room(self, room_id: str) -> None: """Join a room with retry logic to handle invite-commit race conditions.""" assert self.client is not None for attempt in range(4): delay = 1 if attempt == 0 else 2 ** attempt # 1, 2, 4, 8 s await asyncio.sleep(delay) result = await self.client.join(room_id) if not isinstance(result, JoinError): logger.info("Joined Matrix room %s", room_id) return logger.warning( "Join attempt %d for %s failed: %s", attempt + 1, room_id, result, ) logger.error("Giving up joining %s after 4 attempts", room_id) async def _on_megolm_event( self, room: MatrixRoom, event: MegolmEvent, ) -> None: """Request missing room key for undecryptable Megolm messages.""" logger.warning( "Matrix: unable to decrypt message in %s from %s (session %s). " "Requesting missing room key.", room.room_id, event.sender, event.session_id, ) if self.client is not None: try: await self.client.request_room_key(event) await self.client.send_to_device_messages() logger.debug( "Sent key request for session %s in %s", event.session_id, room.room_id, ) except Exception: logger.debug( "Could not send key request for session %s", event.session_id, exc_info=True, ) # -- Internal: helpers --------------------------------------------- def _is_addressed( self, room: MatrixRoom, event: object, ) -> bool: """Return ``True`` when the bot should respond. The bot responds if any of the following hold: * The room has two or fewer members (DM / small room). * The bot's user ID appears in ``m.mentions.user_ids``. * The bot's user ID appears in the plain-text body. * The message is a reply to an event the bot sent. """ assert self.client is not None # DM / small room -- always respond if room.member_count <= 2: return True source = getattr(event, "source", None) or {} content = source.get("content", {}) # Modern m.mentions spec mentions = content.get("m.mentions", {}) if self.client.user_id in mentions.get("user_ids", []): return True # Fallback: user ID appears in body text body = getattr(event, "body", "") or "" if self.client.user_id in body: return True # Reply to one of the bot's own messages relates = content.get("m.relates_to", {}) reply_to = relates.get("m.in_reply_to", {}) reply_event_id = reply_to.get("event_id") if reply_event_id: sent = self._sent_events.get(room.room_id, set()) if reply_event_id in sent: return True return False async def _sync_loop(self) -> None: """Run the Matrix sync loop until stopped.""" backoff = 5.0 max_backoff = 60.0 try: while not self._stop_event.is_set(): try: await self.client.sync( # type: ignore[union-attr] timeout=30000, ) backoff = 5.0 # reset on success trust_all_devices( self.client, # type: ignore[arg-type] ) if self.client.should_upload_keys: # type: ignore[union-attr] await self.client.keys_upload() # type: ignore[union-attr] if self.client.should_query_keys: # type: ignore[union-attr] await self.client.keys_query() # type: ignore[union-attr] if self.client.should_claim_keys: # type: ignore[union-attr] await self.client.keys_claim( # type: ignore[union-attr] self.client.get_users_for_key_claiming(), # type: ignore[union-attr] ) await self.client.send_to_device_messages() # type: ignore[union-attr] except asyncio.CancelledError: raise except ( TimeoutError, aiohttp.ClientError, ConnectionError, OSError, ) as e: logger.warning( "Matrix sync transient error, retrying in %.0fs: %s", backoff, e, ) await asyncio.sleep(backoff) backoff = min(backoff * 2, max_backoff) continue except Exception: logger.exception("Matrix sync loop encountered an error") break except asyncio.CancelledError: pass finally: logger.info("Matrix sync loop exited")