Source code for tools

"""Extensible tool-calling framework.

Register tools with the ``@registry.tool`` decorator. Each tool is an async
callable that receives keyword arguments and returns a string result.

Example usage::

    from tools import ToolRegistry

    registry = ToolRegistry()

    @registry.tool(
        name="get_weather",
        description="Get the current weather for a location.",
        parameters={
            "type": "object",
            "properties": {
                "location": {
                    "type": "string",
                    "description": "City and state, e.g. 'San Francisco, CA'",
                },
            },
            "required": ["location"],
        },
    )
    async def get_weather(location: str) -> str:
        return f"The weather in {location} is sunny and 72°F."
"""

from __future__ import annotations

import inspect
import json
import logging
from dataclasses import dataclass, field
from typing import Any, Callable, Awaitable, TYPE_CHECKING

if TYPE_CHECKING:
    from task_manager import TaskManager
    from tool_context import ToolContext

logger = logging.getLogger(__name__)


[docs] @dataclass class ToolDefinition: """Internal representation of a registered tool.""" name: str description: str parameters: dict[str, Any] handler: Callable[..., Awaitable[str]] no_background: bool = False allow_repeat: bool = False
[docs] class ToolRegistry: """Registry that stores tool definitions and executes tool calls."""
[docs] def __init__( self, task_manager: TaskManager | None = None, ) -> None: """Initialize the instance. Args: task_manager (TaskManager | None): The task manager value. """ self._tools: dict[str, ToolDefinition] = {} self._permissions: dict[str, list[str]] = {} self.task_manager: TaskManager | None = task_manager # Cached OpenAI-format tool lists, rebuilt lazily on first access # after any mutation of ``_tools``. self._openai_tools_cache: list[dict[str, Any]] | None = None self._openai_tools_by_name: dict[str, dict[str, Any]] | None = None
# ------------------------------------------------------------------ # Permissions # ------------------------------------------------------------------
[docs] def set_permissions(self, permissions: dict[str, list[str]]) -> None: """Set per-tool user whitelists. *permissions* maps tool names to lists of allowed user IDs. A special value ``"*"`` in the list means *everyone*. Tools **not** present in the dict are allowed for all users. """ self._permissions = dict(permissions)
[docs] def is_allowed(self, tool_name: str, user_id: str) -> bool: """Return ``True`` if *user_id* may execute *tool_name*. Rules: 1. Tool not in the permissions dict -> allow (backward compatible). 2. ``"*"`` in the tool's allowed list -> allow. 3. *user_id* in the tool's allowed list -> allow. 4. Otherwise -> deny. """ allowed = self._permissions.get(tool_name) if allowed is None: return True if "*" in allowed: return True return user_id in allowed
# ------------------------------------------------------------------ # Registration # ------------------------------------------------------------------
[docs] def tool( self, name: str, description: str, parameters: dict[str, Any], ) -> Callable: """Decorator to register an async function as a tool. Parameters ---------- name: The tool name exposed to the LLM. description: Human-readable description of what the tool does. parameters: JSON Schema object describing the tool's parameters. """ def decorator(fn: Callable[..., Awaitable[str]]) -> Callable[..., Awaitable[str]]: """Decorator. Args: fn (Callable[..., Awaitable[str]]): The fn value. Returns: Callable[..., Awaitable[str]]: The result. """ self._tools[name] = ToolDefinition( name=name, description=description, parameters=parameters, handler=fn, ) self.invalidate_cache() return fn return decorator
# ------------------------------------------------------------------ # Execution # ------------------------------------------------------------------
[docs] async def call( self, name: str, arguments: dict[str, Any], user_id: str = "", ctx: ToolContext | None = None, ) -> str: """Execute a registered tool by name and return the result string. If *user_id* is provided, the tool's permission whitelist is checked first. If *ctx* is provided **and** the handler declares a ``ctx`` parameter, the context is injected automatically. If the tool raises an exception the error message is returned so the LLM can see what went wrong and recover. """ tool_def = self._tools.get(name) if tool_def is None: error_msg = f"Unknown tool: {name}" logger.error(error_msg) return error_msg # Permission check if user_id and not self.is_allowed(name, user_id): error_msg = ( f"Permission denied: user '{user_id}' is not allowed " f"to run tool '{name}'." ) logger.warning(error_msg) return error_msg try: arguments = _filter_arguments(tool_def.handler, arguments, name) if ctx is not None and user_id: err = await _resolve_secret_refs(arguments, user_id, ctx) if err is not None: return err if ctx is not None and _handler_accepts_ctx(tool_def.handler): coro = tool_def.handler(**arguments, ctx=ctx) else: coro = tool_def.handler(**arguments) if self.task_manager is not None and not tool_def.no_background: return await self.task_manager.execute( coro, tool_name=name, user_id=user_id, channel_id=ctx.channel_id if ctx else "", platform=ctx.platform if ctx else "", ) return str(await coro) except Exception as exc: logger.exception("Tool %r failed", name) return ( f"Tool '{name}' raised {type(exc).__name__}. " "See server logs for details." )
# ------------------------------------------------------------------ # Schema export & caching # ------------------------------------------------------------------
[docs] def invalidate_cache(self) -> None: """Clear the cached OpenAI tool representations. Called automatically when tools are registered via the ``tool`` decorator. Must also be called explicitly after bulk mutations such as ``_tools.clear()`` (e.g. in ``reload_tools``). """ self._openai_tools_cache = None self._openai_tools_by_name = None
def _ensure_cache(self) -> None: """Rebuild the OpenAI-format caches if stale.""" if self._openai_tools_cache is not None: return tools = [ { "type": "function", "function": { "name": td.name, "description": td.description, "parameters": td.parameters, }, } for td in self._tools.values() ] self._openai_tools_cache = tools self._openai_tools_by_name = { t["function"]["name"]: t for t in tools }
[docs] def get_openai_tools(self) -> list[dict[str, Any]]: """Return tool definitions in the OpenAI function-calling JSON format. Returns an empty list when no tools are registered, which means the ``tools`` parameter can be omitted from the API call. """ self._ensure_cache() return list(self._openai_tools_cache)
[docs] def get_openai_tools_by_names( self, names: set[str], ) -> list[dict[str, Any]]: """Return only the OpenAI tool dicts whose names are in *names*. Uses a cached ``dict`` for O(1) per-name lookup instead of rebuilding and filtering the full list each time. """ self._ensure_cache() by_name = self._openai_tools_by_name return [by_name[n] for n in names if n in by_name]
[docs] def list_tools(self) -> list[ToolDefinition]: """Return all registered tool definitions.""" return list(self._tools.values())
[docs] def repeat_allowed_tools(self) -> frozenset[str]: """Return names of tools that are exempt from repetition-loop detection.""" return frozenset( td.name for td in self._tools.values() if td.allow_repeat )
@property def has_tools(self) -> bool: """Check whether has tools. Returns: bool: True on success, False otherwise. """ return len(self._tools) > 0
[docs] def __len__(self) -> int: """Internal helper: len . Returns: int: The result. """ return len(self._tools)
# ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ # Cache so we only inspect each handler once. _ctx_cache: dict[int, bool] = {} _sig_cache: dict[int, inspect.Signature | None] = {} def _filter_arguments( handler: Callable, arguments: dict[str, Any], tool_name: str, ) -> dict[str, Any]: """Strip keyword arguments that *handler* does not accept. If the handler accepts ``**kwargs`` all arguments are passed through. Otherwise only recognised parameter names are kept, and any dropped keys are logged as a warning so hallucinated params don't crash the call. """ key = id(handler) sig = _sig_cache.get(key, ...) # sentinel to distinguish "not cached" from "None" if sig is ...: try: sig = inspect.signature(handler) except (ValueError, TypeError): sig = None _sig_cache[key] = sig if sig is None: return arguments # can't introspect — pass everything through # If the handler has a **kwargs catch-all, everything is fine. if any( p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values() ): return arguments valid_names = set(sig.parameters.keys()) - {"ctx"} unexpected = set(arguments.keys()) - valid_names if not unexpected: return arguments logger.warning( "Tool '%s': dropping unexpected argument(s) %s", tool_name, ", ".join(sorted(unexpected)), ) return {k: v for k, v in arguments.items() if k in valid_names} def _handler_accepts_ctx(handler: Callable) -> bool: """Return *True* if *handler* has a ``ctx`` parameter.""" key = id(handler) cached = _ctx_cache.get(key) if cached is not None: return cached try: sig = inspect.signature(handler) result = "ctx" in sig.parameters except (ValueError, TypeError): result = False _ctx_cache[key] = result return result async def _resolve_secret_refs( arguments: dict[str, Any], user_id: str, ctx: "ToolContext", ) -> str | None: """Resolve secret:name refs in string arguments. Return error string if any fail, else None.""" from tools.manage_secrets import SECRET_REF_PREFIX, resolve_user_secret for key, val in list(arguments.items()): if isinstance(val, str) and val.startswith(SECRET_REF_PREFIX): name = val[len(SECRET_REF_PREFIX) :].strip() if not name: return "Invalid secret reference: secret: must be followed by a name." resolved = await resolve_user_secret( user_id, name, redis_client=ctx.redis, config=getattr(ctx, "config", None), ) if resolved is None: return f"Secret '{name}' not found or could not be decrypted." arguments[key] = resolved return None