Source code for core.stream_consumer

"""Stream consumer workers for Stargazer distributed architecture."""

import asyncio
import logging
from typing import Any, Callable, Awaitable, Optional

from redis.asyncio import Redis

from core.event_bus import INBOUND_STREAM, WORKER_GROUP
from core.serialization import deserialize_stream_payload
from core.dlq import handle_failed_message

logger = logging.getLogger("stargazer.stream_consumer")


[docs] class InboundStreamConsumer: """Consumes inbound message envelopes and dispatches to the processor. Runs as an asyncio task on worker nodes. Each worker instance has a unique consumer name within the WORKER_GROUP consumer group, ensuring Redis distributes messages across workers. """
[docs] def __init__( self, redis: Redis, consumer_name: str, process_fn: Callable[[dict[str, Any]], Awaitable[None]], autoclaim_interval: float = 30.0, autoclaim_min_idle: int = 60000, # 60 seconds state_machine: Optional[Any] = None, ) -> None: """Configure an inbound consumer without starting any background work. Stores the async Redis client, this worker's unique consumer name within the ``WORKER_GROUP`` consumer group (so Redis Streams fans messages out across workers), the async ``process_fn`` callback that handles a single deserialized payload, and the autoclaim tuning knobs. An optional ``state_machine`` is retained so the consumer can emit trace lifecycle transitions (``CLAIMED`` / ``COMPLETED`` / ``ERRORED``) as messages move through. Per-message retry counters (``_attempt_counts``) and the live ``_active_tasks`` set are initialised empty, and the loop tasks are left as ``None`` until :meth:`start` is called -- no Redis I/O happens here. Constructed by ``inference_main.py`` (the inference worker) and by a wide range of tests under ``tests/core`` and ``tests/integration``. Args: redis: Async Redis client used for ``XREADGROUP``, ``XACK``, ``XPENDING``, ``XAUTOCLAIM``, the autoclaim audit hash, and the per-channel distributed lock. consumer_name: Unique consumer name within ``WORKER_GROUP``. process_fn: Async callback invoked with the deserialized payload dict for each successfully decoded message. autoclaim_interval: Seconds the autoclaim daemon sleeps between sweeps. autoclaim_min_idle: Minimum idle time in milliseconds before a pending message is eligible to be reclaimed from a crashed consumer. state_machine: Optional trace state machine whose ``transition`` is awaited / scheduled to record lifecycle hooks; ``None`` disables hooks. """ self._redis = redis self._consumer_name = consumer_name self._process_fn = process_fn self._autoclaim_interval = autoclaim_interval self._autoclaim_min_idle = autoclaim_min_idle self._state_machine = state_machine self._running = False self._task: asyncio.Task | None = None self._autoclaim_task: asyncio.Task | None = None # Track retry attempts per message ID self._attempt_counts: dict[str, int] = {} self._active_tasks: set[asyncio.Task] = set()
[docs] async def start(self) -> None: """Start the inbound consume loop and the autoclaim daemon. Flips ``_running`` true and spawns two long-lived asyncio tasks: the :meth:`_consume_loop` that blocks on ``XREADGROUP`` for new inbound envelopes, and the :meth:`_autoclaim_loop` that periodically reclaims messages orphaned by crashed workers. This is the activation point that turns a freshly constructed, idle consumer into a running worker; both tasks keep references on ``self`` so :meth:`stop` can cancel them. No Redis I/O happens directly here beyond the loop tasks it launches, and an informational log line records the consumer name and ``INBOUND_STREAM``. Called by the inference worker in ``inference_main.py`` once Redis is connected, and by tests under ``tests/core``. """ self._running = True self._task = asyncio.create_task(self._consume_loop(), name="inbound_consumer") self._autoclaim_task = asyncio.create_task( self._autoclaim_loop(), name="inbound_autoclaim" ) logger.info( "InboundStreamConsumer started", extra={"consumer_name": self._consumer_name, "stream": INBOUND_STREAM}, )
[docs] async def stop(self) -> None: """Gracefully shut down the inbound consumer and drain in-flight work. Clears ``_running`` so neither loop schedules new work, cancels the consume-loop and autoclaim tasks started by :meth:`start`, then awaits any per-message tasks still tracked in ``_active_tasks`` (gathering exceptions rather than raising) so that messages already mid-processing finish their ``XACK`` / DLQ path before the worker exits. This is the lifecycle counterpart to :meth:`start`, invoked during graceful worker shutdown in ``inference_main.py`` and by tests; it touches Redis only indirectly via the tasks it drains. """ self._running = False if self._task: self._task.cancel() if self._autoclaim_task: self._autoclaim_task.cancel() if self._active_tasks: logger.info("Awaiting %d active inbound processing tasks...", len(self._active_tasks)) await asyncio.gather(*self._active_tasks, return_exceptions=True) logger.info("InboundStreamConsumer stopped")
async def _consume_loop(self) -> None: """Block-read inbound envelopes and fan them out to per-message tasks. The hot path of the inbound worker. While running it issues a blocking ``XREADGROUP`` against ``INBOUND_STREAM`` under ``WORKER_GROUP`` with this consumer's name, reading up to ten new entries per poll. For each entry it opportunistically extracts the ``trace_id`` and, when a trace state machine was supplied, schedules a ``CLAIMED`` transition, then spawns a :meth:`_handle_message_task` coroutine tracked in ``_active_tasks`` so many channels are processed concurrently while a single channel stays serialized downstream. The loop is resilient: ``CancelledError`` breaks out cleanly for shutdown, and any other exception is logged at critical level before sleeping two seconds and retrying so a transient Redis fault never kills the worker. Launched as a background task by :meth:`start`; not called directly elsewhere. """ while self._running: try: logger.debug( "Blocking on XREADGROUP for inbound stream", extra={ "stream": INBOUND_STREAM, "group": WORKER_GROUP, "consumer": self._consumer_name, }, ) messages = await self._redis.xreadgroup( WORKER_GROUP, self._consumer_name, {INBOUND_STREAM: ">"}, count=10, block=5000, ) if not messages: continue logger.debug("Successfully read messages from stream", extra={"count": len(messages)}) for _stream_name, entries in messages: for msg_id, raw in entries: msg_id_str = msg_id.decode() if isinstance(msg_id, bytes) else str(msg_id) # Hook: TraceState.CLAIMED if self._state_machine: try: trace_id = b"" if isinstance(raw, dict): trace_id = raw.get(b"trace_id", b"") or raw.get("trace_id", b"") if trace_id: trace_id_str = trace_id.decode() if isinstance(trace_id, bytes) else str(trace_id) asyncio.create_task( self._state_machine.transition(trace_id_str, "CLAIMED", {"msg_id": msg_id_str}) ) except Exception as e: logger.debug("Failed to extract trace_id for hook: %s", e) task = asyncio.create_task( self._handle_message_task(msg_id_str, raw), name=f"inbound_msg_{msg_id_str}", ) self._active_tasks.add(task) task.add_done_callback(self._active_tasks.discard) except asyncio.CancelledError: break except Exception: logger.critical( "Inbound consumer loop crashed, restarting in 2s", exc_info=True, ) await asyncio.sleep(2) async def _handle_message_task(self, msg_id: str, raw: dict) -> None: """Serialize per-channel handling behind a Redis lock, then process. Concurrency wrapper around :meth:`_handle_message` that preserves per-channel ordering while still letting different channels run in parallel. It first deserializes the envelope and stream payload to read ``channel_id``; if deserialization fails it skips locking and hands the raw entry straight to :meth:`_handle_message` so the normal error-logging and DLQ path runs. Messages flagged ``is_addressed=False`` (passive, not directed at the bot) also bypass the lock since ordering does not matter for them. Otherwise it acquires a :class:`core.distributed_lock.DistributedLock` on ``sg:lock:channel:{channel_id}`` (60s TTL, auto-renewing) before delegating to :meth:`_handle_message`, releasing the lock in a ``finally``. Touches Redis only through that lock; the actual ``XACK`` / DLQ work happens inside the delegate. Spawned as a task by :meth:`_consume_loop` and :meth:`_autoclaim_loop`. Args: msg_id: Redis Streams entry ID being handled. raw: The raw field-value mapping read from the stream. """ try: from core.event_types import deserialize_envelope env = deserialize_envelope(raw) if env is None: raise ValueError("Unsupported schema version in envelope") payload = deserialize_stream_payload(raw) channel_id = payload.get("channel_id") or "global" except Exception as exc: # If payload cannot be deserialized, call direct handler to trigger normal logging and DLQ await self._handle_message(msg_id, raw) return is_addressed = payload.get("is_addressed", True) if not is_addressed: await self._handle_message(msg_id, raw) return from core.distributed_lock import DistributedLock lock = DistributedLock( self._redis, f"sg:lock:channel:{channel_id}", ttl=60, auto_renew=True, max_renewals=0, ) try: logger.debug( "Awaiting channel serialization lock", extra={"stream_msg_id": msg_id, "channel_id": channel_id}, ) await lock.acquire_blocking() await self._handle_message(msg_id, raw) finally: await lock.release() async def _handle_message(self, msg_id: str, raw: dict) -> None: """Deserialize, process, acknowledge, and DLQ a single inbound entry. The core per-message worker. It first runs an idempotency guard with ``XPENDING`` (``xpending_range``) on ``INBOUND_STREAM`` / ``WORKER_GROUP``: if the entry is no longer in the pending list another worker already acknowledged it, so it clears the local attempt counter and returns without re-processing (on a guard failure it logs and falls through to normal processing). It then increments the per-message attempt count in ``_attempt_counts``, deserializes the envelope and payload, and awaits the injected ``process_fn`` to do the real work. On success it sends ``XACK``, drops the attempt counter, and -- when a state machine and ``trace_id`` are present -- records a ``COMPLETED`` trace transition. A ``CancelledError`` raised mid-process is honored, but if the current task was tagged ``_user_cancelled`` (a deliberate user stop) the message is ``XACK``-ed so it is not retried. Any other exception is logged and routed to :func:`core.dlq.handle_failed_message` for retry/quarantine, followed by an ``ERRORED`` trace transition. Reaches Redis for ``XPENDING`` and ``XACK``, the DLQ, and trace hooks. Called by :meth:`_handle_message_task` (and directly for the deserialize-failure and passive-message fast paths). Args: msg_id: Redis Streams entry ID of the message being processed. raw: The raw field-value mapping read from the stream. """ try: # Inbound idempotency check: # Query Redis to verify if the message is still in the pending entries list (PEL). # If a message has already been acknowledged (XACKed) by another worker, it will be gone from the PEL. pending_info = await self._redis.xpending_range( INBOUND_STREAM, WORKER_GROUP, msg_id, msg_id, 1 ) if not pending_info: logger.info( "Inbound message already acknowledged by another worker (idempotency check hit)", extra={"stream_msg_id": msg_id}, ) self._attempt_counts.pop(msg_id, None) return except Exception as exc: # If xpending_range fails for any reason, we fall back to normal processing logger.warning( "Failed to run inbound idempotency check, falling back to processing", exc_info=True, extra={"stream_msg_id": msg_id}, ) attempt = self._attempt_counts.get(msg_id, 0) + 1 self._attempt_counts[msg_id] = attempt try: from core.event_types import deserialize_envelope env = deserialize_envelope(raw) if env is None: raise ValueError("Unsupported schema version in envelope") payload = deserialize_stream_payload(raw) logger.info( "Processing inbound message", extra={ "stream_msg_id": msg_id, "channel_id": payload.get("channel_id"), "platform": payload.get("platform"), "user_id": payload.get("user_id"), "attempt": attempt, }, ) try: await self._process_fn(payload) except asyncio.CancelledError: # Check if this task was cancelled by the user stop request current_task = asyncio.current_task() if current_task and getattr(current_task, "_user_cancelled", False): logger.info( "InboundStreamConsumer: Inbound message cancelled by user stop request; acknowledging message to prevent retry. msg_id=%s", msg_id ) try: await self._redis.xack(INBOUND_STREAM, WORKER_GROUP, msg_id) self._attempt_counts.pop(msg_id, None) except Exception: logger.exception("Failed to XACK user-cancelled message") raise await self._redis.xack(INBOUND_STREAM, WORKER_GROUP, msg_id) self._attempt_counts.pop(msg_id, None) logger.debug("XACK sent", extra={"stream_msg_id": msg_id}) # Hook: TraceState.COMPLETED if self._state_machine and payload.get("trace_id"): await self._state_machine.transition(payload["trace_id"], "COMPLETED") except Exception as exc: if isinstance(exc, TimeoutError) and "lock" in str(exc).lower(): from core.dlq import extract_stream_payload_bytes, extract_stream_aux_fields try: payload_bytes = extract_stream_payload_bytes(raw) aux = extract_stream_aux_fields(raw) ts_val = aux.get("ts", "") trace_val = aux.get("trace_id", "") channel_id = payload.get("channel_id") if 'payload' in locals() and payload else None logger.warning( "Lock collision detected (TimeoutError), requeueing message to stream", extra={"stream_msg_id": msg_id, "channel_id": channel_id}, ) await self._redis.xadd( INBOUND_STREAM, { b"data": payload_bytes, b"ts": ts_val.encode("utf-8") if ts_val else b"", b"trace_id": trace_val.encode("utf-8") if trace_val else b"" } ) await self._redis.xack(INBOUND_STREAM, WORKER_GROUP, msg_id) self._attempt_counts.pop(msg_id, None) return except Exception as requeue_err: logger.error( "Failed to requeue message after lock collision; falling back to standard DLQ path", exc_info=True, extra={"stream_msg_id": msg_id}, ) logger.error( "Failed to process inbound message", exc_info=True, extra={"stream_msg_id": msg_id, "attempt": attempt}, ) await handle_failed_message( self._redis, INBOUND_STREAM, WORKER_GROUP, msg_id, raw, exc, attempt, ) # Hook: TraceState.ERRORED if self._state_machine: try: trace_id = b"" if isinstance(raw, dict): trace_id = raw.get(b"trace_id", b"") or raw.get("trace_id", b"") if trace_id: trace_id_str = trace_id.decode() if isinstance(trace_id, bytes) else str(trace_id) await self._state_machine.transition(trace_id_str, "ERRORED") except Exception as e: logger.debug("Failed to extract trace_id for hook on error: %s", e) async def _autoclaim_loop(self) -> None: """Periodically claim orphaned messages from crashed consumers. Uses XAUTOCLAIM to reassign messages that have been pending for longer than autoclaim_min_idle milliseconds. """ while self._running: try: await asyncio.sleep(self._autoclaim_interval) logger.debug( "Executing background autoclaim sweep", extra={ "stream": INBOUND_STREAM, "group": WORKER_GROUP, "min_idle_ms": self._autoclaim_min_idle, }, ) result = await self._redis.xautoclaim( INBOUND_STREAM, WORKER_GROUP, self._consumer_name, min_idle_time=self._autoclaim_min_idle, start_id="0-0", count=10, ) if result and len(result) >= 2: claimed = result[1] if claimed: logger.warning( "Autoclaimed %d orphaned messages", len(claimed), extra={"consumer_name": self._consumer_name}, ) for msg_id, raw in claimed: msg_id_str = msg_id.decode() if isinstance(msg_id, bytes) else str(msg_id) # Audit claim attempts globally using a Redis hash attempts = await self._redis.hincrby("sg:autoclaim:attempts", msg_id_str, 1) if attempts > 3: logger.warning( "Quarantining toxic message after exceeding max claim attempts", extra={"stream_msg_id": msg_id_str, "attempts": attempts} ) await handle_failed_message( self._redis, INBOUND_STREAM, WORKER_GROUP, msg_id_str, raw, ValueError(f"Max claim attempts exceeded: {attempts}"), attempts ) await self._redis.xack(INBOUND_STREAM, WORKER_GROUP, msg_id_str) await self._redis.hdel("sg:autoclaim:attempts", msg_id_str) continue # Hook: TraceState.CLAIMED if self._state_machine: try: trace_id = b"" if isinstance(raw, dict): trace_id = raw.get(b"trace_id", b"") or raw.get("trace_id", b"") if trace_id: trace_id_str = trace_id.decode() if isinstance(trace_id, bytes) else str(trace_id) asyncio.create_task( self._state_machine.transition(trace_id_str, "CLAIMED", {"msg_id": msg_id_str}) ) except Exception as e: logger.debug("Failed to extract trace_id for hook in autoclaim: %s", e) # Concurrent claimed processing that respects locks task = asyncio.create_task( self._handle_message_task(msg_id_str, raw), name=f"autoclaim_msg_{msg_id_str}", ) self._active_tasks.add(task) task.add_done_callback(self._active_tasks.discard) except asyncio.CancelledError: break except Exception: logger.error("Autoclaim loop error", exc_info=True)
[docs] class OutboundStreamConsumer: """Consumes outbound messages intended for platform adapters. Runs on gateway nodes. Reads from sg:stream:outbound:{platform} for each configured platform. """
[docs] def __init__( self, redis: Redis, consumer_name: str, process_fn: Callable[[dict[str, Any]], Awaitable[None]], platforms: list[str], ) -> None: """Configure an outbound consumer for one or more platform streams. Stores the async Redis client, this gateway's unique consumer name (within the ``stargazer_gateways`` consumer group used by :meth:`_consume_loop` / :meth:`_handle_message`), and the async ``process_fn`` that hands a deserialized outbound payload to the platform adapter. From ``platforms`` it precomputes ``_streams``, a mapping of ``sg:stream:outbound:{platform}`` to the ``">"`` read cursor for each configured platform, which is passed straight to ``XREADGROUP``. The running flag, loop task, and ``_active_tasks`` set are initialised so the consumer is idle until :meth:`start` runs; no Redis I/O happens here. Constructed by ``gateway_main.py`` (the gateway service) and by tests under ``tests/core`` and ``tests/integration``. Args: redis: Async Redis client used for ``XREADGROUP`` and ``XACK`` on the outbound platform streams. consumer_name: Unique consumer name within the ``stargazer_gateways`` group. process_fn: Async callback invoked with the deserialized outbound payload dict, which routes the send to the platform adapter. platforms: Platform identifiers (e.g. ``"discord"``) whose outbound streams this consumer should read; an empty list causes :meth:`start` to no-op with a warning. """ self._redis = redis self._consumer_name = consumer_name self._process_fn = process_fn self._platforms = platforms self._streams = {f"sg:stream:outbound:{p}": ">" for p in platforms} self._running = False self._task: asyncio.Task | None = None self._active_tasks: set[asyncio.Task] = set()
[docs] async def start(self) -> None: """Start the outbound consume loop unless no platform streams exist. Activation point for the gateway-side consumer. If ``_streams`` is empty (no platforms were configured) it logs a warning and returns without doing anything, since there is nothing to read. Otherwise it flips ``_running`` true and spawns the :meth:`_consume_loop` task that block-reads the ``sg:stream:outbound:{platform}`` streams; the task reference is kept on ``self`` for :meth:`stop` to cancel. Unlike the inbound consumer there is no autoclaim daemon. Called by the gateway in ``gateway_main.py`` -- and importantly before each platform ``adapter.start()`` so responses are never missed -- and by tests under ``tests/core``. """ if not self._streams: logger.warning("No platforms provided, OutboundStreamConsumer will not start") return self._running = True self._task = asyncio.create_task(self._consume_loop(), name="outbound_consumer") logger.info( "OutboundStreamConsumer started", extra={"consumer_name": self._consumer_name, "streams": list(self._streams.keys())}, )
[docs] async def stop(self) -> None: """Gracefully shut down the outbound consumer and drain in-flight sends. Clears ``_running`` so the loop stops scheduling work, cancels the consume-loop task launched by :meth:`start`, and then awaits any per-message send tasks still tracked in ``_active_tasks`` (gathering exceptions) so outbound deliveries already underway complete their adapter call and ``XACK`` before the gateway exits. Lifecycle counterpart to :meth:`start`, invoked during graceful gateway shutdown in ``gateway_main.py`` and by tests; reaches Redis only via the tasks it drains. """ self._running = False if self._task: self._task.cancel() if self._active_tasks: logger.info("Awaiting %d active outbound tasks...", len(self._active_tasks)) await asyncio.gather(*self._active_tasks, return_exceptions=True) logger.info("OutboundStreamConsumer stopped")
async def _consume_loop(self) -> None: """Block-read outbound messages and fan them out to per-message sends. Hot path of the gateway consumer. While running it issues a blocking ``XREADGROUP`` under the ``stargazer_gateways`` group across every ``sg:stream:outbound:{platform}`` stream in ``_streams``, reading up to ten entries per poll. For each entry it spawns a :meth:`_handle_message` task -- passing the originating ``stream_name`` so the right stream is acknowledged -- and tracks it in ``_active_tasks`` so multiple sends run concurrently. ``CancelledError`` breaks the loop for shutdown; any other exception is logged at critical level before a two-second backoff and retry so a transient Redis fault never permanently stops delivery. Launched as a background task by :meth:`start`; not called elsewhere. """ while self._running: try: messages = await self._redis.xreadgroup( "stargazer_gateways", self._consumer_name, self._streams, count=10, block=5000, ) if not messages: continue for stream_name, entries in messages: for msg_id, raw in entries: msg_id_str = msg_id.decode() if isinstance(msg_id, bytes) else str(msg_id) task = asyncio.create_task( self._handle_message(msg_id_str, raw, stream_name), name=f"outbound_msg_{msg_id_str}", ) self._active_tasks.add(task) task.add_done_callback(self._active_tasks.discard) except asyncio.CancelledError: break except Exception: logger.critical( "Outbound consumer loop crashed, restarting in 2s", exc_info=True, ) await asyncio.sleep(2) async def _handle_message(self, msg_id: str, raw: dict, stream_name: bytes | str) -> None: """Deliver one outbound message to the platform adapter, then acknowledge. Per-message worker for the gateway. It normalizes ``stream_name`` to a string, deserializes the envelope and outbound payload, and awaits the injected ``process_fn`` -- which routes the send to the platform adapter in ``gateway_main.py`` -- then sends ``XACK`` to the originating outbound stream under the ``stargazer_gateways`` group so the entry is not redelivered. On any failure it logs the error and forwards the entry to :func:`core.dlq.handle_failed_message` (attempt count fixed at 1, since the outbound consumer does not track per-message retries). Reaches Redis for the ``XACK`` and DLQ. Spawned as a task by :meth:`_consume_loop`. Args: msg_id: Redis Streams entry ID of the outbound message. raw: The raw field-value mapping read from the stream. stream_name: Outbound stream the entry came from, used as the ``XACK`` target; accepted as bytes or str. """ stream_name_str = stream_name.decode() if isinstance(stream_name, bytes) else str(stream_name) try: from core.event_types import deserialize_envelope env = deserialize_envelope(raw) if env is None: raise ValueError("Unsupported schema version in envelope") payload = deserialize_stream_payload(raw) # The outbound payload structure handles method routing inside `gateway_main.py` await self._process_fn(payload) await self._redis.xack(stream_name_str, "stargazer_gateways", msg_id) except Exception as exc: logger.error( "Failed to process outbound message", exc_info=True, extra={"stream_msg_id": msg_id}, ) await handle_failed_message( self._redis, stream_name_str, "stargazer_gateways", msg_id, raw, exc, 1, )