Source code for core.tools_consumer

"""Tool-execution stream consumer for the dedicated ``tools`` service.

``ToolExecConsumer`` mirrors :class:`core.stream_consumer.InboundStreamConsumer`
— blocking ``XREADGROUP`` on ``sg:stream:tools`` under the ``sg:tools`` group,
an autoclaim sweep, and DLQ on repeated failure — but with two deliberate
differences:

* **No per-channel distributed lock.** Tool execution must run fully in parallel
  and load-balance across ``tools`` instances; user-visible message ordering is
  already enforced downstream by the gateway's per-channel outbound lock.
* The injected ``process_fn`` is expected to **always reply** (it pushes the
  result, or an error envelope, onto the caller's reply stream) and to swallow
  tool-level exceptions. Only infrastructure failures propagate to the DLQ.
"""

from __future__ import annotations

import asyncio
import logging
from typing import Any, Awaitable, Callable

from redis.asyncio import Redis

from core.dlq import handle_failed_message
from core.event_bus import TOOLS_GROUP, TOOLS_STREAM
from core.serialization import deserialize_stream_payload

logger = logging.getLogger("stargazer.tools_consumer")

_AUTOCLAIM_ATTEMPTS_KEY = "sg:autoclaim:tools:attempts"
# A consumer idle at least this long with zero pending is treated as a dead
# worker's leftover registration and reaped from the group. Comfortably beyond
# the autoclaim window so a briefly-paused live worker is never removed.
_REAP_IDLE_MS = 600_000  # 10 minutes


