"""Outbound stream consumer for gateway nodes."""
import asyncio
import logging
import base64
from typing import Any
from redis.asyncio import Redis
from core.event_bus import OUTBOUND_STREAM_PREFIX, GATEWAY_GROUP_PREFIX
from core.serialization import deserialize_stream_payload
logger = logging.getLogger("stargazer.outbound_consumer")
[docs]
class OutboundStreamConsumer:
"""Consumes outbound response envelopes and dispatches to local adapters.
One instance per platform per gateway node.
"""
[docs]
def __init__(
self,
redis: Redis,
platform: str,
adapter: Any, # PlatformAdapter
consumer_name: str,
autoclaim_interval: float = 30.0,
autoclaim_min_idle: int = 60000,
tools_dir: str | None = None,
) -> None:
"""Initialize a per-platform outbound consumer and derive its stream/group names.
Stores the injected collaborators and computes the outbound stream name
``{OUTBOUND_STREAM_PREFIX}:{platform}`` and consumer-group name
``{GATEWAY_GROUP_PREFIX}:{platform}`` that the consume and autoclaim loops
read from. Background-task handles and the active-dispatch task set are left
empty here; nothing is scheduled until :meth:`start` is called. The
``ToolRegistry`` is left unbuilt (``None``) and only lazily compiled when an
``execute_tool`` RPC first arrives.
Constructed once per platform per gateway node, primarily by
:class:`gateway_main.GatewayRunner` in its ``on_start`` adapter loop (and
directly by the test suite under ``tests/core``).
Args:
redis: Async Redis client used for stream reads, ACKs, locks, the
idempotency/claim keys, and RPC reply queues.
platform: Platform name (e.g. ``"discord"``); used to build the stream
and group names and as a log dimension.
adapter: The concrete ``PlatformAdapter`` whose send/typing/reaction/
presence and RPC methods are invoked to actually reach the platform.
consumer_name: Unique consumer name within the group (typically the
gateway instance id) used for ``xreadgroup``/``xautoclaim``.
autoclaim_interval: Seconds to sleep between autoclaim sweeps.
autoclaim_min_idle: Minimum idle time in milliseconds before a pending
entry is eligible to be reclaimed from another consumer.
tools_dir: Directory to load delegated tools from for ``execute_tool``
RPCs; defaults to ``"tools"`` when not provided.
"""
self._redis = redis
self._platform = platform
self._adapter = adapter
self._consumer_name = consumer_name
self._autoclaim_interval = autoclaim_interval
self._autoclaim_min_idle = autoclaim_min_idle
self._stream = f"{OUTBOUND_STREAM_PREFIX}:{platform}"
self._group = f"{GATEWAY_GROUP_PREFIX}:{platform}"
self._tools_dir = tools_dir or "tools"
self._tool_registry = None
self._running = False
self._task: asyncio.Task | None = None
self._autoclaim_task: asyncio.Task | None = None
self._active_tasks: set[asyncio.Task] = set()
[docs]
async def start(self) -> None:
"""Launch the background consume and autoclaim loops for this platform.
Flips ``_running`` true and schedules :meth:`_consume_loop` (live stream
reads) and :meth:`_autoclaim_loop` (recovery of stranded entries) as named
asyncio tasks, then emits a startup log line. Returns immediately; the loops
run until :meth:`stop` cancels them.
Called by :class:`gateway_main.GatewayRunner` in its ``on_start`` adapter
setup loop, deliberately before each ``adapter.start()`` so the response
path exists before the first inbound message can arrive (also exercised
directly in the test suite).
"""
self._running = True
self._task = asyncio.create_task(
self._consume_loop(), name=f"outbound_{self._platform}"
)
self._autoclaim_task = asyncio.create_task(
self._autoclaim_loop(), name=f"outbound_autoclaim_{self._platform}"
)
logger.info(
"OutboundStreamConsumer started",
extra={"platform": self._platform, "stream": self._stream},
)
[docs]
async def stop(self) -> None:
"""Stop the consumer, cancelling its loops and draining in-flight dispatches.
Clears ``_running`` so the loops exit, cancels the consume and autoclaim
tasks, awaits the (cancelled) consume task to absorb its
``CancelledError``, and finally gathers any still-running per-message
dispatch tasks tracked in ``_active_tasks`` so outbound sends that are
mid-flight are allowed to finish before shutdown completes.
Called by :class:`gateway_main.GatewayRunner` during graceful teardown,
which iterates ``self.outbound_consumers`` and awaits ``consumer.stop()``
for each (also covered by ``tests/core/test_outbound_consumer.py``).
"""
self._running = False
if self._task:
self._task.cancel()
if self._autoclaim_task:
self._autoclaim_task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
if self._active_tasks:
logger.info("Awaiting %d active outbound dispatch tasks...", len(self._active_tasks))
await asyncio.gather(*self._active_tasks, return_exceptions=True)
async def _consume_loop(self) -> None:
"""Continuously read new outbound entries and fan them out to dispatch tasks.
The hot path of the consumer: blocks on ``xreadgroup`` over ``_stream`` for
this consumer group (up to 10 entries, 5s block), and for every delivered
entry spawns a :meth:`_handle_dispatch_task` task tracked in
``_active_tasks`` (with a done-callback that discards it) so independent
channels can be dispatched concurrently while each channel is serialized by
a Redis lock downstream. Honours cancellation by breaking out, and logs and
backs off two seconds on any other unexpected error so the loop self-heals.
Started as a background task by :meth:`start`; not called directly except by
the consume-loop unit tests.
"""
while self._running:
try:
messages = await self._redis.xreadgroup(
self._group,
self._consumer_name,
{self._stream: ">"},
count=10,
block=5000,
)
if not messages:
continue
for _stream_name, entries in messages:
for msg_id, raw in entries:
msg_id_str = msg_id.decode() if isinstance(msg_id, bytes) else str(msg_id)
task = asyncio.create_task(
self._handle_dispatch_task(msg_id_str, raw),
name=f"outbound_msg_{msg_id_str}",
)
self._active_tasks.add(task)
task.add_done_callback(self._active_tasks.discard)
except asyncio.CancelledError:
break
except Exception:
logger.critical("Outbound consumer loop crashed", exc_info=True)
await asyncio.sleep(2)
async def _autoclaim_loop(self) -> None:
"""Recover outbound stream entries stranded in the consumer-group PEL.
Safety net for the consume path: if a consumer crashes after reading but
before acking, its delivered-but-unprocessed entries linger in the
pending-entries list and would otherwise never be sent. Every
``_autoclaim_interval`` seconds this sweeps ``xautoclaim`` over
``_stream``/``_group`` for entries idle past ``_autoclaim_min_idle`` and
re-drives each one through a fresh :meth:`_handle_dispatch_task` (tracked
in ``_active_tasks``), so a stalled message resumes on a healthy consumer.
To bound poison-message replay it ``hincrby``\\ s a per-message attempt
count in the ``sg:autoclaim:outbound:attempts`` Redis hash; after three
attempts the entry is routed to the dead-letter queue via
``core.dlq.handle_failed_message`` and its counter ``hdel``\\ ed instead of
being retried. Honours cancellation by breaking out and logs (without
backing off) on any other error so the sweep keeps running.
Started as a background task by :meth:`start`; invoked directly only by
the autoclaim/PEL-leak unit tests (e.g.
``tests/adversarial/test_pel_leak_protection.py``,
``tests/core/test_stream_idempotency.py``).
"""
from core.dlq import handle_failed_message
while self._running:
try:
await asyncio.sleep(self._autoclaim_interval)
result = await self._redis.xautoclaim(
self._stream,
self._group,
self._consumer_name,
min_idle_time=self._autoclaim_min_idle,
start_id="0-0",
count=10,
)
if not result or len(result) < 2:
continue
claimed = result[1]
if not claimed:
continue
logger.warning(
"Autoclaimed %d orphaned outbound messages for %s",
len(claimed),
self._platform,
)
for msg_id, raw in claimed:
msg_id_str = (
msg_id.decode() if isinstance(msg_id, bytes) else str(msg_id)
)
attempts = await self._redis.hincrby(
"sg:autoclaim:outbound:attempts",
msg_id_str,
1,
)
if attempts > 3:
await handle_failed_message(
self._redis,
self._stream,
self._group,
msg_id_str,
raw,
ValueError(f"Max outbound claim attempts exceeded: {attempts}"),
attempts,
)
await self._redis.hdel(
"sg:autoclaim:outbound:attempts",
msg_id_str,
)
continue
task = asyncio.create_task(
self._handle_dispatch_task(msg_id_str, raw),
name=f"outbound_autoclaim_{msg_id_str}",
)
self._active_tasks.add(task)
task.add_done_callback(self._active_tasks.discard)
except asyncio.CancelledError:
break
except Exception:
logger.error("Outbound autoclaim loop error", exc_info=True)
async def _handle_rpc_request(self, payload: dict[str, Any]) -> None:
"""Execute a delegated adapter RPC on the gateway and return the result.
Serves the gateway side of the cross-service RPC bridge: the inference
worker's :class:`core.proxy_adapter.ProxyPlatformAdapter` cannot touch the
live platform connection, so it publishes ``rpc_request`` envelopes onto the
outbound stream and blocks on a Redis list reply key. This method runs the
requested method against the real local ``self._adapter`` and pushes the
JSON-encoded result back.
Dispatches on ``rpc_method``: ``fetch_channel_history`` calls
``adapter.fetch_history`` and serializes each message (via ``to_dict`` or a
hand-built dict); ``should_skip_channel_heartbeat`` calls the like-named
adapter method; and ``execute_tool`` reconstructs a
:class:`tool_context.ToolContext` (wiring in this adapter and Redis),
lazily compiles a :class:`tools.ToolRegistry` by ``load_tools`` from
``_tools_dir`` (off-thread) on first use, and runs the named tool. Any
exception is caught and turned into an ``{"error": ...}`` result rather than
propagated. The result is delivered by ``rpush`` to ``reply_key`` with a
60-second ``expire`` so the proxy's ``blpop`` unblocks; a missing
``reply_key`` is logged and dropped.
Called by :meth:`_handle_dispatch_task` when a deserialized payload has
``type == "rpc_request"`` (also covered by
``tests/core/migration/test_gateway_tool_execution.py``).
Args:
payload: Deserialized RPC envelope carrying ``rpc_method``/``action``,
``rpc_args``/``kwargs``, ``reply_key`` and ``channel_id``.
"""
rpc_method = payload.get("rpc_method") or payload.get("action")
rpc_args = payload.get("rpc_args") or payload.get("kwargs") or {}
reply_key = payload.get("reply_key")
channel_id = payload.get("channel_id")
if not reply_key:
logger.error("RPC request missing reply_key")
return
logger.info(
"Handling RPC request %s on platform %s for channel %s",
rpc_method,
self._platform,
channel_id,
)
import json
result_value = None
try:
if rpc_method == "fetch_channel_history":
limit = rpc_args.get("limit", 10)
history = await self._adapter.fetch_history(channel_id, limit=limit)
result_value = []
for msg in history:
if hasattr(msg, "to_dict"):
result_value.append(msg.to_dict())
else:
result_value.append({
"user_id": msg.user_id,
"user_name": msg.user_name,
"text": msg.text,
"timestamp": msg.timestamp.timestamp() if hasattr(msg.timestamp, "timestamp") else msg.timestamp,
"message_id": msg.message_id,
"is_bot": msg.is_bot,
"reply_to_id": msg.reply_to_id,
"reactions": msg.reactions,
})
elif rpc_method == "should_skip_channel_heartbeat":
result_value = await self._adapter.should_skip_channel_heartbeat(channel_id)
elif rpc_method == "delete_egregore_webhook":
# Worker → gateway: delete an egregore webhook on the live client
# (the worker's ProxyPlatformAdapter has no .client). Used by the
# !snes_end / !snes_blow cleanup paths.
from tools._egregore_discord import delete_egregore_webhook_by_id
client = (
getattr(self._adapter, "_client", None)
or getattr(self._adapter, "client", None)
)
ok, err = await delete_egregore_webhook_by_id(
client,
rpc_args.get("guild_id"),
rpc_args.get("webhook_id"),
)
result_value = {"ok": ok, "error": err}
elif rpc_method == "execute_tool":
tool_name = rpc_args.get("tool_name")
tool_args = rpc_args.get("tool_args") or {}
tool_ctx_dict = rpc_args.get("tool_ctx") or {}
# Reconstruct ToolContext
from tool_context import ToolContext
t_ctx = ToolContext(
platform=tool_ctx_dict.get("platform") or self._platform,
channel_id=tool_ctx_dict.get("channel_id") or channel_id or "",
user_id=tool_ctx_dict.get("user_id") or "",
user_name=tool_ctx_dict.get("user_name") or "",
guild_id=tool_ctx_dict.get("guild_id") or "",
message_id=tool_ctx_dict.get("message_id") or "",
adapter=self._adapter,
redis=self._redis,
config=getattr(self._adapter, "config", None),
)
# Lazy compile ToolRegistry
if self._tool_registry is None:
from tools import ToolRegistry
from tool_loader import load_tools
self._tool_registry = ToolRegistry()
tools_dir = self._tools_dir
config_obj = getattr(self._adapter, "config", None)
if config_obj and hasattr(config_obj, "tools_dir"):
tools_dir = config_obj.tools_dir
logger.info("OutboundStreamConsumer lazy-loading tools from %s", tools_dir)
await asyncio.to_thread(load_tools, tools_dir, self._tool_registry)
logger.info("OutboundStreamConsumer executing delegated tool %s", tool_name)
result_value = await self._tool_registry.call(
name=tool_name,
arguments=tool_args,
user_id=t_ctx.user_id,
ctx=t_ctx,
)
else:
logger.error("Unknown RPC method: %s", rpc_method)
result_value = {"error": f"Unknown RPC method: {rpc_method}"}
except Exception as e:
logger.exception("Error processing RPC request %s", rpc_method)
result_value = {"error": str(e)}
await self._redis.rpush(reply_key, json.dumps(result_value))
await self._redis.expire(reply_key, 60)
async def _handle_dispatch_task(self, msg_id: str, raw: dict) -> None:
"""Process one outbound stream entry, serialized per channel and deduplicated.
The per-message unit of work spawned by the consume and autoclaim loops.
It deserializes the entry via
:func:`core.serialization.deserialize_stream_payload`; an
``rpc_request`` envelope is short-circuited to :meth:`_handle_rpc_request`
and acked. Otherwise it acquires a :class:`core.distributed_lock.DistributedLock`
on ``sg:lock:outbound:{platform}:{channel_id}`` so concurrent tasks for
the same channel send in order while different channels proceed in
parallel, then runs :meth:`_dispatch` and :meth:`_ack_outbound`.
Two Redis-backed idempotency guards prevent double-sends across redelivery
or lock failover: a completion-marker ``hget`` of ``message_id`` on the
payload's ``message_key`` skips entries already delivered, and an atomic
``set`` with ``nx`` on ``sg:outbound:claim:{message_key}`` (300s TTL)
claims the send before it happens, releasing the claim with ``delete`` if
:meth:`_dispatch` raises so a legitimate retry can proceed. The lock is
always released in a ``finally``. Dispatch errors are logged and
swallowed so one bad message cannot wedge the loop.
Scheduled as a task by :meth:`_consume_loop` and :meth:`_autoclaim_loop`;
also driven directly by the concurrency and idempotency tests (e.g.
``tests/core/test_concurrent_consumers.py``,
``tests/core/test_stream_idempotency.py``).
Args:
msg_id: Redis stream entry id, used for acking and claim bookkeeping.
raw: The raw stream field map to deserialize into an outbound payload.
"""
try:
payload = deserialize_stream_payload(raw)
# Intercept RPC requests before lock acquisition
if payload.get("type") == "rpc_request":
await self._handle_rpc_request(payload)
await self._ack_outbound(msg_id)
return
channel_id = payload.get("channel_id") or "global"
except Exception:
# If payload cannot be deserialized, call direct dispatch to trigger standard error handling/ACKs
try:
payload = deserialize_stream_payload(raw)
await self._dispatch(payload)
await self._ack_outbound(msg_id)
except Exception:
logger.error(
"Failed to dispatch outbound message",
exc_info=True,
extra={"stream_msg_id": msg_id},
)
return
from core.distributed_lock import DistributedLock
lock = DistributedLock(
self._redis,
f"sg:lock:outbound:{self._platform}:{channel_id}",
ttl=30,
auto_renew=True,
)
try:
await lock.acquire_blocking()
# Idempotency check (completion marker):
# If the platform sent ID already exists in the Redis message cache under message_key,
# this message has already been successfully dispatched to Discord/Matrix.
message_key = payload.get("message_key")
if message_key:
sent_id = await self._redis.hget(message_key, "message_id")
if sent_id:
sent_id_str = sent_id.decode() if isinstance(sent_id, bytes) else str(sent_id)
if sent_id_str.strip():
logger.info(
"Outbound message already dispatched (idempotency check hit)",
extra={
"stream_msg_id": msg_id,
"message_key": message_key,
"sent_id": sent_id_str,
},
)
await self._ack_outbound(msg_id)
return
# Atomic pre-dispatch claim: closes the check-then-act window so a
# duplicate send cannot happen if the per-channel lock ever fails to
# provide mutual exclusion (e.g. a Redis failover split-brain) or the
# stream entry is redelivered concurrently. The completion marker
# above still guards against re-sending after a successful dispatch.
claim_key = None
if message_key:
claim_key = f"sg:outbound:claim:{message_key}"
claimed = await self._redis.set(claim_key, msg_id, nx=True, ex=300)
if not claimed:
logger.info(
"Outbound message already claimed; skipping duplicate dispatch",
extra={
"stream_msg_id": msg_id,
"message_key": message_key,
},
)
await self._ack_outbound(msg_id)
return
try:
await self._dispatch(payload)
await self._ack_outbound(msg_id)
except Exception:
# A propagated exception means the send did not complete (the
# post-send marker write is swallowed inside _dispatch), so it
# is safe to release the claim for a legitimate retry.
if claim_key:
try:
await self._redis.delete(claim_key)
except Exception:
pass
raise
except Exception:
logger.error(
"Failed to dispatch outbound message",
exc_info=True,
extra={"stream_msg_id": msg_id},
)
finally:
await lock.release()
async def _ack_outbound(self, msg_id: str) -> None:
"""Acknowledge an outbound stream entry and clear its autoclaim counter.
``xack``\\ s ``msg_id`` against ``_stream``/``_group`` so it leaves the pending
entries list, then best-effort ``hdel``\\ s its attempt count from the
``sg:autoclaim:outbound:attempts`` hash (swallowing errors) so a
successfully handled message is not later treated as a retry by the
autoclaim path.
Called by :meth:`_handle_dispatch_task` after every terminal outcome
(successful dispatch, idempotency/claim skip, or completed RPC) and never
on its own.
Args:
msg_id: Stream entry id to acknowledge.
"""
await self._redis.xack(self._stream, self._group, msg_id)
try:
await self._redis.hdel("sg:autoclaim:outbound:attempts", msg_id)
except Exception:
pass
async def _dispatch(self, payload: dict[str, Any]) -> None:
"""Translate an outbound envelope into the matching live-adapter call.
The bridge between a serialized envelope and the real platform client.
Branching on the payload ``type`` it invokes the corresponding method on
``self._adapter``: ``message`` to ``send``, ``file`` (base64-decoding
``file_data`` first) to ``send_file``, ``buttons`` to ``send_with_buttons``,
``typing``/``start_typing``/``stop_typing`` to the typing methods (feature-
detected with ``hasattr``), and ``reaction`` to ``add_reaction``. Presence
is handled up front before the ``channel_id`` guard because it is a global,
channel-less action; a payload missing ``channel_id`` for any other type is
logged and dropped, and an unknown type is logged as a warning.
On a successful send that returns a platform message id, and when the
payload carries a ``message_key``, it writes that id back to the Redis
message cache via ``hset`` (self-healing the key with a 90-day ``expire``
if it did not already exist) so the idempotency completion marker in
:meth:`_handle_dispatch_task` and later lookups can find it; a failure to
re-sync is logged but not raised. Called by :meth:`_handle_dispatch_task`
(both the normal and deserialize-fallback paths) and directly by the
outbound/dispatch unit tests (e.g. ``tests/core/test_outbound_consumer.py``,
``tests/core/test_message_id_resync.py``).
Args:
payload: The deserialized outbound envelope; its ``type`` selects the
adapter method and the remaining fields supply that method's args.
"""
msg_type = payload.get("type", "message")
channel_id = payload.get("channel_id")
# Presence is a global (non-channel) action — handle it before the
# channel_id guard, which would otherwise drop it.
if msg_type == "presence":
logger.info(
"Dispatching outbound presence",
extra={"platform": self._platform},
)
if hasattr(self._adapter, "set_presence"):
await self._adapter.set_presence(
payload.get("text", ""),
payload.get("emoji"),
)
else:
logger.warning(
"Adapter %s has no set_presence; dropping presence update",
self._platform,
)
return
if not channel_id:
logger.error("Outbound payload missing channel_id")
return
logger.info(
"Dispatching outbound %s",
msg_type,
extra={
"channel_id": channel_id,
"platform": self._platform,
"type": msg_type,
},
)
sent_id = None
if msg_type == "message":
# 💀🔥 Attach persistent ✧ toggle button to every message if enabled
_toggle_view = None
if self._platform in ("discord", "discord-self"):
try:
import feature_toggles
channel_key = f"{self._platform}:{channel_id}"
# Resolve guild_id from Discord channel cache # 💀🔥
_guild_id = None
try:
_client = getattr(self._adapter, "_client", None)
if _client is not None:
_ch = _client.get_channel(int(channel_id))
if _ch is not None and hasattr(_ch, "guild") and _ch.guild:
_guild_id = str(_ch.guild.id)
except Exception:
pass
disabled = await feature_toggles.is_disabled_resolving_discord_aliases(
self._redis, "toggle_menu", channel_key,
guild_id=_guild_id,
)
if not disabled:
from star_toggle_ui import build_star_toggle_view
_toggle_view = build_star_toggle_view(channel_id)
except Exception:
logger.debug("star_toggle_ui or toggle checks failed", exc_info=True)
if _toggle_view is not None:
sent_id = await self._adapter.send_with_buttons(
channel_id, payload["text"], _toggle_view,
)
else:
sent_id = await self._adapter.send(channel_id, payload["text"])
elif msg_type == "file":
file_data = payload["file_data"]
decoded_data = base64.b64decode(file_data) if isinstance(file_data, str) else file_data
sent_id = await self._adapter.send_file(
channel_id,
decoded_data,
payload.get("filename", "file"),
payload.get("mimetype", "application/octet-stream"),
)
elif msg_type == "buttons":
sent_id = await self._adapter.send_with_buttons(
channel_id,
payload["text"],
payload["buttons"],
)
elif msg_type in ("typing", "start_typing"):
if hasattr(self._adapter, "start_typing"):
await self._adapter.start_typing(channel_id)
elif hasattr(self._adapter, "send_typing"):
await self._adapter.send_typing(channel_id)
elif msg_type == "stop_typing":
if hasattr(self._adapter, "stop_typing"):
await self._adapter.stop_typing(channel_id)
elif msg_type == "reaction":
if hasattr(self._adapter, "add_reaction"):
await self._adapter.add_reaction(
channel_id,
payload["message_id"],
payload["emoji"],
)
else:
logger.warning(
"Adapter %s has no add_reaction; dropping reaction envelope",
getattr(self._adapter, "name", "?"),
)
else:
logger.warning("Unknown outbound message type: %s", msg_type)
message_key = payload.get("message_key")
if sent_id and message_key:
try:
# Check if the cache entry exists first (merging behavior)
exists = await self._redis.exists(message_key)
await self._redis.hset(message_key, "message_id", str(sent_id))
if not exists:
# Self-heal and apply a 90-day TTL (same as standard messages)
# to prevent any orphaned keys if this was written before log_message.
await self._redis.expire(message_key, 90 * 86400)
logger.info(
"Re-synced platform message_id to Redis message cache",
extra={
"message_key": message_key,
"message_id": sent_id,
"platform": self._platform,
"cache_existed": bool(exists),
},
)
except Exception:
logger.warning(
"Failed to re-sync platform message_id to Redis cache",
exc_info=True,
extra={
"message_key": message_key,
"message_id": sent_id,
},
)