Source code for core.remote_tool_registry

"""Inference-side stand-in for ``ToolRegistry`` that delegates tool execution.

``RemoteToolRegistry`` duck-types the read surface of :class:`tools.ToolRegistry`
(OpenAI schemas, ``tool_names``, ``repeat_allowed_tools``, ``is_allowed`` …) from
the Redis-published tool catalog, so the inference tier can present tools to the
LLM and rank them without importing any handler module. Its ``call()`` routes
each invocation through three tiers:

1. ``GATEWAY_PINNED_TOOLS``   -> delegate to the gateway (live discord.py client).
2. ``INFERENCE_PINNED_TOOLS`` (and nested calls, and ``tools_force_in_process``)
   -> run in-process via the small local registry the inference service loads.
3. everything else            -> delegate to the dedicated ``tools`` service over
   ``sg:stream:tools``; the result + write-back (``sent_files`` media bytes,
   ``injected_tools``, ``sent_rich_messages``) come back on a per-worker reply
   stream and are merged onto the live ``ctx`` so the LLM loop is unchanged.

In ``tools_service_mode == "in_process"`` (the safe default) every call runs on
the local registry, exactly as before the split. The remote path is opt-in.

Drop-in: the executor and transport only touch the registry's read surface plus
``call()``, so swapping ``ToolRegistry`` for this requires no changes there.
"""

from __future__ import annotations

import asyncio
import logging
import uuid
from typing import Any, Optional

from core.event_bus import RedisEventBus, TOOLS_REPLY_STREAM_PREFIX
from core.gateway_pinned import GATEWAY_PINNED_TOOLS, INFERENCE_PINNED_TOOLS
from core.serialization import deserialize_stream_payload
from core.tool_catalog import (
    CATALOG_PUBSUB_CHANNEL,
    load_catalog,
    openai_tool_dict,
)

logger = logging.getLogger("stargazer.remote_tool_registry")


