"""Stream consumer workers for Stargazer distributed architecture."""
import asyncio
import logging
from typing import Any, Callable, Awaitable, Optional
from redis.asyncio import Redis
from core.event_bus import INBOUND_STREAM, WORKER_GROUP
from core.serialization import deserialize_stream_payload
from core.dlq import handle_failed_message
logger = logging.getLogger("stargazer.stream_consumer")
[docs]
class InboundStreamConsumer:
"""Consumes inbound message envelopes and dispatches to the processor.
Runs as an asyncio task on worker nodes. Each worker instance
has a unique consumer name within the WORKER_GROUP consumer group,
ensuring Redis distributes messages across workers.
"""
[docs]
def __init__(
self,
redis: Redis,
consumer_name: str,
process_fn: Callable[[dict[str, Any]], Awaitable[None]],
autoclaim_interval: float = 30.0,
autoclaim_min_idle: int = 60000, # 60 seconds
state_machine: Optional[Any] = None,
) -> None:
"""Configure an inbound consumer without starting any background work.
Stores the async Redis client, this worker's unique consumer name within
the ``WORKER_GROUP`` consumer group (so Redis Streams fans messages out
across workers), the async ``process_fn`` callback that handles a single
deserialized payload, and the autoclaim tuning knobs. An optional
``state_machine`` is retained so the consumer can emit trace lifecycle
transitions (``CLAIMED`` / ``COMPLETED`` / ``ERRORED``) as messages move
through. Per-message retry counters (``_attempt_counts``) and the live
``_active_tasks`` set are initialised empty, and the loop tasks are left
as ``None`` until :meth:`start` is called -- no Redis I/O happens here.
Constructed by ``inference_main.py`` (the inference worker) and by a wide
range of tests under ``tests/core`` and ``tests/integration``.
Args:
redis: Async Redis client used for ``XREADGROUP``, ``XACK``,
``XPENDING``, ``XAUTOCLAIM``, the autoclaim audit hash, and the
per-channel distributed lock.
consumer_name: Unique consumer name within ``WORKER_GROUP``.
process_fn: Async callback invoked with the deserialized payload dict
for each successfully decoded message.
autoclaim_interval: Seconds the autoclaim daemon sleeps between sweeps.
autoclaim_min_idle: Minimum idle time in milliseconds before a pending
message is eligible to be reclaimed from a crashed consumer.
state_machine: Optional trace state machine whose ``transition`` is
awaited / scheduled to record lifecycle hooks; ``None`` disables
hooks.
"""
self._redis = redis
self._consumer_name = consumer_name
self._process_fn = process_fn
self._autoclaim_interval = autoclaim_interval
self._autoclaim_min_idle = autoclaim_min_idle
self._state_machine = state_machine
self._running = False
self._task: asyncio.Task | None = None
self._autoclaim_task: asyncio.Task | None = None
# Track retry attempts per message ID
self._attempt_counts: dict[str, int] = {}
self._active_tasks: set[asyncio.Task] = set()
[docs]
async def start(self) -> None:
"""Start the inbound consume loop and the autoclaim daemon.
Flips ``_running`` true and spawns two long-lived asyncio tasks: the
:meth:`_consume_loop` that blocks on ``XREADGROUP`` for new inbound
envelopes, and the :meth:`_autoclaim_loop` that periodically reclaims
messages orphaned by crashed workers. This is the activation point that
turns a freshly constructed, idle consumer into a running worker; both
tasks keep references on ``self`` so :meth:`stop` can cancel them. No
Redis I/O happens directly here beyond the loop tasks it launches, and an
informational log line records the consumer name and ``INBOUND_STREAM``.
Called by the inference worker in ``inference_main.py`` once Redis is
connected, and by tests under ``tests/core``.
"""
self._running = True
self._task = asyncio.create_task(self._consume_loop(), name="inbound_consumer")
self._autoclaim_task = asyncio.create_task(
self._autoclaim_loop(), name="inbound_autoclaim"
)
logger.info(
"InboundStreamConsumer started",
extra={"consumer_name": self._consumer_name, "stream": INBOUND_STREAM},
)
[docs]
async def stop(self) -> None:
"""Gracefully shut down the inbound consumer and drain in-flight work.
Clears ``_running`` so neither loop schedules new work, cancels the
consume-loop and autoclaim tasks started by :meth:`start`, then awaits any
per-message tasks still tracked in ``_active_tasks`` (gathering exceptions
rather than raising) so that messages already mid-processing finish their
``XACK`` / DLQ path before the worker exits. This is the lifecycle
counterpart to :meth:`start`, invoked during graceful worker shutdown in
``inference_main.py`` and by tests; it touches Redis only indirectly via
the tasks it drains.
"""
self._running = False
if self._task:
self._task.cancel()
if self._autoclaim_task:
self._autoclaim_task.cancel()
if self._active_tasks:
logger.info("Awaiting %d active inbound processing tasks...", len(self._active_tasks))
await asyncio.gather(*self._active_tasks, return_exceptions=True)
logger.info("InboundStreamConsumer stopped")
async def _consume_loop(self) -> None:
"""Block-read inbound envelopes and fan them out to per-message tasks.
The hot path of the inbound worker. While running it issues a blocking
``XREADGROUP`` against ``INBOUND_STREAM`` under ``WORKER_GROUP`` with this
consumer's name, reading up to ten new entries per poll. For each entry it
opportunistically extracts the ``trace_id`` and, when a trace state machine
was supplied, schedules a ``CLAIMED`` transition, then spawns a
:meth:`_handle_message_task` coroutine tracked in ``_active_tasks`` so
many channels are processed concurrently while a single channel stays
serialized downstream. The loop is resilient: ``CancelledError`` breaks out
cleanly for shutdown, and any other exception is logged at critical level
before sleeping two seconds and retrying so a transient Redis fault never
kills the worker. Launched as a background task by :meth:`start`; not
called directly elsewhere.
"""
while self._running:
try:
logger.debug(
"Blocking on XREADGROUP for inbound stream",
extra={
"stream": INBOUND_STREAM,
"group": WORKER_GROUP,
"consumer": self._consumer_name,
},
)
messages = await self._redis.xreadgroup(
WORKER_GROUP,
self._consumer_name,
{INBOUND_STREAM: ">"},
count=10,
block=5000,
)
if not messages:
continue
logger.debug("Successfully read messages from stream", extra={"count": len(messages)})
for _stream_name, entries in messages:
for msg_id, raw in entries:
msg_id_str = msg_id.decode() if isinstance(msg_id, bytes) else str(msg_id)
# Hook: TraceState.CLAIMED
if self._state_machine:
try:
trace_id = b""
if isinstance(raw, dict):
trace_id = raw.get(b"trace_id", b"") or raw.get("trace_id", b"")
if trace_id:
trace_id_str = trace_id.decode() if isinstance(trace_id, bytes) else str(trace_id)
asyncio.create_task(
self._state_machine.transition(trace_id_str, "CLAIMED", {"msg_id": msg_id_str})
)
except Exception as e:
logger.debug("Failed to extract trace_id for hook: %s", e)
task = asyncio.create_task(
self._handle_message_task(msg_id_str, raw),
name=f"inbound_msg_{msg_id_str}",
)
self._active_tasks.add(task)
task.add_done_callback(self._active_tasks.discard)
except asyncio.CancelledError:
break
except Exception:
logger.critical(
"Inbound consumer loop crashed, restarting in 2s",
exc_info=True,
)
await asyncio.sleep(2)
async def _handle_message_task(self, msg_id: str, raw: dict) -> None:
"""Serialize per-channel handling behind a Redis lock, then process.
Concurrency wrapper around :meth:`_handle_message` that preserves
per-channel ordering while still letting different channels run in
parallel. It first deserializes the envelope and stream payload to read
``channel_id``; if deserialization fails it skips locking and hands the raw
entry straight to :meth:`_handle_message` so the normal error-logging and
DLQ path runs. Messages flagged ``is_addressed=False`` (passive, not
directed at the bot) also bypass the lock since ordering does not matter
for them. Otherwise it acquires a :class:`core.distributed_lock.DistributedLock`
on ``sg:lock:channel:{channel_id}`` (60s TTL, auto-renewing) before
delegating to :meth:`_handle_message`, releasing the lock in a ``finally``.
Touches Redis only through that lock; the actual ``XACK`` / DLQ work
happens inside the delegate. Spawned as a task by :meth:`_consume_loop`
and :meth:`_autoclaim_loop`.
Args:
msg_id: Redis Streams entry ID being handled.
raw: The raw field-value mapping read from the stream.
"""
try:
from core.event_types import deserialize_envelope
env = deserialize_envelope(raw)
if env is None:
raise ValueError("Unsupported schema version in envelope")
payload = deserialize_stream_payload(raw)
channel_id = payload.get("channel_id") or "global"
except Exception as exc:
# If payload cannot be deserialized, call direct handler to trigger normal logging and DLQ
await self._handle_message(msg_id, raw)
return
is_addressed = payload.get("is_addressed", True)
if not is_addressed:
await self._handle_message(msg_id, raw)
return
from core.distributed_lock import DistributedLock
lock = DistributedLock(
self._redis,
f"sg:lock:channel:{channel_id}",
ttl=60,
auto_renew=True,
max_renewals=0,
)
try:
logger.debug(
"Awaiting channel serialization lock",
extra={"stream_msg_id": msg_id, "channel_id": channel_id},
)
await lock.acquire_blocking()
await self._handle_message(msg_id, raw)
finally:
await lock.release()
async def _handle_message(self, msg_id: str, raw: dict) -> None:
"""Deserialize, process, acknowledge, and DLQ a single inbound entry.
The core per-message worker. It first runs an idempotency guard with
``XPENDING`` (``xpending_range``) on ``INBOUND_STREAM`` / ``WORKER_GROUP``:
if the entry is no longer in the pending list another worker already
acknowledged it, so it clears the local attempt counter and returns
without re-processing (on a guard failure it logs and falls through to
normal processing). It then increments the per-message attempt count in
``_attempt_counts``, deserializes the envelope and payload, and awaits the
injected ``process_fn`` to do the real work. On success it sends ``XACK``,
drops the attempt counter, and -- when a state machine and ``trace_id`` are
present -- records a ``COMPLETED`` trace transition. A ``CancelledError``
raised mid-process is honored, but if the current task was tagged
``_user_cancelled`` (a deliberate user stop) the message is ``XACK``-ed so
it is not retried. Any other exception is logged and routed to
:func:`core.dlq.handle_failed_message` for retry/quarantine, followed by an
``ERRORED`` trace transition. Reaches Redis for ``XPENDING`` and ``XACK``,
the DLQ, and trace hooks. Called by :meth:`_handle_message_task` (and
directly for the deserialize-failure and passive-message fast paths).
Args:
msg_id: Redis Streams entry ID of the message being processed.
raw: The raw field-value mapping read from the stream.
"""
try:
# Inbound idempotency check:
# Query Redis to verify if the message is still in the pending entries list (PEL).
# If a message has already been acknowledged (XACKed) by another worker, it will be gone from the PEL.
pending_info = await self._redis.xpending_range(
INBOUND_STREAM, WORKER_GROUP, msg_id, msg_id, 1
)
if not pending_info:
logger.info(
"Inbound message already acknowledged by another worker (idempotency check hit)",
extra={"stream_msg_id": msg_id},
)
self._attempt_counts.pop(msg_id, None)
return
except Exception as exc:
# If xpending_range fails for any reason, we fall back to normal processing
logger.warning(
"Failed to run inbound idempotency check, falling back to processing",
exc_info=True,
extra={"stream_msg_id": msg_id},
)
attempt = self._attempt_counts.get(msg_id, 0) + 1
self._attempt_counts[msg_id] = attempt
try:
from core.event_types import deserialize_envelope
env = deserialize_envelope(raw)
if env is None:
raise ValueError("Unsupported schema version in envelope")
payload = deserialize_stream_payload(raw)
logger.info(
"Processing inbound message",
extra={
"stream_msg_id": msg_id,
"channel_id": payload.get("channel_id"),
"platform": payload.get("platform"),
"user_id": payload.get("user_id"),
"attempt": attempt,
},
)
try:
await self._process_fn(payload)
except asyncio.CancelledError:
# Check if this task was cancelled by the user stop request
current_task = asyncio.current_task()
if current_task and getattr(current_task, "_user_cancelled", False):
logger.info(
"InboundStreamConsumer: Inbound message cancelled by user stop request; acknowledging message to prevent retry. msg_id=%s",
msg_id
)
try:
await self._redis.xack(INBOUND_STREAM, WORKER_GROUP, msg_id)
self._attempt_counts.pop(msg_id, None)
except Exception:
logger.exception("Failed to XACK user-cancelled message")
raise
await self._redis.xack(INBOUND_STREAM, WORKER_GROUP, msg_id)
self._attempt_counts.pop(msg_id, None)
logger.debug("XACK sent", extra={"stream_msg_id": msg_id})
# Hook: TraceState.COMPLETED
if self._state_machine and payload.get("trace_id"):
await self._state_machine.transition(payload["trace_id"], "COMPLETED")
except Exception as exc:
if isinstance(exc, TimeoutError) and "lock" in str(exc).lower():
from core.dlq import extract_stream_payload_bytes, extract_stream_aux_fields
try:
payload_bytes = extract_stream_payload_bytes(raw)
aux = extract_stream_aux_fields(raw)
ts_val = aux.get("ts", "")
trace_val = aux.get("trace_id", "")
channel_id = payload.get("channel_id") if 'payload' in locals() and payload else None
logger.warning(
"Lock collision detected (TimeoutError), requeueing message to stream",
extra={"stream_msg_id": msg_id, "channel_id": channel_id},
)
await self._redis.xadd(
INBOUND_STREAM,
{
b"data": payload_bytes,
b"ts": ts_val.encode("utf-8") if ts_val else b"",
b"trace_id": trace_val.encode("utf-8") if trace_val else b""
}
)
await self._redis.xack(INBOUND_STREAM, WORKER_GROUP, msg_id)
self._attempt_counts.pop(msg_id, None)
return
except Exception as requeue_err:
logger.error(
"Failed to requeue message after lock collision; falling back to standard DLQ path",
exc_info=True,
extra={"stream_msg_id": msg_id},
)
logger.error(
"Failed to process inbound message",
exc_info=True,
extra={"stream_msg_id": msg_id, "attempt": attempt},
)
await handle_failed_message(
self._redis, INBOUND_STREAM, WORKER_GROUP,
msg_id, raw, exc, attempt,
)
# Hook: TraceState.ERRORED
if self._state_machine:
try:
trace_id = b""
if isinstance(raw, dict):
trace_id = raw.get(b"trace_id", b"") or raw.get("trace_id", b"")
if trace_id:
trace_id_str = trace_id.decode() if isinstance(trace_id, bytes) else str(trace_id)
await self._state_machine.transition(trace_id_str, "ERRORED")
except Exception as e:
logger.debug("Failed to extract trace_id for hook on error: %s", e)
async def _autoclaim_loop(self) -> None:
"""Periodically claim orphaned messages from crashed consumers.
Uses XAUTOCLAIM to reassign messages that have been pending
for longer than autoclaim_min_idle milliseconds.
"""
while self._running:
try:
await asyncio.sleep(self._autoclaim_interval)
logger.debug(
"Executing background autoclaim sweep",
extra={
"stream": INBOUND_STREAM,
"group": WORKER_GROUP,
"min_idle_ms": self._autoclaim_min_idle,
},
)
result = await self._redis.xautoclaim(
INBOUND_STREAM,
WORKER_GROUP,
self._consumer_name,
min_idle_time=self._autoclaim_min_idle,
start_id="0-0",
count=10,
)
if result and len(result) >= 2:
claimed = result[1]
if claimed:
logger.warning(
"Autoclaimed %d orphaned messages",
len(claimed),
extra={"consumer_name": self._consumer_name},
)
for msg_id, raw in claimed:
msg_id_str = msg_id.decode() if isinstance(msg_id, bytes) else str(msg_id)
# Audit claim attempts globally using a Redis hash
attempts = await self._redis.hincrby("sg:autoclaim:attempts", msg_id_str, 1)
if attempts > 3:
logger.warning(
"Quarantining toxic message after exceeding max claim attempts",
extra={"stream_msg_id": msg_id_str, "attempts": attempts}
)
await handle_failed_message(
self._redis,
INBOUND_STREAM,
WORKER_GROUP,
msg_id_str,
raw,
ValueError(f"Max claim attempts exceeded: {attempts}"),
attempts
)
await self._redis.xack(INBOUND_STREAM, WORKER_GROUP, msg_id_str)
await self._redis.hdel("sg:autoclaim:attempts", msg_id_str)
continue
# Hook: TraceState.CLAIMED
if self._state_machine:
try:
trace_id = b""
if isinstance(raw, dict):
trace_id = raw.get(b"trace_id", b"") or raw.get("trace_id", b"")
if trace_id:
trace_id_str = trace_id.decode() if isinstance(trace_id, bytes) else str(trace_id)
asyncio.create_task(
self._state_machine.transition(trace_id_str, "CLAIMED", {"msg_id": msg_id_str})
)
except Exception as e:
logger.debug("Failed to extract trace_id for hook in autoclaim: %s", e)
# Concurrent claimed processing that respects locks
task = asyncio.create_task(
self._handle_message_task(msg_id_str, raw),
name=f"autoclaim_msg_{msg_id_str}",
)
self._active_tasks.add(task)
task.add_done_callback(self._active_tasks.discard)
except asyncio.CancelledError:
break
except Exception:
logger.error("Autoclaim loop error", exc_info=True)
[docs]
class OutboundStreamConsumer:
"""Consumes outbound messages intended for platform adapters.
Runs on gateway nodes. Reads from sg:stream:outbound:{platform} for
each configured platform.
"""
[docs]
def __init__(
self,
redis: Redis,
consumer_name: str,
process_fn: Callable[[dict[str, Any]], Awaitable[None]],
platforms: list[str],
) -> None:
"""Configure an outbound consumer for one or more platform streams.
Stores the async Redis client, this gateway's unique consumer name
(within the ``stargazer_gateways`` consumer group used by
:meth:`_consume_loop` / :meth:`_handle_message`), and the async
``process_fn`` that hands a deserialized outbound payload to the platform
adapter. From ``platforms`` it precomputes ``_streams``, a mapping of
``sg:stream:outbound:{platform}`` to the ``">"`` read cursor for each
configured platform, which is passed straight to ``XREADGROUP``. The
running flag, loop task, and ``_active_tasks`` set are initialised so the
consumer is idle until :meth:`start` runs; no Redis I/O happens here.
Constructed by ``gateway_main.py`` (the gateway service) and by tests
under ``tests/core`` and ``tests/integration``.
Args:
redis: Async Redis client used for ``XREADGROUP`` and ``XACK`` on the
outbound platform streams.
consumer_name: Unique consumer name within the ``stargazer_gateways``
group.
process_fn: Async callback invoked with the deserialized outbound
payload dict, which routes the send to the platform adapter.
platforms: Platform identifiers (e.g. ``"discord"``) whose outbound
streams this consumer should read; an empty list causes
:meth:`start` to no-op with a warning.
"""
self._redis = redis
self._consumer_name = consumer_name
self._process_fn = process_fn
self._platforms = platforms
self._streams = {f"sg:stream:outbound:{p}": ">" for p in platforms}
self._running = False
self._task: asyncio.Task | None = None
self._active_tasks: set[asyncio.Task] = set()
[docs]
async def start(self) -> None:
"""Start the outbound consume loop unless no platform streams exist.
Activation point for the gateway-side consumer. If ``_streams`` is empty
(no platforms were configured) it logs a warning and returns without
doing anything, since there is nothing to read. Otherwise it flips
``_running`` true and spawns the :meth:`_consume_loop` task that block-reads
the ``sg:stream:outbound:{platform}`` streams; the task reference is kept on
``self`` for :meth:`stop` to cancel. Unlike the inbound consumer there is
no autoclaim daemon. Called by the gateway in ``gateway_main.py`` -- and
importantly before each platform ``adapter.start()`` so responses are never
missed -- and by tests under ``tests/core``.
"""
if not self._streams:
logger.warning("No platforms provided, OutboundStreamConsumer will not start")
return
self._running = True
self._task = asyncio.create_task(self._consume_loop(), name="outbound_consumer")
logger.info(
"OutboundStreamConsumer started",
extra={"consumer_name": self._consumer_name, "streams": list(self._streams.keys())},
)
[docs]
async def stop(self) -> None:
"""Gracefully shut down the outbound consumer and drain in-flight sends.
Clears ``_running`` so the loop stops scheduling work, cancels the
consume-loop task launched by :meth:`start`, and then awaits any
per-message send tasks still tracked in ``_active_tasks`` (gathering
exceptions) so outbound deliveries already underway complete their
adapter call and ``XACK`` before the gateway exits. Lifecycle counterpart
to :meth:`start`, invoked during graceful gateway shutdown in
``gateway_main.py`` and by tests; reaches Redis only via the tasks it
drains.
"""
self._running = False
if self._task:
self._task.cancel()
if self._active_tasks:
logger.info("Awaiting %d active outbound tasks...", len(self._active_tasks))
await asyncio.gather(*self._active_tasks, return_exceptions=True)
logger.info("OutboundStreamConsumer stopped")
async def _consume_loop(self) -> None:
"""Block-read outbound messages and fan them out to per-message sends.
Hot path of the gateway consumer. While running it issues a blocking
``XREADGROUP`` under the ``stargazer_gateways`` group across every
``sg:stream:outbound:{platform}`` stream in ``_streams``, reading up to
ten entries per poll. For each entry it spawns a :meth:`_handle_message`
task -- passing the originating ``stream_name`` so the right stream is
acknowledged -- and tracks it in ``_active_tasks`` so multiple sends run
concurrently. ``CancelledError`` breaks the loop for shutdown; any other
exception is logged at critical level before a two-second backoff and
retry so a transient Redis fault never permanently stops delivery.
Launched as a background task by :meth:`start`; not called elsewhere.
"""
while self._running:
try:
messages = await self._redis.xreadgroup(
"stargazer_gateways",
self._consumer_name,
self._streams,
count=10,
block=5000,
)
if not messages:
continue
for stream_name, entries in messages:
for msg_id, raw in entries:
msg_id_str = msg_id.decode() if isinstance(msg_id, bytes) else str(msg_id)
task = asyncio.create_task(
self._handle_message(msg_id_str, raw, stream_name),
name=f"outbound_msg_{msg_id_str}",
)
self._active_tasks.add(task)
task.add_done_callback(self._active_tasks.discard)
except asyncio.CancelledError:
break
except Exception:
logger.critical(
"Outbound consumer loop crashed, restarting in 2s",
exc_info=True,
)
await asyncio.sleep(2)
async def _handle_message(self, msg_id: str, raw: dict, stream_name: bytes | str) -> None:
"""Deliver one outbound message to the platform adapter, then acknowledge.
Per-message worker for the gateway. It normalizes ``stream_name`` to a
string, deserializes the envelope and outbound payload, and awaits the
injected ``process_fn`` -- which routes the send to the platform adapter
in ``gateway_main.py`` -- then sends ``XACK`` to the originating outbound
stream under the ``stargazer_gateways`` group so the entry is not
redelivered. On any failure it logs the error and forwards the entry to
:func:`core.dlq.handle_failed_message` (attempt count fixed at 1, since
the outbound consumer does not track per-message retries). Reaches Redis
for the ``XACK`` and DLQ. Spawned as a task by :meth:`_consume_loop`.
Args:
msg_id: Redis Streams entry ID of the outbound message.
raw: The raw field-value mapping read from the stream.
stream_name: Outbound stream the entry came from, used as the ``XACK``
target; accepted as bytes or str.
"""
stream_name_str = stream_name.decode() if isinstance(stream_name, bytes) else str(stream_name)
try:
from core.event_types import deserialize_envelope
env = deserialize_envelope(raw)
if env is None:
raise ValueError("Unsupported schema version in envelope")
payload = deserialize_stream_payload(raw)
# The outbound payload structure handles method routing inside `gateway_main.py`
await self._process_fn(payload)
await self._redis.xack(stream_name_str, "stargazer_gateways", msg_id)
except Exception as exc:
logger.error(
"Failed to process outbound message",
exc_info=True,
extra={"stream_msg_id": msg_id},
)
await handle_failed_message(
self._redis, stream_name_str, "stargazer_gateways",
msg_id, raw, exc, 1,
)