Source code for core.remote_tool_registry
"""Inference-side stand-in for ``ToolRegistry`` that delegates tool execution.
``RemoteToolRegistry`` duck-types the read surface of :class:`tools.ToolRegistry`
(OpenAI schemas, ``tool_names``, ``repeat_allowed_tools``, ``is_allowed`` …) from
the Redis-published tool catalog, so the inference tier can present tools to the
LLM and rank them without importing any handler module. Its ``call()`` routes
each invocation through three tiers:
1. ``GATEWAY_PINNED_TOOLS`` -> delegate to the gateway (live discord.py client).
2. ``INFERENCE_PINNED_TOOLS`` (and nested calls, and ``tools_force_in_process``)
-> run in-process via the small local registry the inference service loads.
3. everything else -> delegate to the dedicated ``tools`` service over
``sg:stream:tools``; the result + write-back (``sent_files`` media bytes,
``injected_tools``, ``sent_rich_messages``) come back on a per-worker reply
stream and are merged onto the live ``ctx`` so the LLM loop is unchanged.
In ``tools_service_mode == "in_process"`` (the safe default) every call runs on
the local registry, exactly as before the split. The remote path is opt-in.
Drop-in: the executor and transport only touch the registry's read surface plus
``call()``, so swapping ``ToolRegistry`` for this requires no changes there.
"""
from __future__ import annotations
import asyncio
import logging
import uuid
from typing import Any, Optional
from core.event_bus import RedisEventBus, TOOLS_REPLY_STREAM_PREFIX
from core.gateway_pinned import GATEWAY_PINNED_TOOLS, INFERENCE_PINNED_TOOLS
from core.serialization import deserialize_stream_payload
from core.tool_catalog import (
CATALOG_PUBSUB_CHANNEL,
load_catalog,
openai_tool_dict,
)
logger = logging.getLogger("stargazer.remote_tool_registry")
[docs]
class RemoteToolRegistry:
"""Catalog-backed, delegation-capable replacement for ``ToolRegistry``.
Args:
event_bus: bus used to publish tool-exec requests.
redis: raw (``decode_responses=False``) async client — used to read the
catalog and consume the per-worker reply stream.
config: live :class:`config.Config` (routing mode, timeouts, overrides).
local_registry: a real :class:`tools.ToolRegistry` for in-process tiers
(gateway-pinned delegation needs ``ctx.adapter``; inference-pinned +
``in_process`` mode run handlers here). During rollout it may hold
every tool; in the lean end state only ``INFERENCE_PINNED_TOOLS``.
worker_id: stable id for this inference worker (reply-stream name).
"""
def __init__(
self,
*,
event_bus: RedisEventBus,
redis: Any,
config: Any,
local_registry: Any,
worker_id: str,
) -> None:
self._bus = event_bus
self._redis = redis
self._cfg = config
self._local = local_registry
self._worker_id = worker_id
self._reply_stream = f"{TOOLS_REPLY_STREAM_PREFIX}:{worker_id}"
self._reply_group = f"{self._reply_stream}:grp"
# Catalog-derived read surface (None until loaded; falls back to local).
self._catalog: Optional[dict[str, Any]] = None
self._by_name: dict[str, dict[str, Any]] = {}
self._perms: dict[str, list[str]] = {}
self._repeat: frozenset[str] = frozenset()
self._schema_hash: str = ""
self._pending: dict[str, asyncio.Future] = {}
self._tasks: list[asyncio.Task] = []
self._running = False
# ── lifecycle ──────────────────────────────────────────────────
[docs]
async def start(self) -> None:
"""Load the catalog and start the reply reader + catalog watcher.
In ``in_process`` mode this is a no-op: no catalog is read (the read
surface stays delegated to the local registry) and no background tasks
are spawned, so the inference worker behaves exactly as before the split.
"""
self._running = True
if getattr(self._cfg, "tools_service_mode", "in_process") == "in_process":
return
await self.reload_catalog()
try:
await self._redis.xgroup_create(
self._reply_stream, self._reply_group, id="$", mkstream=True
)
except Exception:
pass # group already exists
self._tasks = [
asyncio.create_task(self._reply_reader_loop(), name=f"tools-reply-{self._worker_id}"),
asyncio.create_task(self._catalog_watch_loop(), name="tools-catalog-watch"),
]
[docs]
async def stop(self) -> None:
self._running = False
for t in self._tasks:
t.cancel()
for t in self._tasks:
try:
await t
except (asyncio.CancelledError, Exception):
pass
self._tasks = []
for fut in self._pending.values():
if not fut.done():
fut.cancel()
self._pending.clear()
[docs]
async def reload_catalog(self) -> bool:
"""(Re)load the catalog from Redis. Returns True if a catalog was found."""
cat = await load_catalog(self._redis)
if cat is None:
logger.warning("No tool catalog published yet; using local registry read surface.")
self._catalog = None
return False
self._catalog = cat
self._by_name = {t["name"]: t for t in cat.get("tools", [])}
self._perms = {k: list(v) for k, v in (cat.get("permissions") or {}).items()}
self._repeat = frozenset(
t["name"] for t in cat.get("tools", []) if t.get("allow_repeat")
)
self._schema_hash = cat.get("schema_hash", "")
# Keep the local registry's permission view in sync for in-process tiers.
try:
self._local.set_permissions(self._perms)
except Exception:
pass
logger.info(
"Loaded tool catalog v%s (%d tools)", cat.get("version"), len(self._by_name)
)
return True
# ── read surface (duck-types ToolRegistry) ─────────────────────
@property
def _use_catalog(self) -> bool:
return self._catalog is not None
[docs]
def get_openai_tools(self) -> list[dict[str, Any]]:
if self._use_catalog:
return [openai_tool_dict(t) for t in self._by_name.values()]
return self._local.get_openai_tools()
[docs]
def get_openai_tools_by_names(self, names: set[str]) -> list[dict[str, Any]]:
if self._use_catalog:
return [openai_tool_dict(self._by_name[n]) for n in names if n in self._by_name]
return self._local.get_openai_tools_by_names(names)
[docs]
def tool_names(self) -> frozenset[str]:
if self._use_catalog:
return frozenset(self._by_name.keys())
return self._local.tool_names()
[docs]
def repeat_allowed_tools(self) -> frozenset[str]:
if self._use_catalog:
return self._repeat
return self._local.repeat_allowed_tools()
[docs]
def list_tools(self) -> list[Any]:
if self._use_catalog:
from tools import ToolDefinition
def _stub(*_a: Any, **_k: Any) -> Any:
raise RuntimeError("RemoteToolRegistry tool handlers run remotely; use call().")
return [
ToolDefinition(
name=t["name"],
description=t.get("description", ""),
parameters=t.get("parameters", {}),
handler=_stub,
no_background=bool(t.get("no_background", False)),
allow_repeat=bool(t.get("allow_repeat", False)),
)
for t in self._by_name.values()
]
return self._local.list_tools()
[docs]
def set_permissions(self, permissions: dict[str, list[str]]) -> None:
self._perms = dict(permissions)
try:
self._local.set_permissions(permissions)
except Exception:
pass
[docs]
def is_allowed(self, tool_name: str, user_id: str) -> bool:
allowed = self._perms.get(tool_name)
if allowed is None:
return True
if "*" in allowed:
return True
return user_id in allowed
@property
def has_tools(self) -> bool:
if self._use_catalog:
return len(self._by_name) > 0
return self._local.has_tools
def __len__(self) -> int:
if self._use_catalog:
return len(self._by_name)
return len(self._local)
# TaskManager passthrough so background-offload semantics survive locally.
@property
def task_manager(self) -> Any:
return getattr(self._local, "task_manager", None)
@task_manager.setter
def task_manager(self, value: Any) -> None:
try:
self._local.task_manager = value
except Exception:
pass
# ── execution: three-tier routing ──────────────────────────────
[docs]
async def call(
self,
name: str,
arguments: dict[str, Any],
user_id: str = "",
ctx: Any = None,
*,
nested: bool = False,
) -> str:
mode = getattr(self._cfg, "tools_service_mode", "in_process")
force_local = name in set(getattr(self._cfg, "tools_force_in_process", []) or [])
# Permission fast-fail from the catalog (UX only; the authoritative check
# is re-run service-side). Mirrors ToolRegistry.call's denial string.
if user_id and not self.is_allowed(name, user_id):
return (
f"Permission denied: user '{user_id}' is not allowed to run tool '{name}'."
)
# Nested inner calls (compound tools) and in_process mode always run on
# the local registry, which itself handles gateway-pinned delegation.
if nested or mode == "in_process":
return await self._local.call(name, arguments, user_id=user_id, ctx=ctx, nested=nested)
# Remote / shadow mode.
if name in GATEWAY_PINNED_TOOLS:
return await self._delegate_to_gateway(name, arguments, ctx)
if name in INFERENCE_PINNED_TOOLS or force_local:
return await self._local.call(name, arguments, user_id=user_id, ctx=ctx, nested=nested)
try:
return await self._remote_call(name, arguments, user_id, ctx)
except Exception as exc:
logger.warning("Remote tool exec failed for %r: %s", name, exc)
if getattr(self._cfg, "tools_local_fallback", True) and name in self._local.tool_names():
logger.info("Falling back to in-process execution for %r", name)
return await self._local.call(name, arguments, user_id=user_id, ctx=ctx, nested=nested)
return f"Tool '{name}' is temporarily unavailable (tools service error). Try again."
async def _delegate_to_gateway(self, name: str, arguments: dict[str, Any], ctx: Any) -> str:
"""Replicate ToolRegistry.call's gateway-pinned delegation (lines 277-293)."""
delegate = getattr(getattr(ctx, "adapter", None), "delegate_to_gateway", None)
if delegate is None:
return f"Tool '{name}' requires the gateway client, which is unavailable."
return await delegate(
"execute_tool",
tool_name=name,
tool_args=arguments,
tool_ctx={
"platform": ctx.platform,
"channel_id": ctx.channel_id,
"user_id": ctx.user_id,
"user_name": ctx.user_name,
"guild_id": ctx.guild_id,
"message_id": ctx.message_id,
},
)
async def _remote_call(self, name: str, arguments: dict[str, Any], user_id: str, ctx: Any) -> str:
correlation_id = uuid.uuid4().hex
trace_id = getattr(ctx, "observability_request_id", "") or correlation_id
envelope = {
"type": "tool_exec",
"tool_name": name,
"tool_args": arguments,
"user_id": user_id,
"correlation_id": correlation_id,
"idem_key": f"{trace_id}:{correlation_id}",
"reply_to": self._reply_stream,
"schema_hash": self._schema_hash,
"ctx": self._ctx_payload(ctx, trace_id),
}
# Write the AUTHENTICATED identity to a session record keyed by trace_id
# BEFORE publishing, so the tools service resolves identity from there
# rather than trusting the envelope (anti-privilege-escalation).
from core.tool_session import write_session
await write_session(
self._redis,
trace_id,
{
"user_id": getattr(ctx, "user_id", "") if ctx is not None else user_id,
"guild_id": getattr(ctx, "guild_id", "") if ctx is not None else "",
"channel_id": getattr(ctx, "channel_id", "") if ctx is not None else "",
"platform": getattr(ctx, "platform", "") if ctx is not None else "",
},
)
loop = asyncio.get_event_loop()
fut: asyncio.Future = loop.create_future()
self._pending[correlation_id] = fut
timeout = float(getattr(self._cfg, "tools_exec_timeout", 120.0))
try:
await self._bus.publish_tools_request(envelope)
reply = await asyncio.wait_for(fut, timeout=timeout)
finally:
self._pending.pop(correlation_id, None)
# Stale catalog: reload and retry once with the fresh schema hash.
if reply.get("status") == "CATALOG_STALE":
await self.reload_catalog()
return await self._remote_call(name, arguments, user_id, ctx)
if reply.get("error"):
return str(reply["error"])
self._merge_writeback(ctx, reply.get("writeback") or {})
return str(reply.get("result", ""))
@staticmethod
def _ctx_payload(ctx: Any, trace_id: str) -> dict[str, Any]:
if ctx is None:
return {"trace_id": trace_id}
return {
"platform": ctx.platform,
"channel_id": ctx.channel_id,
"user_id": ctx.user_id,
"user_name": ctx.user_name,
"guild_id": ctx.guild_id,
"message_id": ctx.message_id,
"observability_request_id": ctx.observability_request_id,
"disclosed_skill_ids": list(ctx.disclosed_skill_ids or []),
"room_context": ctx.room_context,
"trace_id": trace_id,
}
@staticmethod
def _merge_writeback(ctx: Any, wb: dict[str, Any]) -> None:
"""Append the remote call's write-back channels onto the live ctx.
Mirrors what the in-process tool would have mutated, so the executor's
post-gather reads (sent_files -> multimodal parts, injected_tools ->
active tool list) and generate_and_send's sent_rich_messages all work.
"""
if ctx is None:
return
sent_files = wb.get("sent_files") or []
if sent_files:
ctx.sent_files.extend(sent_files)
injected = wb.get("injected_tools") or []
if injected:
if ctx.injected_tools is None:
ctx.injected_tools = []
ctx.injected_tools.extend(injected)
rich = wb.get("sent_rich_messages") or []
if rich:
ctx.sent_rich_messages.extend(rich)
# ── background loops ───────────────────────────────────────────
async def _reply_reader_loop(self) -> None:
while self._running:
try:
msgs = await self._redis.xreadgroup(
self._reply_group,
self._worker_id,
{self._reply_stream: ">"},
count=20,
block=5000,
)
except Exception:
await asyncio.sleep(0.5)
continue
if not msgs:
continue
for _stream, entries in msgs:
for msg_id, raw in entries:
try:
payload = deserialize_stream_payload(raw)
cid = payload.get("correlation_id")
fut = self._pending.get(cid)
if fut is not None and not fut.done():
fut.set_result(payload)
except Exception:
logger.debug("bad tool reply entry", exc_info=True)
finally:
try:
await self._redis.xack(self._reply_stream, self._reply_group, msg_id)
except Exception:
pass
async def _catalog_watch_loop(self) -> None:
try:
pubsub = self._redis.pubsub()
await pubsub.subscribe(CATALOG_PUBSUB_CHANNEL)
except Exception:
logger.debug("catalog pubsub subscribe failed", exc_info=True)
return
try:
async for msg in pubsub.listen():
if not self._running:
break
if msg.get("type") != "message":
continue
await self.reload_catalog()
except asyncio.CancelledError:
pass
except Exception:
logger.debug("catalog watch loop error", exc_info=True)
finally:
try:
await pubsub.unsubscribe(CATALOG_PUBSUB_CHANNEL)
await pubsub.aclose()
except Exception:
pass