Source code for platforms.matrix
"""Matrix platform adapter using matrix-nio.
Wraps the matrix-nio ``AsyncClient`` and converts Matrix events into
:class:`~platforms.base.IncomingMessage` instances for the shared
:class:`~message_processor.MessageProcessor`.
"""
from __future__ import annotations
import asyncio
import base64
import io
import json
import logging
import re
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Union
import markdown as md
import olm
import aiohttp
import aiofiles
from nio import (
AsyncClient,
AsyncClientConfig,
InviteMemberEvent,
JoinError,
JoinResponse,
LoginResponse,
MatrixRoom,
MegolmEvent,
RoomMessageText,
RoomSendResponse,
)
from nio.api import Api as NioApi
from nio.crypto import Sas
from nio.crypto.attachments import decrypt_attachment
from nio.crypto.device import OlmDevice
from nio.events.room_events import (
RoomEncryptedAudio,
RoomEncryptedFile,
RoomEncryptedImage,
RoomEncryptedMedia,
RoomEncryptedVideo,
RoomMessageAudio,
RoomMessageFile,
RoomMessageImage,
RoomMessageMedia,
RoomMessageUnknown,
RoomMessageVideo,
UnknownEvent,
)
from nio.events.to_device import (
KeyVerificationCancel,
KeyVerificationKey,
KeyVerificationMac,
KeyVerificationStart,
UnknownToDeviceEvent,
)
from nio import ToDeviceMessage
from nio.responses import DownloadError, RoomMessagesResponse
from media_cache import MediaCache
from config import Config
from platforms.base import (
Attachment,
HistoricalMessage,
IncomingMessage,
MessageHandler,
PlatformAdapter,
)
logger = logging.getLogger(__name__)
# Union of all media event types handled by on_media
MediaEvent = Union[
RoomMessageImage, RoomEncryptedImage,
RoomMessageAudio, RoomEncryptedAudio,
RoomMessageVideo, RoomEncryptedVideo,
RoomMessageFile, RoomEncryptedFile,
]
def _get_reply_to_id(event: Any) -> str:
"""Extract the reply-to event ID from a Matrix event, or ``""``."""
source = getattr(event, "source", None) or {}
content = source.get("content", {})
relates = content.get("m.relates_to", {})
reply_to = relates.get("m.in_reply_to", {})
return reply_to.get("event_id", "")
_STRIP_HTML_RE = re.compile(r"<[^>]+>")
_MD_EXTENSIONS = ["fenced_code", "tables", "nl2br"]
# Matches Matrix user IDs like @localpart:server or @localpart:server:port
_MATRIX_MENTION_RE = re.compile(r"@[\w.=\-/+]+:[\w.\-]+(:\d+)?")
# Matches any HTML tag (used to skip mention replacement inside tags)
_HTML_TAG_RE = re.compile(r"<[^>]+>")
def _markdown_to_html(text: str) -> str:
"""Convert Markdown text to HTML for Matrix ``formatted_body``."""
return md.markdown(text, extensions=_MD_EXTENSIONS)
def _strip_html(text: str) -> str:
"""Strip HTML tags for the plain-text ``body`` fallback."""
return _STRIP_HTML_RE.sub("", text)
def _linkify_matrix_mentions(html: str) -> tuple[str, list[str]]:
"""Replace bare @user:server mentions in HTML with matrix.to anchor links.
Operates only on text nodes (skips content inside HTML tags) so existing
links or attributes are never double-processed.
Returns ``(new_html, unique_user_ids)`` where ``unique_user_ids`` is the
ordered list of Matrix user IDs found, suitable for ``m.mentions.user_ids``.
"""
mentions: list[str] = []
def _replace(m: re.Match) -> str:
uid = m.group(0)
mentions.append(uid)
return f'<a href="https://matrix.to/#/{uid}">{uid}</a>'
parts: list[str] = []
pos = 0
for tag_match in _HTML_TAG_RE.finditer(html):
# Linkify text between tags, keep the tag itself verbatim
parts.append(_MATRIX_MENTION_RE.sub(_replace, html[pos:tag_match.start()]))
parts.append(tag_match.group(0))
pos = tag_match.end()
parts.append(_MATRIX_MENTION_RE.sub(_replace, html[pos:]))
seen: set[str] = set()
unique = [u for u in mentions if not (u in seen or seen.add(u))] # type: ignore[func-returns-value]
return "".join(parts), unique
# ------------------------------------------------------------------
# Matrix-specific media download
# ------------------------------------------------------------------
[docs]
async def download_matrix_media(
client: AsyncClient,
event: RoomMessageMedia | RoomEncryptedMedia,
) -> tuple[bytes, str, str]:
"""Download (and optionally decrypt) a media attachment from Matrix.
Returns
-------
tuple of (data, mimetype, filename)
"""
mxc_url: str = event.url
filename: str = event.body or "attachment"
resp = await client.download(mxc=mxc_url)
if isinstance(resp, DownloadError):
raise RuntimeError(
f"Failed to download {mxc_url}: {resp.message}"
)
data: bytes = resp.body # type: ignore[assignment]
# Determine MIME type
if isinstance(event, RoomEncryptedMedia):
mimetype = (
event.mimetype
or resp.content_type
or "application/octet-stream"
)
else:
info = event.source.get("content", {}).get("info", {})
mimetype = (
info.get("mimetype")
or resp.content_type
or "application/octet-stream"
)
# Decrypt if necessary
if isinstance(event, RoomEncryptedMedia):
data = decrypt_attachment(
ciphertext=data,
key=event.key["k"],
hash=event.hashes["sha256"],
iv=event.iv,
)
return data, mimetype, filename
# ------------------------------------------------------------------
# Credential helpers
# ------------------------------------------------------------------
[docs]
async def save_matrix_credentials(
credentials_file: str,
homeserver: str,
client: AsyncClient,
) -> None:
"""Persist Matrix login credentials (preserving extra keys like seeds)."""
existing: dict[str, Any] = {}
cred_path = Path(credentials_file)
if cred_path.exists():
try:
async with aiofiles.open(credentials_file, "r") as f:
existing = json.loads(await f.read())
except (json.JSONDecodeError, OSError):
pass
existing.update({
"homeserver": homeserver,
"user_id": client.user_id,
"device_id": client.device_id,
"access_token": client.access_token,
})
async with aiofiles.open(credentials_file, "w") as f:
await f.write(json.dumps(existing, indent=2))
logger.info("Matrix credentials saved to %s", credentials_file)
[docs]
async def load_matrix_credentials(credentials_file: str) -> dict | None:
"""Load previously saved Matrix credentials, or return ``None``."""
cred_path = Path(credentials_file)
if not cred_path.exists():
return None
async with aiofiles.open(credentials_file, "r") as f:
content = await f.read()
try:
return json.loads(content)
except json.JSONDecodeError:
return None
# ------------------------------------------------------------------
# Auto-trust helper
# ------------------------------------------------------------------
[docs]
def trust_all_devices(client: AsyncClient) -> None:
"""Mark every known device of every tracked user as trusted."""
for user_id, devices in client.device_store.items():
for device_id, olm_device in devices.items():
if not client.olm.is_device_verified(olm_device):
client.verify_device(olm_device)
logger.debug(
"Auto-trusted device %s of %s", device_id, user_id,
)
# ------------------------------------------------------------------
# Cross-signing setup
# ------------------------------------------------------------------
def _sign_json(signing_key: olm.PkSigning, obj: dict) -> str:
"""Canonical-JSON-sign *obj* with *signing_key*, return the signature."""
canonical = json.dumps(obj, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
return signing_key.sign(canonical)
[docs]
async def setup_cross_signing(
client: AsyncClient,
password: str,
credentials_file: str,
saved_seeds: dict[str, str] | None = None,
) -> None:
"""Generate cross-signing keys, upload them, and self-sign the device.
If *saved_seeds* is provided the keys are re-derived from persisted
seeds instead of generating new ones. Seeds are persisted to
*credentials_file* for future restarts.
"""
user_id = client.user_id
device_id = client.device_id
assert user_id and device_id
# --- Derive or generate keys ----------------------------------------
if saved_seeds:
master_seed = base64.b64decode(saved_seeds["master"])
self_signing_seed = base64.b64decode(saved_seeds["self_signing"])
user_signing_seed = base64.b64decode(saved_seeds["user_signing"])
else:
master_seed = olm.PkSigning.generate_seed()
self_signing_seed = olm.PkSigning.generate_seed()
user_signing_seed = olm.PkSigning.generate_seed()
master_key = olm.PkSigning(master_seed)
self_signing_key = olm.PkSigning(self_signing_seed)
user_signing_key = olm.PkSigning(user_signing_seed)
# --- Check if master key is already published -----------------------
query_path = (
f"/_matrix/client/v3/keys/query"
f"?access_token={client.access_token}"
)
query_body = json.dumps({"device_keys": {user_id: []}})
try:
resp = await client.send("POST", query_path, query_body)
if resp.status == 200:
data = json.loads(await resp.read())
existing_master = (
data.get("master_keys", {}).get(user_id, {}).get("keys", {})
)
if f"ed25519:{master_key.public_key}" in existing_master:
logger.info(
"Cross-signing master key already published — skipping upload",
)
await _sign_own_device(
client, self_signing_key, user_id, device_id,
)
return
except Exception:
logger.debug("Could not query existing cross-signing keys", exc_info=True)
if not password:
logger.warning(
"Cross-signing setup requires a password for UIA — skipping. "
"Set up will be attempted on next fresh login.",
)
return
# --- Build the key upload payload -----------------------------------
master_key_obj = {
"user_id": user_id,
"usage": ["master"],
"keys": {f"ed25519:{master_key.public_key}": master_key.public_key},
}
self_signing_key_obj = {
"user_id": user_id,
"usage": ["self_signing"],
"keys": {
f"ed25519:{self_signing_key.public_key}": self_signing_key.public_key,
},
}
ss_sig = _sign_json(master_key, self_signing_key_obj)
self_signing_key_obj["signatures"] = {
user_id: {f"ed25519:{master_key.public_key}": ss_sig},
}
user_signing_key_obj = {
"user_id": user_id,
"usage": ["user_signing"],
"keys": {
f"ed25519:{user_signing_key.public_key}": user_signing_key.public_key,
},
}
us_sig = _sign_json(master_key, user_signing_key_obj)
user_signing_key_obj["signatures"] = {
user_id: {f"ed25519:{master_key.public_key}": us_sig},
}
upload_body: dict[str, Any] = {
"master_key": master_key_obj,
"self_signing_key": self_signing_key_obj,
"user_signing_key": user_signing_key_obj,
}
# --- Upload with UIA (password auth) --------------------------------
upload_path = (
f"/_matrix/client/v3/keys/device_signing/upload"
f"?access_token={client.access_token}"
)
# First request to get the session for UIA
resp = await client.send(
"POST", upload_path, json.dumps(upload_body),
)
if resp.status == 401:
uia_data = json.loads(await resp.read())
session = uia_data.get("session", "")
upload_body["auth"] = {
"type": "m.login.password",
"user": user_id,
"password": password,
"session": session,
}
resp = await client.send(
"POST", upload_path, json.dumps(upload_body),
)
if resp.status != 200:
body = await resp.read()
logger.error(
"Cross-signing key upload failed (HTTP %d): %s",
resp.status, body.decode(errors="replace"),
)
return
logger.info("Cross-signing keys uploaded successfully")
# --- Sign own device ------------------------------------------------
await _sign_own_device(client, self_signing_key, user_id, device_id)
# --- Persist seeds --------------------------------------------------
seeds = {
"master": base64.b64encode(master_seed).decode(),
"self_signing": base64.b64encode(self_signing_seed).decode(),
"user_signing": base64.b64encode(user_signing_seed).decode(),
}
try:
creds: dict[str, Any] = {}
cred_path = Path(credentials_file)
if cred_path.exists():
async with aiofiles.open(credentials_file, "r") as f:
creds = json.loads(await f.read())
creds["cross_signing_seeds"] = seeds
async with aiofiles.open(credentials_file, "w") as f:
await f.write(json.dumps(creds, indent=2))
logger.info("Cross-signing seeds persisted to %s", credentials_file)
except Exception:
logger.warning("Failed to persist cross-signing seeds", exc_info=True)
async def _sign_own_device(
client: AsyncClient,
self_signing_key: olm.PkSigning,
user_id: str,
device_id: str,
) -> None:
"""Sign the bot's own device key with the self-signing key."""
device_key_id = f"ed25519:{device_id}"
device_keys = client.olm.account.identity_keys
ed25519_key = device_keys["ed25519"]
curve25519_key = device_keys["curve25519"]
device_key_obj = {
"algorithms": [
"m.olm.v1.curve25519-aes-sha2-256",
"m.megolm.v1.aes-sha2",
],
"device_id": device_id,
"keys": {
f"curve25519:{device_id}": curve25519_key,
f"ed25519:{device_id}": ed25519_key,
},
"user_id": user_id,
}
dev_sig = _sign_json(self_signing_key, device_key_obj)
sig_upload = {
user_id: {
device_id: {
f"ed25519:{self_signing_key.public_key}": dev_sig,
},
},
}
sig_path = (
f"/_matrix/client/v3/keys/signatures/upload"
f"?access_token={client.access_token}"
)
try:
resp = await client.send(
"POST", sig_path, json.dumps(sig_upload),
)
if resp.status == 200:
logger.info(
"Device %s self-signed with cross-signing key", device_id,
)
else:
body = await resp.read()
logger.warning(
"Device signature upload failed (HTTP %d): %s",
resp.status, body.decode(errors="replace"),
)
except Exception:
logger.warning("Failed to upload device signature", exc_info=True)
# ------------------------------------------------------------------
# MatrixPlatform adapter
# ------------------------------------------------------------------
[docs]
class MatrixPlatform(PlatformAdapter):
"""Platform adapter for Matrix via matrix-nio.
Parameters
----------
message_handler:
Async callback that receives :class:`IncomingMessage` instances.
homeserver:
Matrix homeserver URL.
user_id:
Matrix user ID for the bot.
password:
Password (only needed for first login).
store_path:
Path to the nio E2EE key store.
credentials_file:
Path to the JSON file for persisting login credentials.
"""
[docs]
def __init__(
self,
message_handler: MessageHandler,
*,
homeserver: str,
user_id: str,
password: str = "",
store_path: str = "nio_store",
credentials_file: str = "matrix_credentials.json",
media_cache: MediaCache | None = None,
config: Config | None = None,
) -> None:
"""Initialize the instance.
Args:
message_handler (MessageHandler): The message handler value.
"""
super().__init__(message_handler)
self._homeserver = homeserver
self._user_id = user_id
self._password = password
self._store_path = store_path
self._credentials_file = credentials_file
self._media_cache = media_cache
self._config = config
self.client: AsyncClient | None = None
self.credentials: dict | None = None
self._sync_task: asyncio.Task | None = None
self._stop_event = asyncio.Event()
# channel_id -> set of event IDs the bot has sent
self._sent_events: dict[str, set[str]] = {}
self._typing_tasks: dict[str, asyncio.Task[None]] = {}
# request_event_id -> (Sas, OlmDevice, room_id) for in-room SAS verifications
self._in_room_sas: dict[str, tuple[Sas, OlmDevice, str]] = {}
# -- PlatformAdapter metadata --------------------------------------
@property
def name(self) -> str:
"""Name.
Returns:
str: Result string.
"""
return "matrix"
@property
def is_running(self) -> bool:
"""Check whether is running.
Returns:
bool: True on success, False otherwise.
"""
return self._sync_task is not None and not self._sync_task.done()
@property
def bot_identity(self) -> dict[str, str]:
uid = (self.client.user_id if self.client else None) or self._user_id
return {
"platform": "matrix",
"user_id": uid,
"display_name": uid,
"mention": uid,
}
# -- PlatformAdapter lifecycle -------------------------------------
[docs]
async def start(self) -> None:
"""Start.
"""
if self.is_running:
logger.warning("Matrix platform is already running")
return
self._stop_event.clear()
# Ensure store directory exists
store_path = Path(self._store_path)
store_path.mkdir(parents=True, exist_ok=True)
# Configure nio with E2EE
client_config = AsyncClientConfig(
max_limit_exceeded=0,
max_timeouts=0,
store_sync_tokens=True,
encryption_enabled=True,
)
self.client = AsyncClient(
homeserver=self._homeserver,
user=self._user_id,
store_path=str(store_path),
config=client_config,
)
# Login or restore session
creds = self.credentials or await load_matrix_credentials(
self._credentials_file,
)
self.credentials = creds
if creds and creds.get("access_token"):
logger.info(
"Restoring Matrix session for %s on device %s",
creds["user_id"], creds["device_id"],
)
self.client.restore_login(
user_id=creds["user_id"],
device_id=creds["device_id"],
access_token=creds["access_token"],
)
else:
if not self._password:
raise RuntimeError(
"No saved Matrix credentials and no password "
"configured – cannot login"
)
logger.info("Logging in to Matrix as %s …", self._user_id)
resp = await self.client.login(
self._password, device_name="MatrixLLMBot",
)
if not isinstance(resp, LoginResponse):
await self.client.close()
raise RuntimeError(f"Matrix login failed: {resp}")
logger.info(
"Logged in as %s on device %s",
resp.user_id, resp.device_id,
)
await save_matrix_credentials(
self._credentials_file, self._homeserver, self.client,
)
self.credentials = {
"homeserver": self._homeserver,
"user_id": resp.user_id,
"device_id": resp.device_id,
"access_token": resp.access_token,
}
# Register event callbacks
self._register_callbacks()
# Initial sync + auto-trust + key management
logger.info("Matrix: performing initial sync …")
await self.client.sync(timeout=30000, full_state=True)
trust_all_devices(self.client)
if self.client.should_upload_keys:
await self.client.keys_upload()
if self.client.should_query_keys:
await self.client.keys_query()
if self.client.should_claim_keys:
await self.client.keys_claim(self.client.get_users_for_key_claiming())
await self.client.send_to_device_messages()
# Cross-signing: self-sign our device so other users don't see
# "Encrypted by a device not verified by its owner"
try:
saved_seeds = (self.credentials or {}).get("cross_signing_seeds")
await setup_cross_signing(
self.client,
password=self._password,
credentials_file=self._credentials_file,
saved_seeds=saved_seeds,
)
except Exception:
logger.warning("Cross-signing setup failed", exc_info=True)
logger.info("Matrix platform is running. Listening for messages …")
self._sync_task = asyncio.create_task(self._sync_loop())
[docs]
async def stop(self) -> None:
"""Stop.
"""
if not self.is_running:
return
self._stop_event.set()
if self._sync_task:
self._sync_task.cancel()
try:
await self._sync_task
except asyncio.CancelledError:
pass
self._sync_task = None
if self.client:
await self.client.close()
self.client = None
logger.info("Matrix platform stopped")
# -- PlatformAdapter outbound messaging ----------------------------
[docs]
async def send(self, channel_id: str, text: str) -> str:
"""Send.
Args:
channel_id (str): Discord/Matrix channel identifier.
text (str): Text content.
Returns:
str: Result string.
"""
if self.client is None:
logger.error("Matrix client is not connected")
return ""
html = _markdown_to_html(text)
html, mentioned_users = _linkify_matrix_mentions(html)
content: dict[str, Any] = {
"msgtype": "m.text",
"body": _strip_html(text),
"format": "org.matrix.custom.html",
"formatted_body": html,
"m.mentions": {"user_ids": mentioned_users},
}
try:
resp = await self.client.room_send(
room_id=channel_id,
message_type="m.room.message",
content=content,
ignore_unverified_devices=True,
)
if isinstance(resp, RoomSendResponse):
self._sent_events.setdefault(
channel_id, set(),
).add(resp.event_id)
return resp.event_id
logger.warning(
"room_send to %s returned non-success response: %s",
channel_id, type(resp).__name__,
)
except Exception:
logger.exception(
"Failed to send message to Matrix room %s", channel_id,
)
return ""
[docs]
async def send_file(
self,
channel_id: str,
data: bytes,
filename: str,
mimetype: str = "application/octet-stream",
) -> str | None:
"""Send file.
Args:
channel_id (str): Discord/Matrix channel identifier.
data (bytes): Input data payload.
filename (str): The filename value.
mimetype (str): The mimetype value.
Returns:
str | None: The mxc:// content URI, or None on failure.
"""
if self.client is None:
logger.error("Matrix client is not connected")
return None
try:
room = self.client.rooms.get(channel_id)
room_encrypted = room.encrypted if room else False
upload_resp, encryption_keys = await self.client.upload(
io.BytesIO(data),
content_type=mimetype,
filename=filename,
filesize=len(data),
encrypt=room_encrypted,
)
content_uri: str = upload_resp.content_uri
msgtype = "m.file"
if mimetype.startswith("image/"):
msgtype = "m.image"
elif mimetype.startswith("audio/"):
msgtype = "m.audio"
elif mimetype.startswith("video/"):
msgtype = "m.video"
content = {
"msgtype": msgtype,
"body": filename,
"info": {
"mimetype": mimetype,
"size": len(data),
},
}
if encryption_keys:
encryption_keys["url"] = content_uri
content["file"] = encryption_keys
else:
content["url"] = content_uri
resp = await self.client.room_send(
room_id=channel_id,
message_type="m.room.message",
content=content,
ignore_unverified_devices=True,
)
if isinstance(resp, RoomSendResponse):
self._sent_events.setdefault(
channel_id, set(),
).add(resp.event_id)
return content_uri
except Exception:
logger.exception(
"Failed to send file to Matrix room %s", channel_id,
)
return None
# -- Typing indicator ----------------------------------------------
[docs]
async def start_typing(self, channel_id: str) -> None:
"""Start typing.
Args:
channel_id (str): Discord/Matrix channel identifier.
"""
await self.stop_typing(channel_id)
if self.client is None:
return
async def _typing_loop(room_id: str) -> None:
"""Internal helper: typing loop.
Args:
room_id (str): The room id value.
"""
try:
while True:
await self.client.room_typing( # type: ignore[union-attr]
room_id, typing_state=True, timeout=30000,
)
await asyncio.sleep(25)
except asyncio.CancelledError:
pass
self._typing_tasks[channel_id] = asyncio.create_task(
_typing_loop(channel_id),
)
[docs]
async def stop_typing(self, channel_id: str) -> None:
"""Stop typing.
Args:
channel_id (str): Discord/Matrix channel identifier.
"""
task = self._typing_tasks.pop(channel_id, None)
if task is not None and not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
if self.client is not None:
try:
await self.client.room_typing(
channel_id, typing_state=False,
)
except Exception:
logger.debug(
"Failed to clear typing indicator in %s",
channel_id, exc_info=True,
)
# -- Server/channel listing ----------------------------------------
[docs]
async def list_servers_and_channels(self) -> list[dict[str, Any]]:
"""Return all Matrix rooms the bot is in.
Matrix doesn't have a guild/server hierarchy in the same way
Discord does — each room is listed as a standalone entry.
"""
if self.client is None:
return []
rooms: list[dict[str, Any]] = []
for room_id, room in self.client.rooms.items():
rooms.append({
"server_name": room.display_name,
"server_id": room_id,
"member_count": room.member_count,
"channels": [], # Matrix rooms are flat
})
return rooms
# -- Channel history -----------------------------------------------
[docs]
async def fetch_history(
self,
channel_id: str,
limit: int = 100,
) -> list[HistoricalMessage]:
"""Fetch history.
Args:
channel_id (str): Discord/Matrix channel identifier.
limit (int): Maximum number of items.
Returns:
list[HistoricalMessage]: The result.
"""
if self.client is None:
return []
start_token = self.client.next_batch
if not start_token:
return []
try:
resp = await self.client.room_messages(
room_id=channel_id,
start=start_token,
limit=limit,
message_filter={"types": ["m.room.message"]},
)
except Exception:
logger.debug(
"Failed to fetch history for Matrix room %s",
channel_id, exc_info=True,
)
return []
if not isinstance(resp, RoomMessagesResponse):
return []
room = self.client.rooms.get(channel_id)
bot_user_id = self.client.user_id
messages: list[HistoricalMessage] = []
_media_types = (
RoomMessageImage, RoomMessageVideo,
RoomMessageAudio, RoomMessageFile,
RoomEncryptedImage, RoomEncryptedVideo,
RoomEncryptedAudio, RoomEncryptedFile,
)
for event in resp.chunk:
text: str | None = None
if isinstance(event, RoomMessageText):
text = event.body
elif isinstance(event, _media_types):
url = getattr(event, "url", "") or ""
body = getattr(event, "body", "") or ""
if isinstance(event, (RoomMessageImage, RoomEncryptedImage)):
text = f"[Image: {body}]" if body else "[Image]"
elif isinstance(event, (RoomMessageVideo, RoomEncryptedVideo)):
text = f"[Video: {body}]" if body else "[Video]"
elif isinstance(event, (RoomMessageAudio, RoomEncryptedAudio)):
text = f"[Audio: {body}]" if body else "[Audio]"
else:
text = f"[File: {body}]" if body else "[File]"
if url:
text += f" ({url})"
else:
continue
sender = event.sender
if room is not None:
display_name = room.user_name(sender) or sender
else:
display_name = sender
messages.append(HistoricalMessage(
user_id=sender,
user_name=display_name,
text=text,
timestamp=datetime.fromtimestamp(
event.server_timestamp / 1000, tz=timezone.utc,
),
message_id=event.event_id,
is_bot=(sender == bot_user_id),
reply_to_id=_get_reply_to_id(event),
))
messages.reverse()
return messages
# -- Internal: event callbacks -------------------------------------
def _register_callbacks(self) -> None:
"""Attach matrix-nio event callbacks."""
assert self.client is not None
self.client.add_event_callback(
self._on_message, RoomMessageText,
)
self.client.add_event_callback(
self._on_invite, InviteMemberEvent,
)
self.client.add_event_callback(
self._on_megolm_event, MegolmEvent,
)
_media_types = (
RoomMessageImage, RoomEncryptedImage,
RoomMessageAudio, RoomEncryptedAudio,
RoomMessageVideo, RoomEncryptedVideo,
RoomMessageFile, RoomEncryptedFile,
)
for event_type in _media_types:
self.client.add_event_callback(
self._on_media, event_type,
)
self.client.add_to_device_callback(
self._on_unknown_to_device, UnknownToDeviceEvent,
)
self.client.add_to_device_callback(
self._on_key_verification_start, KeyVerificationStart,
)
self.client.add_to_device_callback(
self._on_key_verification_key, KeyVerificationKey,
)
self.client.add_to_device_callback(
self._on_key_verification_mac, KeyVerificationMac,
)
self.client.add_to_device_callback(
self._on_key_verification_cancel, KeyVerificationCancel,
)
self.client.add_event_callback(
self._on_room_verification_request, RoomMessageUnknown,
)
self.client.add_event_callback(
self._on_room_verification_event, UnknownEvent,
)
# -- In-room SAS verification helpers ---------------------------------
async def _in_room_verif_send(
self,
room_id: str,
event_type: str,
content: dict,
request_event_id: str,
) -> None:
"""Send an in-room verification event with the required m.relates_to.
In-room verification events must NOT include transaction_id in the
content body; the relation is carried by m.relates_to instead.
"""
assert self.client is not None
content.pop("transaction_id", None)
content["m.relates_to"] = {
"rel_type": "m.reference",
"event_id": request_event_id,
}
await self.client.room_send(
room_id, event_type, content, ignore_unverified_devices=True,
)
@staticmethod
def _to_device_dict(source: dict, tx_id: str) -> dict:
"""Build a to-device-style dict from a room event source dict.
nio's KeyVerification* from_dict() methods expect a to-device
envelope with ``sender`` and ``content.transaction_id``.
"""
content = dict(source.get("content", {}))
content["transaction_id"] = tx_id
return {"sender": source["sender"], "content": content}
# -- In-room SAS handlers ---------------------------------------------
async def _on_room_verification_request(
self, room: MatrixRoom, event: RoomMessageUnknown,
) -> None:
"""Handle m.key.verification.request room messages and send ready."""
if self.client is None:
return
if event.msgtype != "m.key.verification.request":
return
if event.sender == self.client.user_id:
return
content = event.source.get("content", {})
methods: list[str] = content.get("methods", [])
request_event_id: str = event.event_id
if "m.sas.v1" not in methods:
logger.debug(
"Ignoring in-room verification request from %s: unsupported methods %s",
event.sender, methods,
)
return
logger.debug(
"In-room SAS verification request from %s in %s (event %s)",
event.sender, room.room_id, request_event_id,
)
ready_content = {
"from_device": self.client.device_id,
"methods": ["m.sas.v1"],
}
await self._in_room_verif_send(
room.room_id, "m.key.verification.ready",
ready_content, request_event_id,
)
async def _on_room_verification_event(
self, room: MatrixRoom, event: UnknownEvent,
) -> None:
"""Dispatch in-room m.key.verification.* events."""
if self.client is None:
return
if event.sender == self.client.user_id:
return
ev_type: str = event.type
source: dict = event.source
content: dict = source.get("content", {})
# transaction_id for in-room verification is the original request event_id
relates_to: dict = content.get("m.relates_to", {})
request_event_id: str | None = relates_to.get("event_id")
if ev_type == "m.key.verification.done":
logger.debug(
"Received in-room m.key.verification.done from %s", event.sender,
)
return
if ev_type == "m.key.verification.cancel":
logger.debug(
"In-room SAS canceled by %s (event %s): [%s] %s",
event.sender, request_event_id,
content.get("code", ""), content.get("reason", ""),
)
if request_event_id:
self._in_room_sas.pop(request_event_id, None)
return
if not ev_type.startswith("m.key.verification."):
return
if not request_event_id:
logger.debug("In-room verification event %s has no m.relates_to", ev_type)
return
# ---- start -------------------------------------------------------
if ev_type == "m.key.verification.start":
if content.get("method") != "m.sas.v1":
return
try:
user_devices = self.client.device_store[event.sender]
except KeyError:
user_devices = {}
from_device: str = content.get("from_device", "")
olm_device: OlmDevice | None = user_devices.get(from_device)
if olm_device is None:
logger.debug(
"In-room SAS start from unknown device %s/%s",
event.sender, from_device,
)
return
fp_key: str = self.client.olm.account.identity_keys["ed25519"]
td = self._to_device_dict(source, request_event_id)
try:
typed_start = KeyVerificationStart.from_dict(td)
except Exception:
logger.exception("Failed to parse in-room KeyVerificationStart")
return
try:
sas = Sas.from_key_verification_start(
self.client.user_id,
self.client.device_id,
fp_key,
olm_device,
typed_start,
)
except Exception:
logger.exception("Failed to create Sas for in-room verification")
return
# Recompute commitment from the original room event content
# (without the injected transaction_id) so it matches what the
# initiator will verify against.
original_content = source.get("content", {})
sas.commitment = olm.sha256(
sas.pubkey + NioApi.to_canonical_json(original_content),
)
if sas.canceled:
logger.debug(
"In-room SAS start from %s/%s was invalid: %s",
event.sender, from_device, sas.cancel_reason,
)
return
accept_msg = sas.accept_verification()
await self._in_room_verif_send(
room.room_id, accept_msg.type,
dict(accept_msg.content), request_event_id,
)
self._in_room_sas[request_event_id] = (sas, olm_device, room.room_id)
logger.debug(
"In-room SAS started with %s/%s (event %s)",
event.sender, from_device, request_event_id,
)
return
# ---- key / mac ---------------------------------------------------
entry = self._in_room_sas.get(request_event_id)
if entry is None:
logger.debug(
"In-room verification event %s for unknown request %s",
ev_type, request_event_id,
)
return
sas, olm_device, room_id = entry
if ev_type == "m.key.verification.key":
# As responder, send our key only after receiving theirs.
key_msg = sas.share_key()
await self._in_room_verif_send(
room_id, key_msg.type, dict(key_msg.content), request_event_id,
)
td = self._to_device_dict(source, request_event_id)
try:
typed_key = KeyVerificationKey.from_dict(td)
except Exception:
logger.exception("Failed to parse in-room KeyVerificationKey")
return
sas.receive_key_event(typed_key)
if sas.canceled:
logger.debug(
"In-room SAS canceled after key exchange: %s", sas.cancel_reason,
)
self._in_room_sas.pop(request_event_id, None)
return
emojis = sas.get_emoji()
emoji_str = " ".join(f"{e} {d}" for e, d in emojis)
logger.debug(
"In-room SAS emojis (event %s): %s — auto-accepting",
request_event_id, emoji_str,
)
sas.accept_sas()
mac_msg = sas.get_mac()
await self._in_room_verif_send(
room_id, mac_msg.type, dict(mac_msg.content), request_event_id,
)
elif ev_type == "m.key.verification.mac":
td = self._to_device_dict(source, request_event_id)
try:
typed_mac = KeyVerificationMac.from_dict(td)
except Exception:
logger.exception("Failed to parse in-room KeyVerificationMac")
return
sas.receive_mac_event(typed_mac)
if sas.verified:
self.client.verify_device(olm_device)
await self._in_room_verif_send(
room_id, "m.key.verification.done", {}, request_event_id,
)
logger.debug(
"In-room SAS verified device %s/%s (event %s)",
olm_device.user_id, olm_device.id, request_event_id,
)
elif sas.canceled:
logger.debug(
"In-room SAS MAC verification failed (event %s): %s",
request_event_id, sas.cancel_reason,
)
self._in_room_sas.pop(request_event_id, None)
# -- SAS key-verification handlers ------------------------------------
async def _on_unknown_to_device(
self, event: UnknownToDeviceEvent,
) -> None:
"""Handle untyped to-device events (verification.request, verification.done)."""
if event.type == "m.key.verification.done":
logger.debug(
"Received m.key.verification.done from %s (source: %s)",
event.sender, event.source.get("content", {}),
)
return
if event.type != "m.key.verification.request":
return
if self.client is None:
return
content = event.source.get("content", {})
methods: list[str] = content.get("methods", [])
transaction_id: str | None = content.get("transaction_id")
from_device: str | None = content.get("from_device")
logger.debug("Verification request content: %s", content)
if "m.sas.v1" not in methods or not transaction_id or not from_device:
logger.debug(
"Ignoring verification request from %s: unsupported methods %s",
event.sender, methods,
)
return
logger.debug(
"SAS verification request from %s (device %s, tx %s)",
event.sender, from_device, transaction_id,
)
ready_content = {
"from_device": self.client.device_id,
"methods": ["m.sas.v1"],
"transaction_id": transaction_id,
}
ready_msg = ToDeviceMessage(
"m.key.verification.ready",
event.sender,
from_device,
ready_content,
)
await self.client.to_device(ready_msg)
async def _on_key_verification_start(
self, event: KeyVerificationStart,
) -> None:
"""Accept a SAS verification start via nio's internal Sas object."""
if self.client is None:
return
if event.method != "m.sas.v1":
return
sas = self.client.olm.key_verifications.get(event.transaction_id)
if sas is None or sas.canceled:
logger.debug(
"SAS start from %s/%s but no valid internal Sas (tx %s)",
event.sender, event.from_device, event.transaction_id,
)
return
await self.client.to_device(sas.accept_verification())
logger.debug(
"SAS verification accepted for %s/%s (tx %s)",
event.sender, event.from_device, event.transaction_id,
)
async def _on_key_verification_key(
self, event: KeyVerificationKey,
) -> None:
"""Log emojis, auto-accept, and send our MAC after key exchange."""
if self.client is None:
return
sas = self.client.olm.key_verifications.get(event.transaction_id)
if sas is None:
return
if sas.canceled:
logger.debug(
"SAS canceled after key exchange (tx %s): %s",
event.transaction_id, sas.cancel_reason,
)
return
# nio's internal handler already called receive_key_event() and
# queued share_key() in outgoing_to_device_messages — flush it.
await self.client.send_to_device_messages()
emojis = sas.get_emoji()
emoji_str = " ".join(f"{e} {d}" for e, d in emojis)
logger.debug(
"SAS emojis for tx %s: %s — auto-accepting",
event.transaction_id, emoji_str,
)
sas.accept_sas()
await self.client.to_device(sas.get_mac())
async def _on_key_verification_mac(
self, event: KeyVerificationMac,
) -> None:
"""Log the result after nio's internal handler verified the MAC."""
if self.client is None:
return
sas = self.client.olm.key_verifications.get(event.transaction_id)
if sas is None:
return
# nio's internal handler already called receive_mac_event() and
# verify_device() if the MAC was valid.
if sas.verified:
done_content = {"transaction_id": event.transaction_id}
done_msg = ToDeviceMessage(
"m.key.verification.done",
sas.other_olm_device.user_id,
sas.other_olm_device.id,
done_content,
)
resp = await self.client.to_device(done_msg)
logger.debug(
"Device %s/%s verified via SAS (tx %s), done sent -> %s",
sas.other_olm_device.user_id,
sas.other_olm_device.id,
event.transaction_id,
type(resp).__name__,
)
elif sas.canceled:
logger.debug(
"SAS MAC verification failed (tx %s): %s",
event.transaction_id, sas.cancel_reason,
)
async def _on_key_verification_cancel(
self, event: KeyVerificationCancel,
) -> None:
"""Log cancellation of an in-progress SAS verification."""
logger.debug(
"SAS verification canceled by %s (tx %s): [%s] %s",
event.sender, event.transaction_id,
event.code, event.reason,
)
async def _on_message(
self, room: MatrixRoom, event: RoomMessageText,
) -> None:
"""Convert a Matrix text message to IncomingMessage."""
assert self.client is not None
if event.sender == self.client.user_id:
return
text = event.body
attachments: list[Attachment] = []
# --- Resolve custom emojis as images --------------------------
cfg = self._config
if cfg is not None and cfg.resolve_emojis_as_images and text:
try:
source_content = (event.source or {}).get("content", {})
formatted_body = source_content.get("formatted_body", "")
if formatted_body:
from platforms.emoji_resolver import (
extract_matrix_emojis,
rewrite_matrix_emoji_text,
download_matrix_emojis,
)
emoji_matches = extract_matrix_emojis(formatted_body)
if emoji_matches:
emoji_atts = await download_matrix_emojis(
emoji_matches,
self.client,
max_emojis=cfg.max_emojis_per_message,
media_cache=self._media_cache,
)
if emoji_atts:
text = rewrite_matrix_emoji_text(text, emoji_matches[:cfg.max_emojis_per_message])
attachments.extend(emoji_atts)
logger.info(
"Resolved %d/%d Matrix custom emojis as images",
len(emoji_atts), len(emoji_matches),
)
except Exception:
logger.debug("Matrix emoji resolution failed", exc_info=True)
msg = IncomingMessage(
platform="matrix",
channel_id=room.room_id,
user_id=event.sender,
user_name=room.user_name(event.sender) or event.sender,
text=text,
is_addressed=self._is_addressed(room, event),
attachments=attachments,
channel_name=room.display_name,
timestamp=datetime.fromtimestamp(
event.server_timestamp / 1000, tz=timezone.utc,
),
message_id=event.event_id,
reply_to_id=_get_reply_to_id(event),
extra={
"bot_id": self.client.user_id,
"member_count": room.member_count,
"is_dm": room.member_count <= 2,
"is_server_admin": (
room.power_levels.get_user_level(event.sender) >= 50
if room.power_levels else False
),
},
)
logger.debug(
"[Matrix/%s] %s: %s",
room.display_name,
msg.user_name,
msg.text[:120],
)
await self._message_handler(msg, self)
async def _on_media(
self, room: MatrixRoom, event: MediaEvent,
) -> None:
"""Convert a Matrix media message to IncomingMessage."""
assert self.client is not None
if event.sender == self.client.user_id:
return
is_addressed = self._is_addressed(room, event)
caption = event.body or ""
logger.debug(
"[Matrix/%s] %s: [media] %s",
room.display_name,
room.user_name(event.sender),
caption[:120],
)
# Download the attachment (using cache when available)
attachments: list[Attachment] = []
mxc_url: str = event.url
try:
if self._media_cache is not None:
async def _download() -> tuple[bytes, str, str]:
"""Internal helper: download.
Returns:
tuple[bytes, str, str]: The result.
"""
return await download_matrix_media(self.client, event) # type: ignore[arg-type]
data, mimetype, filename = await self._media_cache.get_or_download(
mxc_url, _download,
)
else:
data, mimetype, filename = await download_matrix_media(
self.client, event,
)
attachments.append(Attachment(
data=data, mimetype=mimetype, filename=filename,
source_url=mxc_url,
))
except Exception:
logger.exception(
"Failed to download Matrix media from %s", event.sender,
)
msg = IncomingMessage(
platform="matrix",
channel_id=room.room_id,
user_id=event.sender,
user_name=room.user_name(event.sender) or event.sender,
text=caption,
is_addressed=is_addressed,
attachments=attachments,
channel_name=room.display_name,
timestamp=datetime.fromtimestamp(
event.server_timestamp / 1000, tz=timezone.utc,
),
message_id=event.event_id,
reply_to_id=_get_reply_to_id(event),
extra={
"bot_id": self.client.user_id,
"member_count": room.member_count,
"is_dm": room.member_count <= 2,
"is_server_admin": (
room.power_levels.get_user_level(event.sender) >= 50
if room.power_levels else False
),
},
)
await self._message_handler(msg, self)
async def _on_invite(
self, room: MatrixRoom, event: InviteMemberEvent,
) -> None:
"""Auto-accept room invites directed at the bot."""
assert self.client is not None
if event.state_key != self.client.user_id:
return
logger.info(
"Matrix: received invite to %s from %s – joining",
room.room_id, event.sender,
)
asyncio.ensure_future(self._join_room(room.room_id))
async def _join_room(self, room_id: str) -> None:
"""Join a room with retry logic to handle invite-commit race conditions."""
assert self.client is not None
for attempt in range(4):
delay = 1 if attempt == 0 else 2 ** attempt # 1, 2, 4, 8 s
await asyncio.sleep(delay)
result = await self.client.join(room_id)
if not isinstance(result, JoinError):
logger.info("Joined Matrix room %s", room_id)
return
logger.warning(
"Join attempt %d for %s failed: %s", attempt + 1, room_id, result,
)
logger.error("Giving up joining %s after 4 attempts", room_id)
async def _on_megolm_event(
self, room: MatrixRoom, event: MegolmEvent,
) -> None:
"""Request missing room key for undecryptable Megolm messages."""
logger.warning(
"Matrix: unable to decrypt message in %s from %s (session %s). "
"Requesting missing room key.",
room.room_id, event.sender, event.session_id,
)
if self.client is not None:
try:
await self.client.request_room_key(event)
await self.client.send_to_device_messages()
logger.debug(
"Sent key request for session %s in %s",
event.session_id, room.room_id,
)
except Exception:
logger.debug(
"Could not send key request for session %s",
event.session_id, exc_info=True,
)
# -- Internal: helpers ---------------------------------------------
def _is_addressed(
self, room: MatrixRoom, event: object,
) -> bool:
"""Return ``True`` when the bot should respond.
The bot responds if any of the following hold:
* The room has two or fewer members (DM / small room).
* The bot's user ID appears in ``m.mentions.user_ids``.
* The bot's user ID appears in the plain-text body.
* The message is a reply to an event the bot sent.
"""
assert self.client is not None
# DM / small room -- always respond
if room.member_count <= 2:
return True
source = getattr(event, "source", None) or {}
content = source.get("content", {})
# Modern m.mentions spec
mentions = content.get("m.mentions", {})
if self.client.user_id in mentions.get("user_ids", []):
return True
# Fallback: user ID appears in body text
body = getattr(event, "body", "") or ""
if self.client.user_id in body:
return True
# Reply to one of the bot's own messages
relates = content.get("m.relates_to", {})
reply_to = relates.get("m.in_reply_to", {})
reply_event_id = reply_to.get("event_id")
if reply_event_id:
sent = self._sent_events.get(room.room_id, set())
if reply_event_id in sent:
return True
return False
async def _sync_loop(self) -> None:
"""Run the Matrix sync loop until stopped."""
backoff = 5.0
max_backoff = 60.0
try:
while not self._stop_event.is_set():
try:
await self.client.sync( # type: ignore[union-attr]
timeout=30000,
)
backoff = 5.0 # reset on success
trust_all_devices(
self.client, # type: ignore[arg-type]
)
if self.client.should_upload_keys: # type: ignore[union-attr]
await self.client.keys_upload() # type: ignore[union-attr]
if self.client.should_query_keys: # type: ignore[union-attr]
await self.client.keys_query() # type: ignore[union-attr]
if self.client.should_claim_keys: # type: ignore[union-attr]
await self.client.keys_claim( # type: ignore[union-attr]
self.client.get_users_for_key_claiming(), # type: ignore[union-attr]
)
await self.client.send_to_device_messages() # type: ignore[union-attr]
except asyncio.CancelledError:
raise
except (
TimeoutError,
aiohttp.ClientError,
ConnectionError,
OSError,
) as e:
logger.warning(
"Matrix sync transient error, retrying in %.0fs: %s",
backoff,
e,
)
await asyncio.sleep(backoff)
backoff = min(backoff * 2, max_backoff)
continue
except Exception:
logger.exception("Matrix sync loop encountered an error")
break
except asyncio.CancelledError:
pass
finally:
logger.info("Matrix sync loop exited")