Source code for core.outbound_consumer

"""Outbound stream consumer for gateway nodes."""

import asyncio
import logging
import base64
from typing import Any

from redis.asyncio import Redis

from core.event_bus import OUTBOUND_STREAM_PREFIX, GATEWAY_GROUP_PREFIX
from core.serialization import deserialize_stream_payload

logger = logging.getLogger("stargazer.outbound_consumer")


[docs] class OutboundStreamConsumer: """Consumes outbound response envelopes and dispatches to local adapters. One instance per platform per gateway node. """
[docs] def __init__( self, redis: Redis, platform: str, adapter: Any, # PlatformAdapter consumer_name: str, autoclaim_interval: float = 30.0, autoclaim_min_idle: int = 60000, tools_dir: str | None = None, ) -> None: """Initialize a per-platform outbound consumer and derive its stream/group names. Stores the injected collaborators and computes the outbound stream name ``{OUTBOUND_STREAM_PREFIX}:{platform}`` and consumer-group name ``{GATEWAY_GROUP_PREFIX}:{platform}`` that the consume and autoclaim loops read from. Background-task handles and the active-dispatch task set are left empty here; nothing is scheduled until :meth:`start` is called. The ``ToolRegistry`` is left unbuilt (``None``) and only lazily compiled when an ``execute_tool`` RPC first arrives. Constructed once per platform per gateway node, primarily by :class:`gateway_main.GatewayRunner` in its ``on_start`` adapter loop (and directly by the test suite under ``tests/core``). Args: redis: Async Redis client used for stream reads, ACKs, locks, the idempotency/claim keys, and RPC reply queues. platform: Platform name (e.g. ``"discord"``); used to build the stream and group names and as a log dimension. adapter: The concrete ``PlatformAdapter`` whose send/typing/reaction/ presence and RPC methods are invoked to actually reach the platform. consumer_name: Unique consumer name within the group (typically the gateway instance id) used for ``xreadgroup``/``xautoclaim``. autoclaim_interval: Seconds to sleep between autoclaim sweeps. autoclaim_min_idle: Minimum idle time in milliseconds before a pending entry is eligible to be reclaimed from another consumer. tools_dir: Directory to load delegated tools from for ``execute_tool`` RPCs; defaults to ``"tools"`` when not provided. """ self._redis = redis self._platform = platform self._adapter = adapter self._consumer_name = consumer_name self._autoclaim_interval = autoclaim_interval self._autoclaim_min_idle = autoclaim_min_idle self._stream = f"{OUTBOUND_STREAM_PREFIX}:{platform}" self._group = f"{GATEWAY_GROUP_PREFIX}:{platform}" self._tools_dir = tools_dir or "tools" self._tool_registry = None self._running = False self._task: asyncio.Task | None = None self._autoclaim_task: asyncio.Task | None = None self._active_tasks: set[asyncio.Task] = set()
[docs] async def start(self) -> None: """Launch the background consume and autoclaim loops for this platform. Flips ``_running`` true and schedules :meth:`_consume_loop` (live stream reads) and :meth:`_autoclaim_loop` (recovery of stranded entries) as named asyncio tasks, then emits a startup log line. Returns immediately; the loops run until :meth:`stop` cancels them. Called by :class:`gateway_main.GatewayRunner` in its ``on_start`` adapter setup loop, deliberately before each ``adapter.start()`` so the response path exists before the first inbound message can arrive (also exercised directly in the test suite). """ self._running = True self._task = asyncio.create_task( self._consume_loop(), name=f"outbound_{self._platform}" ) self._autoclaim_task = asyncio.create_task( self._autoclaim_loop(), name=f"outbound_autoclaim_{self._platform}" ) logger.info( "OutboundStreamConsumer started", extra={"platform": self._platform, "stream": self._stream}, )
[docs] async def stop(self) -> None: """Stop the consumer, cancelling its loops and draining in-flight dispatches. Clears ``_running`` so the loops exit, cancels the consume and autoclaim tasks, awaits the (cancelled) consume task to absorb its ``CancelledError``, and finally gathers any still-running per-message dispatch tasks tracked in ``_active_tasks`` so outbound sends that are mid-flight are allowed to finish before shutdown completes. Called by :class:`gateway_main.GatewayRunner` during graceful teardown, which iterates ``self.outbound_consumers`` and awaits ``consumer.stop()`` for each (also covered by ``tests/core/test_outbound_consumer.py``). """ self._running = False if self._task: self._task.cancel() if self._autoclaim_task: self._autoclaim_task.cancel() try: await self._task except asyncio.CancelledError: pass if self._active_tasks: logger.info("Awaiting %d active outbound dispatch tasks...", len(self._active_tasks)) await asyncio.gather(*self._active_tasks, return_exceptions=True)
async def _consume_loop(self) -> None: """Continuously read new outbound entries and fan them out to dispatch tasks. The hot path of the consumer: blocks on ``xreadgroup`` over ``_stream`` for this consumer group (up to 10 entries, 5s block), and for every delivered entry spawns a :meth:`_handle_dispatch_task` task tracked in ``_active_tasks`` (with a done-callback that discards it) so independent channels can be dispatched concurrently while each channel is serialized by a Redis lock downstream. Honours cancellation by breaking out, and logs and backs off two seconds on any other unexpected error so the loop self-heals. Started as a background task by :meth:`start`; not called directly except by the consume-loop unit tests. """ while self._running: try: messages = await self._redis.xreadgroup( self._group, self._consumer_name, {self._stream: ">"}, count=10, block=5000, ) if not messages: continue for _stream_name, entries in messages: for msg_id, raw in entries: msg_id_str = msg_id.decode() if isinstance(msg_id, bytes) else str(msg_id) task = asyncio.create_task( self._handle_dispatch_task(msg_id_str, raw), name=f"outbound_msg_{msg_id_str}", ) self._active_tasks.add(task) task.add_done_callback(self._active_tasks.discard) except asyncio.CancelledError: break except Exception: logger.critical("Outbound consumer loop crashed", exc_info=True) await asyncio.sleep(2) async def _autoclaim_loop(self) -> None: """Recover outbound stream entries stranded in the consumer-group PEL. Safety net for the consume path: if a consumer crashes after reading but before acking, its delivered-but-unprocessed entries linger in the pending-entries list and would otherwise never be sent. Every ``_autoclaim_interval`` seconds this sweeps ``xautoclaim`` over ``_stream``/``_group`` for entries idle past ``_autoclaim_min_idle`` and re-drives each one through a fresh :meth:`_handle_dispatch_task` (tracked in ``_active_tasks``), so a stalled message resumes on a healthy consumer. To bound poison-message replay it ``hincrby``\\ s a per-message attempt count in the ``sg:autoclaim:outbound:attempts`` Redis hash; after three attempts the entry is routed to the dead-letter queue via ``core.dlq.handle_failed_message`` and its counter ``hdel``\\ ed instead of being retried. Honours cancellation by breaking out and logs (without backing off) on any other error so the sweep keeps running. Started as a background task by :meth:`start`; invoked directly only by the autoclaim/PEL-leak unit tests (e.g. ``tests/adversarial/test_pel_leak_protection.py``, ``tests/core/test_stream_idempotency.py``). """ from core.dlq import handle_failed_message while self._running: try: await asyncio.sleep(self._autoclaim_interval) result = await self._redis.xautoclaim( self._stream, self._group, self._consumer_name, min_idle_time=self._autoclaim_min_idle, start_id="0-0", count=10, ) if not result or len(result) < 2: continue claimed = result[1] if not claimed: continue logger.warning( "Autoclaimed %d orphaned outbound messages for %s", len(claimed), self._platform, ) for msg_id, raw in claimed: msg_id_str = ( msg_id.decode() if isinstance(msg_id, bytes) else str(msg_id) ) attempts = await self._redis.hincrby( "sg:autoclaim:outbound:attempts", msg_id_str, 1, ) if attempts > 3: await handle_failed_message( self._redis, self._stream, self._group, msg_id_str, raw, ValueError(f"Max outbound claim attempts exceeded: {attempts}"), attempts, ) await self._redis.hdel( "sg:autoclaim:outbound:attempts", msg_id_str, ) continue task = asyncio.create_task( self._handle_dispatch_task(msg_id_str, raw), name=f"outbound_autoclaim_{msg_id_str}", ) self._active_tasks.add(task) task.add_done_callback(self._active_tasks.discard) except asyncio.CancelledError: break except Exception: logger.error("Outbound autoclaim loop error", exc_info=True) async def _handle_rpc_request(self, payload: dict[str, Any]) -> None: """Execute a delegated adapter RPC on the gateway and return the result. Serves the gateway side of the cross-service RPC bridge: the inference worker's :class:`core.proxy_adapter.ProxyPlatformAdapter` cannot touch the live platform connection, so it publishes ``rpc_request`` envelopes onto the outbound stream and blocks on a Redis list reply key. This method runs the requested method against the real local ``self._adapter`` and pushes the JSON-encoded result back. Dispatches on ``rpc_method``: ``fetch_channel_history`` calls ``adapter.fetch_history`` and serializes each message (via ``to_dict`` or a hand-built dict); ``should_skip_channel_heartbeat`` calls the like-named adapter method; and ``execute_tool`` reconstructs a :class:`tool_context.ToolContext` (wiring in this adapter and Redis), lazily compiles a :class:`tools.ToolRegistry` by ``load_tools`` from ``_tools_dir`` (off-thread) on first use, and runs the named tool. Any exception is caught and turned into an ``{"error": ...}`` result rather than propagated. The result is delivered by ``rpush`` to ``reply_key`` with a 60-second ``expire`` so the proxy's ``blpop`` unblocks; a missing ``reply_key`` is logged and dropped. Called by :meth:`_handle_dispatch_task` when a deserialized payload has ``type == "rpc_request"`` (also covered by ``tests/core/migration/test_gateway_tool_execution.py``). Args: payload: Deserialized RPC envelope carrying ``rpc_method``/``action``, ``rpc_args``/``kwargs``, ``reply_key`` and ``channel_id``. """ rpc_method = payload.get("rpc_method") or payload.get("action") rpc_args = payload.get("rpc_args") or payload.get("kwargs") or {} reply_key = payload.get("reply_key") channel_id = payload.get("channel_id") if not reply_key: logger.error("RPC request missing reply_key") return logger.info( "Handling RPC request %s on platform %s for channel %s", rpc_method, self._platform, channel_id, ) import json result_value = None try: if rpc_method == "fetch_channel_history": limit = rpc_args.get("limit", 10) history = await self._adapter.fetch_history(channel_id, limit=limit) result_value = [] for msg in history: if hasattr(msg, "to_dict"): result_value.append(msg.to_dict()) else: result_value.append({ "user_id": msg.user_id, "user_name": msg.user_name, "text": msg.text, "timestamp": msg.timestamp.timestamp() if hasattr(msg.timestamp, "timestamp") else msg.timestamp, "message_id": msg.message_id, "is_bot": msg.is_bot, "reply_to_id": msg.reply_to_id, "reactions": msg.reactions, }) elif rpc_method == "should_skip_channel_heartbeat": result_value = await self._adapter.should_skip_channel_heartbeat(channel_id) elif rpc_method == "delete_egregore_webhook": # Worker → gateway: delete an egregore webhook on the live client # (the worker's ProxyPlatformAdapter has no .client). Used by the # !snes_end / !snes_blow cleanup paths. from tools._egregore_discord import delete_egregore_webhook_by_id client = ( getattr(self._adapter, "_client", None) or getattr(self._adapter, "client", None) ) ok, err = await delete_egregore_webhook_by_id( client, rpc_args.get("guild_id"), rpc_args.get("webhook_id"), ) result_value = {"ok": ok, "error": err} elif rpc_method == "execute_tool": tool_name = rpc_args.get("tool_name") tool_args = rpc_args.get("tool_args") or {} tool_ctx_dict = rpc_args.get("tool_ctx") or {} # Reconstruct ToolContext from tool_context import ToolContext t_ctx = ToolContext( platform=tool_ctx_dict.get("platform") or self._platform, channel_id=tool_ctx_dict.get("channel_id") or channel_id or "", user_id=tool_ctx_dict.get("user_id") or "", user_name=tool_ctx_dict.get("user_name") or "", guild_id=tool_ctx_dict.get("guild_id") or "", message_id=tool_ctx_dict.get("message_id") or "", adapter=self._adapter, redis=self._redis, config=getattr(self._adapter, "config", None), ) # Lazy compile ToolRegistry if self._tool_registry is None: from tools import ToolRegistry from tool_loader import load_tools self._tool_registry = ToolRegistry() tools_dir = self._tools_dir config_obj = getattr(self._adapter, "config", None) if config_obj and hasattr(config_obj, "tools_dir"): tools_dir = config_obj.tools_dir logger.info("OutboundStreamConsumer lazy-loading tools from %s", tools_dir) await asyncio.to_thread(load_tools, tools_dir, self._tool_registry) logger.info("OutboundStreamConsumer executing delegated tool %s", tool_name) result_value = await self._tool_registry.call( name=tool_name, arguments=tool_args, user_id=t_ctx.user_id, ctx=t_ctx, ) else: logger.error("Unknown RPC method: %s", rpc_method) result_value = {"error": f"Unknown RPC method: {rpc_method}"} except Exception as e: logger.exception("Error processing RPC request %s", rpc_method) result_value = {"error": str(e)} await self._redis.rpush(reply_key, json.dumps(result_value)) await self._redis.expire(reply_key, 60) async def _handle_dispatch_task(self, msg_id: str, raw: dict) -> None: """Process one outbound stream entry, serialized per channel and deduplicated. The per-message unit of work spawned by the consume and autoclaim loops. It deserializes the entry via :func:`core.serialization.deserialize_stream_payload`; an ``rpc_request`` envelope is short-circuited to :meth:`_handle_rpc_request` and acked. Otherwise it acquires a :class:`core.distributed_lock.DistributedLock` on ``sg:lock:outbound:{platform}:{channel_id}`` so concurrent tasks for the same channel send in order while different channels proceed in parallel, then runs :meth:`_dispatch` and :meth:`_ack_outbound`. Two Redis-backed idempotency guards prevent double-sends across redelivery or lock failover: a completion-marker ``hget`` of ``message_id`` on the payload's ``message_key`` skips entries already delivered, and an atomic ``set`` with ``nx`` on ``sg:outbound:claim:{message_key}`` (300s TTL) claims the send before it happens, releasing the claim with ``delete`` if :meth:`_dispatch` raises so a legitimate retry can proceed. The lock is always released in a ``finally``. Dispatch errors are logged and swallowed so one bad message cannot wedge the loop. Scheduled as a task by :meth:`_consume_loop` and :meth:`_autoclaim_loop`; also driven directly by the concurrency and idempotency tests (e.g. ``tests/core/test_concurrent_consumers.py``, ``tests/core/test_stream_idempotency.py``). Args: msg_id: Redis stream entry id, used for acking and claim bookkeeping. raw: The raw stream field map to deserialize into an outbound payload. """ try: payload = deserialize_stream_payload(raw) # Intercept RPC requests before lock acquisition if payload.get("type") == "rpc_request": await self._handle_rpc_request(payload) await self._ack_outbound(msg_id) return channel_id = payload.get("channel_id") or "global" except Exception: # If payload cannot be deserialized, call direct dispatch to trigger standard error handling/ACKs try: payload = deserialize_stream_payload(raw) await self._dispatch(payload) await self._ack_outbound(msg_id) except Exception: logger.error( "Failed to dispatch outbound message", exc_info=True, extra={"stream_msg_id": msg_id}, ) return from core.distributed_lock import DistributedLock lock = DistributedLock( self._redis, f"sg:lock:outbound:{self._platform}:{channel_id}", ttl=30, auto_renew=True, ) try: await lock.acquire_blocking() # Idempotency check (completion marker): # If the platform sent ID already exists in the Redis message cache under message_key, # this message has already been successfully dispatched to Discord/Matrix. message_key = payload.get("message_key") if message_key: sent_id = await self._redis.hget(message_key, "message_id") if sent_id: sent_id_str = sent_id.decode() if isinstance(sent_id, bytes) else str(sent_id) if sent_id_str.strip(): logger.info( "Outbound message already dispatched (idempotency check hit)", extra={ "stream_msg_id": msg_id, "message_key": message_key, "sent_id": sent_id_str, }, ) await self._ack_outbound(msg_id) return # Atomic pre-dispatch claim: closes the check-then-act window so a # duplicate send cannot happen if the per-channel lock ever fails to # provide mutual exclusion (e.g. a Redis failover split-brain) or the # stream entry is redelivered concurrently. The completion marker # above still guards against re-sending after a successful dispatch. claim_key = None if message_key: claim_key = f"sg:outbound:claim:{message_key}" claimed = await self._redis.set(claim_key, msg_id, nx=True, ex=300) if not claimed: logger.info( "Outbound message already claimed; skipping duplicate dispatch", extra={ "stream_msg_id": msg_id, "message_key": message_key, }, ) await self._ack_outbound(msg_id) return try: await self._dispatch(payload) await self._ack_outbound(msg_id) except Exception: # A propagated exception means the send did not complete (the # post-send marker write is swallowed inside _dispatch), so it # is safe to release the claim for a legitimate retry. if claim_key: try: await self._redis.delete(claim_key) except Exception: pass raise except Exception: logger.error( "Failed to dispatch outbound message", exc_info=True, extra={"stream_msg_id": msg_id}, ) finally: await lock.release() async def _ack_outbound(self, msg_id: str) -> None: """Acknowledge an outbound stream entry and clear its autoclaim counter. ``xack``\\ s ``msg_id`` against ``_stream``/``_group`` so it leaves the pending entries list, then best-effort ``hdel``\\ s its attempt count from the ``sg:autoclaim:outbound:attempts`` hash (swallowing errors) so a successfully handled message is not later treated as a retry by the autoclaim path. Called by :meth:`_handle_dispatch_task` after every terminal outcome (successful dispatch, idempotency/claim skip, or completed RPC) and never on its own. Args: msg_id: Stream entry id to acknowledge. """ await self._redis.xack(self._stream, self._group, msg_id) try: await self._redis.hdel("sg:autoclaim:outbound:attempts", msg_id) except Exception: pass async def _dispatch(self, payload: dict[str, Any]) -> None: """Translate an outbound envelope into the matching live-adapter call. The bridge between a serialized envelope and the real platform client. Branching on the payload ``type`` it invokes the corresponding method on ``self._adapter``: ``message`` to ``send``, ``file`` (base64-decoding ``file_data`` first) to ``send_file``, ``buttons`` to ``send_with_buttons``, ``typing``/``start_typing``/``stop_typing`` to the typing methods (feature- detected with ``hasattr``), and ``reaction`` to ``add_reaction``. Presence is handled up front before the ``channel_id`` guard because it is a global, channel-less action; a payload missing ``channel_id`` for any other type is logged and dropped, and an unknown type is logged as a warning. On a successful send that returns a platform message id, and when the payload carries a ``message_key``, it writes that id back to the Redis message cache via ``hset`` (self-healing the key with a 90-day ``expire`` if it did not already exist) so the idempotency completion marker in :meth:`_handle_dispatch_task` and later lookups can find it; a failure to re-sync is logged but not raised. Called by :meth:`_handle_dispatch_task` (both the normal and deserialize-fallback paths) and directly by the outbound/dispatch unit tests (e.g. ``tests/core/test_outbound_consumer.py``, ``tests/core/test_message_id_resync.py``). Args: payload: The deserialized outbound envelope; its ``type`` selects the adapter method and the remaining fields supply that method's args. """ msg_type = payload.get("type", "message") channel_id = payload.get("channel_id") # Presence is a global (non-channel) action — handle it before the # channel_id guard, which would otherwise drop it. if msg_type == "presence": logger.info( "Dispatching outbound presence", extra={"platform": self._platform}, ) if hasattr(self._adapter, "set_presence"): await self._adapter.set_presence( payload.get("text", ""), payload.get("emoji"), ) else: logger.warning( "Adapter %s has no set_presence; dropping presence update", self._platform, ) return if not channel_id: logger.error("Outbound payload missing channel_id") return logger.info( "Dispatching outbound %s", msg_type, extra={ "channel_id": channel_id, "platform": self._platform, "type": msg_type, }, ) sent_id = None if msg_type == "message": # 💀🔥 Attach persistent ✧ toggle button to every message if enabled _toggle_view = None if self._platform in ("discord", "discord-self"): try: import feature_toggles channel_key = f"{self._platform}:{channel_id}" # Resolve guild_id from Discord channel cache # 💀🔥 _guild_id = None try: _client = getattr(self._adapter, "_client", None) if _client is not None: _ch = _client.get_channel(int(channel_id)) if _ch is not None and hasattr(_ch, "guild") and _ch.guild: _guild_id = str(_ch.guild.id) except Exception: pass disabled = await feature_toggles.is_disabled_resolving_discord_aliases( self._redis, "toggle_menu", channel_key, guild_id=_guild_id, ) if not disabled: from star_toggle_ui import build_star_toggle_view _toggle_view = build_star_toggle_view(channel_id) except Exception: logger.debug("star_toggle_ui or toggle checks failed", exc_info=True) if _toggle_view is not None: sent_id = await self._adapter.send_with_buttons( channel_id, payload["text"], _toggle_view, ) else: sent_id = await self._adapter.send(channel_id, payload["text"]) elif msg_type == "file": file_data = payload["file_data"] decoded_data = base64.b64decode(file_data) if isinstance(file_data, str) else file_data sent_id = await self._adapter.send_file( channel_id, decoded_data, payload.get("filename", "file"), payload.get("mimetype", "application/octet-stream"), ) elif msg_type == "buttons": sent_id = await self._adapter.send_with_buttons( channel_id, payload["text"], payload["buttons"], ) elif msg_type in ("typing", "start_typing"): if hasattr(self._adapter, "start_typing"): await self._adapter.start_typing(channel_id) elif hasattr(self._adapter, "send_typing"): await self._adapter.send_typing(channel_id) elif msg_type == "stop_typing": if hasattr(self._adapter, "stop_typing"): await self._adapter.stop_typing(channel_id) elif msg_type == "reaction": if hasattr(self._adapter, "add_reaction"): await self._adapter.add_reaction( channel_id, payload["message_id"], payload["emoji"], ) else: logger.warning( "Adapter %s has no add_reaction; dropping reaction envelope", getattr(self._adapter, "name", "?"), ) else: logger.warning("Unknown outbound message type: %s", msg_type) message_key = payload.get("message_key") if sent_id and message_key: try: # Check if the cache entry exists first (merging behavior) exists = await self._redis.exists(message_key) await self._redis.hset(message_key, "message_id", str(sent_id)) if not exists: # Self-heal and apply a 90-day TTL (same as standard messages) # to prevent any orphaned keys if this was written before log_message. await self._redis.expire(message_key, 90 * 86400) logger.info( "Re-synced platform message_id to Redis message cache", extra={ "message_key": message_key, "message_id": sent_id, "platform": self._platform, "cache_existed": bool(exists), }, ) except Exception: logger.warning( "Failed to re-sync platform message_id to Redis cache", exc_info=True, extra={ "message_key": message_key, "message_id": sent_id, }, )