[docs] class RemoteToolRegistry: """Catalog-backed, delegation-capable replacement for ``ToolRegistry``. Args: event_bus: bus used to publish tool-exec requests. redis: raw (``decode_responses=False``) async client — used to read the catalog and consume the per-worker reply stream. config: live :class:`config.Config` (routing mode, timeouts, overrides). local_registry: a real :class:`tools.ToolRegistry` for in-process tiers (gateway-pinned delegation needs ``ctx.adapter``; inference-pinned + ``in_process`` mode run handlers here). During rollout it may hold every tool; in the lean end state only ``INFERENCE_PINNED_TOOLS``. worker_id: stable id for this inference worker (reply-stream name). """ def __init__( self, *, event_bus: RedisEventBus, redis: Any, config: Any, local_registry: Any, worker_id: str, ) -> None: self._bus = event_bus self._redis = redis self._cfg = config self._local = local_registry self._worker_id = worker_id self._reply_stream = f"{TOOLS_REPLY_STREAM_PREFIX}:{worker_id}" self._reply_group = f"{self._reply_stream}:grp" # Catalog-derived read surface (None until loaded; falls back to local). self._catalog: Optional[dict[str, Any]] = None self._by_name: dict[str, dict[str, Any]] = {} self._perms: dict[str, list[str]] = {} self._repeat: frozenset[str] = frozenset() self._schema_hash: str = "" self._pending: dict[str, asyncio.Future] = {} self._tasks: list[asyncio.Task] = [] self._running = False # ── lifecycle ──────────────────────────────────────────────────
[docs] async def start(self) -> None: """Load the catalog and start the reply reader + catalog watcher. In ``in_process`` mode this is a no-op: no catalog is read (the read surface stays delegated to the local registry) and no background tasks are spawned, so the inference worker behaves exactly as before the split. """ self._running = True if getattr(self._cfg, "tools_service_mode", "in_process") == "in_process": return await self.reload_catalog() try: await self._redis.xgroup_create( self._reply_stream, self._reply_group, id="$", mkstream=True ) except Exception: pass # group already exists self._tasks = [ asyncio.create_task(self._reply_reader_loop(), name=f"tools-reply-{self._worker_id}"), asyncio.create_task(self._catalog_watch_loop(), name="tools-catalog-watch"), ]
[docs] async def stop(self) -> None: self._running = False for t in self._tasks: t.cancel() for t in self._tasks: try: await t except (asyncio.CancelledError, Exception): pass self._tasks = [] for fut in self._pending.values(): if not fut.done(): fut.cancel() self._pending.clear()
[docs] async def reload_catalog(self) -> bool: """(Re)load the catalog from Redis. Returns True if a catalog was found.""" cat = await load_catalog(self._redis) if cat is None: logger.warning("No tool catalog published yet; using local registry read surface.") self._catalog = None return False self._catalog = cat self._by_name = {t["name"]: t for t in cat.get("tools", [])} self._perms = {k: list(v) for k, v in (cat.get("permissions") or {}).items()} self._repeat = frozenset( t["name"] for t in cat.get("tools", []) if t.get("allow_repeat") ) self._schema_hash = cat.get("schema_hash", "") # Keep the local registry's permission view in sync for in-process tiers. try: self._local.set_permissions(self._perms) except Exception: pass logger.info( "Loaded tool catalog v%s (%d tools)", cat.get("version"), len(self._by_name) ) return True
# ── read surface (duck-types ToolRegistry) ───────────────────── @property def _use_catalog(self) -> bool: return self._catalog is not None
[docs] def get_openai_tools(self) -> list[dict[str, Any]]: if self._use_catalog: return [openai_tool_dict(t) for t in self._by_name.values()] return self._local.get_openai_tools()
[docs] def get_openai_tools_by_names(self, names: set[str]) -> list[dict[str, Any]]: if self._use_catalog: return [openai_tool_dict(self._by_name[n]) for n in names if n in self._by_name] return self._local.get_openai_tools_by_names(names)
[docs] def tool_names(self) -> frozenset[str]: if self._use_catalog: return frozenset(self._by_name.keys()) return self._local.tool_names()
[docs] def repeat_allowed_tools(self) -> frozenset[str]: if self._use_catalog: return self._repeat return self._local.repeat_allowed_tools()
[docs] def list_tools(self) -> list[Any]: if self._use_catalog: from tools import ToolDefinition def _stub(*_a: Any, **_k: Any) -> Any: raise RuntimeError("RemoteToolRegistry tool handlers run remotely; use call().") return [ ToolDefinition( name=t["name"], description=t.get("description", ""), parameters=t.get("parameters", {}), handler=_stub, no_background=bool(t.get("no_background", False)), allow_repeat=bool(t.get("allow_repeat", False)), ) for t in self._by_name.values() ] return self._local.list_tools()
[docs] def set_permissions(self, permissions: dict[str, list[str]]) -> None: self._perms = dict(permissions) try: self._local.set_permissions(permissions) except Exception: pass
[docs] def is_allowed(self, tool_name: str, user_id: str) -> bool: allowed = self._perms.get(tool_name) if allowed is None: return True if "*" in allowed: return True return user_id in allowed
@property def has_tools(self) -> bool: if self._use_catalog: return len(self._by_name) > 0 return self._local.has_tools def __len__(self) -> int: if self._use_catalog: return len(self._by_name) return len(self._local) # TaskManager passthrough so background-offload semantics survive locally. @property def task_manager(self) -> Any: return getattr(self._local, "task_manager", None) @task_manager.setter def task_manager(self, value: Any) -> None: try: self._local.task_manager = value except Exception: pass # ── execution: three-tier routing ──────────────────────────────
[docs] async def call( self, name: str, arguments: dict[str, Any], user_id: str = "", ctx: Any = None, *, nested: bool = False, ) -> str: mode = getattr(self._cfg, "tools_service_mode", "in_process") force_local = name in set(getattr(self._cfg, "tools_force_in_process", []) or []) # Permission fast-fail from the catalog (UX only; the authoritative check # is re-run service-side). Mirrors ToolRegistry.call's denial string. if user_id and not self.is_allowed(name, user_id): return ( f"Permission denied: user '{user_id}' is not allowed to run tool '{name}'." ) # Nested inner calls (compound tools) and in_process mode always run on # the local registry, which itself handles gateway-pinned delegation. if nested or mode == "in_process": return await self._local.call(name, arguments, user_id=user_id, ctx=ctx, nested=nested) # Remote / shadow mode. if name in GATEWAY_PINNED_TOOLS: return await self._delegate_to_gateway(name, arguments, ctx) if name in INFERENCE_PINNED_TOOLS or force_local: return await self._local.call(name, arguments, user_id=user_id, ctx=ctx, nested=nested) try: return await self._remote_call(name, arguments, user_id, ctx) except Exception as exc: logger.warning("Remote tool exec failed for %r: %s", name, exc) if getattr(self._cfg, "tools_local_fallback", True) and name in self._local.tool_names(): logger.info("Falling back to in-process execution for %r", name) return await self._local.call(name, arguments, user_id=user_id, ctx=ctx, nested=nested) return f"Tool '{name}' is temporarily unavailable (tools service error). Try again."
async def _delegate_to_gateway(self, name: str, arguments: dict[str, Any], ctx: Any) -> str: """Replicate ToolRegistry.call's gateway-pinned delegation (lines 277-293).""" delegate = getattr(getattr(ctx, "adapter", None), "delegate_to_gateway", None) if delegate is None: return f"Tool '{name}' requires the gateway client, which is unavailable." return await delegate( "execute_tool", tool_name=name, tool_args=arguments, tool_ctx={ "platform": ctx.platform, "channel_id": ctx.channel_id, "user_id": ctx.user_id, "user_name": ctx.user_name, "guild_id": ctx.guild_id, "message_id": ctx.message_id, }, ) async def _remote_call(self, name: str, arguments: dict[str, Any], user_id: str, ctx: Any) -> str: correlation_id = uuid.uuid4().hex trace_id = getattr(ctx, "observability_request_id", "") or correlation_id envelope = { "type": "tool_exec", "tool_name": name, "tool_args": arguments, "user_id": user_id, "correlation_id": correlation_id, "idem_key": f"{trace_id}:{correlation_id}", "reply_to": self._reply_stream, "schema_hash": self._schema_hash, "ctx": self._ctx_payload(ctx, trace_id), } # Write the AUTHENTICATED identity to a session record keyed by trace_id # BEFORE publishing, so the tools service resolves identity from there # rather than trusting the envelope (anti-privilege-escalation). from core.tool_session import write_session await write_session( self._redis, trace_id, { "user_id": getattr(ctx, "user_id", "") if ctx is not None else user_id, "guild_id": getattr(ctx, "guild_id", "") if ctx is not None else "", "channel_id": getattr(ctx, "channel_id", "") if ctx is not None else "", "platform": getattr(ctx, "platform", "") if ctx is not None else "", }, ) loop = asyncio.get_event_loop() fut: asyncio.Future = loop.create_future() self._pending[correlation_id] = fut timeout = float(getattr(self._cfg, "tools_exec_timeout", 120.0)) try: await self._bus.publish_tools_request(envelope) reply = await asyncio.wait_for(fut, timeout=timeout) finally: self._pending.pop(correlation_id, None) # Stale catalog: reload and retry once with the fresh schema hash. if reply.get("status") == "CATALOG_STALE": await self.reload_catalog() return await self._remote_call(name, arguments, user_id, ctx) if reply.get("error"): return str(reply["error"]) self._merge_writeback(ctx, reply.get("writeback") or {}) return str(reply.get("result", "")) @staticmethod def _ctx_payload(ctx: Any, trace_id: str) -> dict[str, Any]: if ctx is None: return {"trace_id": trace_id} return { "platform": ctx.platform, "channel_id": ctx.channel_id, "user_id": ctx.user_id, "user_name": ctx.user_name, "guild_id": ctx.guild_id, "message_id": ctx.message_id, "observability_request_id": ctx.observability_request_id, "disclosed_skill_ids": list(ctx.disclosed_skill_ids or []), "room_context": ctx.room_context, "trace_id": trace_id, } @staticmethod def _merge_writeback(ctx: Any, wb: dict[str, Any]) -> None: """Append the remote call's write-back channels onto the live ctx. Mirrors what the in-process tool would have mutated, so the executor's post-gather reads (sent_files -> multimodal parts, injected_tools -> active tool list) and generate_and_send's sent_rich_messages all work. """ if ctx is None: return sent_files = wb.get("sent_files") or [] if sent_files: ctx.sent_files.extend(sent_files) injected = wb.get("injected_tools") or [] if injected: if ctx.injected_tools is None: ctx.injected_tools = [] ctx.injected_tools.extend(injected) rich = wb.get("sent_rich_messages") or [] if rich: ctx.sent_rich_messages.extend(rich) # ── background loops ─────────────────────────────────────────── async def _reply_reader_loop(self) -> None: while self._running: try: msgs = await self._redis.xreadgroup( self._reply_group, self._worker_id, {self._reply_stream: ">"}, count=20, block=5000, ) except Exception: await asyncio.sleep(0.5) continue if not msgs: continue for _stream, entries in msgs: for msg_id, raw in entries: try: payload = deserialize_stream_payload(raw) cid = payload.get("correlation_id") fut = self._pending.get(cid) if fut is not None and not fut.done(): fut.set_result(payload) except Exception: logger.debug("bad tool reply entry", exc_info=True) finally: try: await self._redis.xack(self._reply_stream, self._reply_group, msg_id) except Exception: pass async def _catalog_watch_loop(self) -> None: try: pubsub = self._redis.pubsub() await pubsub.subscribe(CATALOG_PUBSUB_CHANNEL) except Exception: logger.debug("catalog pubsub subscribe failed", exc_info=True) return try: async for msg in pubsub.listen(): if not self._running: break if msg.get("type") != "message": continue await self.reload_catalog() except asyncio.CancelledError: pass except Exception: logger.debug("catalog watch loop error", exc_info=True) finally: try: await pubsub.unsubscribe(CATALOG_PUBSUB_CHANNEL) await pubsub.aclose() except Exception: pass