[docs] class ToolExecConsumer: """Consumes tool-execution requests and dispatches them to *process_fn*.""" def __init__( self, redis: Redis, consumer_name: str, process_fn: Callable[[dict[str, Any]], Awaitable[None]], autoclaim_interval: float = 30.0, autoclaim_min_idle: int = 60000, ) -> None: self._redis = redis self._consumer_name = consumer_name self._process_fn = process_fn self._autoclaim_interval = autoclaim_interval self._autoclaim_min_idle = autoclaim_min_idle self._running = False self._task: asyncio.Task | None = None self._autoclaim_task: asyncio.Task | None = None self._attempt_counts: dict[str, int] = {} self._active_tasks: set[asyncio.Task] = set()
[docs] async def start(self) -> None: self._running = True self._task = asyncio.create_task(self._consume_loop(), name="tools_consumer") self._autoclaim_task = asyncio.create_task( self._autoclaim_loop(), name="tools_autoclaim" ) logger.info( "ToolExecConsumer started", extra={"consumer_name": self._consumer_name, "stream": TOOLS_STREAM}, )
[docs] async def stop(self) -> None: self._running = False if self._task: self._task.cancel() if self._autoclaim_task: self._autoclaim_task.cancel() if self._active_tasks: logger.info("Awaiting %d active tool-exec tasks...", len(self._active_tasks)) await asyncio.gather(*self._active_tasks, return_exceptions=True) # Drained: remove our own consumer from the group so it does not linger # as a phantom entry (skipped if anything is still pending, so no message # is dropped — autoclaim on another worker will reclaim it). await self._delconsumer_self() logger.info("ToolExecConsumer stopped")
async def _consumer_pending(self, name: str) -> int: """Pending count for a named consumer, or -1 if unknown.""" try: for c in await self._redis.xinfo_consumers(TOOLS_STREAM, TOOLS_GROUP): nm = c.get(b"name") or c.get("name") nm = nm.decode() if isinstance(nm, bytes) else nm if nm == name: return int(c.get(b"pending") or c.get("pending") or 0) except Exception: logger.debug("xinfo_consumers failed", exc_info=True) return -1 async def _delconsumer_self(self) -> None: pending = await self._consumer_pending(self._consumer_name) if pending == 0: try: await self._redis.xgroup_delconsumer(TOOLS_STREAM, TOOLS_GROUP, self._consumer_name) logger.info("Removed own consumer %s from %s on shutdown", self._consumer_name, TOOLS_GROUP) except Exception: logger.debug("delconsumer self failed", exc_info=True) elif pending > 0: logger.warning( "Leaving consumer %s in %s on shutdown: %d pending (autoclaim will reclaim)", self._consumer_name, TOOLS_GROUP, pending, ) async def _reap_idle_consumers(self) -> None: """Delete other consumers idle past the threshold with no pending — the leftover registrations of crashed workers that never ran shutdown.""" try: consumers = await self._redis.xinfo_consumers(TOOLS_STREAM, TOOLS_GROUP) except Exception: return for c in consumers: nm = c.get(b"name") or c.get("name") nm = nm.decode() if isinstance(nm, bytes) else nm if nm == self._consumer_name: continue idle = int(c.get(b"idle") or c.get("idle") or 0) pending = int(c.get(b"pending") or c.get("pending") or 0) if idle >= _REAP_IDLE_MS and pending == 0: try: await self._redis.xgroup_delconsumer(TOOLS_STREAM, TOOLS_GROUP, nm) logger.info("Reaped stale tools consumer %s (idle %dms, 0 pending)", nm, idle) except Exception: logger.debug("reap delconsumer failed for %s", nm, exc_info=True) async def _consume_loop(self) -> None: while self._running: try: messages = await self._redis.xreadgroup( TOOLS_GROUP, self._consumer_name, {TOOLS_STREAM: ">"}, count=10, block=5000, ) if not messages: continue for _stream_name, entries in messages: for msg_id, raw in entries: msg_id_str = msg_id.decode() if isinstance(msg_id, bytes) else str(msg_id) task = asyncio.create_task( self._handle(msg_id_str, raw), name=f"tool_exec_{msg_id_str}", ) self._active_tasks.add(task) task.add_done_callback(self._active_tasks.discard) except asyncio.CancelledError: break except Exception: logger.critical("Tools consumer loop crashed, restarting in 2s", exc_info=True) await asyncio.sleep(2) async def _handle(self, msg_id: str, raw: dict) -> None: # Idempotency: if no longer pending, another instance handled it. try: pending = await self._redis.xpending_range(TOOLS_STREAM, TOOLS_GROUP, msg_id, msg_id, 1) if not pending: self._attempt_counts.pop(msg_id, None) return except Exception: logger.debug("tools idempotency check failed; processing", exc_info=True) attempt = self._attempt_counts.get(msg_id, 0) + 1 self._attempt_counts[msg_id] = attempt try: payload = deserialize_stream_payload(raw) await self._process_fn(payload) await self._redis.xack(TOOLS_STREAM, TOOLS_GROUP, msg_id) self._attempt_counts.pop(msg_id, None) except Exception as exc: logger.error( "Failed to process tool-exec request", exc_info=True, extra={"stream_msg_id": msg_id, "attempt": attempt}, ) await handle_failed_message( self._redis, TOOLS_STREAM, TOOLS_GROUP, msg_id, raw, exc, attempt ) async def _autoclaim_loop(self) -> None: while self._running: try: await asyncio.sleep(self._autoclaim_interval) await self._reap_idle_consumers() result = await self._redis.xautoclaim( TOOLS_STREAM, TOOLS_GROUP, self._consumer_name, min_idle_time=self._autoclaim_min_idle, start_id="0-0", count=10, ) if not (result and len(result) >= 2): continue claimed = result[1] if not claimed: continue logger.warning("Autoclaimed %d orphaned tool-exec requests", len(claimed)) for msg_id, raw in claimed: msg_id_str = msg_id.decode() if isinstance(msg_id, bytes) else str(msg_id) attempts = await self._redis.hincrby(_AUTOCLAIM_ATTEMPTS_KEY, msg_id_str, 1) if attempts > 3: logger.warning( "Quarantining toxic tool-exec request after %d attempts", attempts, extra={"stream_msg_id": msg_id_str}, ) await handle_failed_message( self._redis, TOOLS_STREAM, TOOLS_GROUP, msg_id_str, raw, ValueError(f"Max claim attempts exceeded: {attempts}"), attempts, ) await self._redis.xack(TOOLS_STREAM, TOOLS_GROUP, msg_id_str) await self._redis.hdel(_AUTOCLAIM_ATTEMPTS_KEY, msg_id_str) continue task = asyncio.create_task( self._handle(msg_id_str, raw), name=f"tool_exec_reclaim_{msg_id_str}" ) self._active_tasks.add(task) task.add_done_callback(self._active_tasks.discard) except asyncio.CancelledError: break except Exception: logger.error("Tools autoclaim loop error", exc_info=True)