Source code for tools
"""Extensible tool-calling framework.
Register tools with the ``@registry.tool`` decorator. Each tool is an async
callable that receives keyword arguments and returns a string result.
Example usage::
from tools import ToolRegistry
registry = ToolRegistry()
@registry.tool(
name="get_weather",
description="Get the current weather for a location.",
parameters={
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "City and state, e.g. 'San Francisco, CA'",
},
},
"required": ["location"],
},
)
async def get_weather(location: str) -> str:
return f"The weather in {location} is sunny and 72°F."
"""
from __future__ import annotations
import inspect
import json
import logging
from dataclasses import dataclass, field
from typing import Any, Callable, Awaitable, TYPE_CHECKING
if TYPE_CHECKING:
from task_manager import TaskManager
from tool_context import ToolContext
logger = logging.getLogger(__name__)
[docs]
@dataclass
class ToolDefinition:
"""Internal representation of a registered tool."""
name: str
description: str
parameters: dict[str, Any]
handler: Callable[..., Awaitable[str]]
no_background: bool = False
allow_repeat: bool = False
[docs]
class ToolRegistry:
"""Registry that stores tool definitions and executes tool calls."""
[docs]
def __init__(
self,
task_manager: TaskManager | None = None,
) -> None:
"""Initialize the instance.
Args:
task_manager (TaskManager | None): The task manager value.
"""
self._tools: dict[str, ToolDefinition] = {}
self._permissions: dict[str, list[str]] = {}
self.task_manager: TaskManager | None = task_manager
# Cached OpenAI-format tool lists, rebuilt lazily on first access
# after any mutation of ``_tools``.
self._openai_tools_cache: list[dict[str, Any]] | None = None
self._openai_tools_by_name: dict[str, dict[str, Any]] | None = None
# ------------------------------------------------------------------
# Permissions
# ------------------------------------------------------------------
[docs]
def set_permissions(self, permissions: dict[str, list[str]]) -> None:
"""Set per-tool user whitelists.
*permissions* maps tool names to lists of allowed user IDs.
A special value ``"*"`` in the list means *everyone*.
Tools **not** present in the dict are allowed for all users.
"""
self._permissions = dict(permissions)
[docs]
def is_allowed(self, tool_name: str, user_id: str) -> bool:
"""Return ``True`` if *user_id* may execute *tool_name*.
Rules:
1. Tool not in the permissions dict -> allow (backward compatible).
2. ``"*"`` in the tool's allowed list -> allow.
3. *user_id* in the tool's allowed list -> allow.
4. Otherwise -> deny.
"""
allowed = self._permissions.get(tool_name)
if allowed is None:
return True
if "*" in allowed:
return True
return user_id in allowed
# ------------------------------------------------------------------
# Registration
# ------------------------------------------------------------------
[docs]
def tool(
self,
name: str,
description: str,
parameters: dict[str, Any],
) -> Callable:
"""Decorator to register an async function as a tool.
Parameters
----------
name:
The tool name exposed to the LLM.
description:
Human-readable description of what the tool does.
parameters:
JSON Schema object describing the tool's parameters.
"""
def decorator(fn: Callable[..., Awaitable[str]]) -> Callable[..., Awaitable[str]]:
"""Decorator.
Args:
fn (Callable[..., Awaitable[str]]): The fn value.
Returns:
Callable[..., Awaitable[str]]: The result.
"""
self._tools[name] = ToolDefinition(
name=name,
description=description,
parameters=parameters,
handler=fn,
)
self.invalidate_cache()
return fn
return decorator
# ------------------------------------------------------------------
# Execution
# ------------------------------------------------------------------
[docs]
async def call(
self,
name: str,
arguments: dict[str, Any],
user_id: str = "",
ctx: ToolContext | None = None,
) -> str:
"""Execute a registered tool by name and return the result string.
If *user_id* is provided, the tool's permission whitelist is
checked first. If *ctx* is provided **and** the handler declares
a ``ctx`` parameter, the context is injected automatically.
If the tool raises an exception the error message is returned so
the LLM can see what went wrong and recover.
"""
tool_def = self._tools.get(name)
if tool_def is None:
error_msg = f"Unknown tool: {name}"
logger.error(error_msg)
return error_msg
# Permission check
if user_id and not self.is_allowed(name, user_id):
error_msg = (
f"Permission denied: user '{user_id}' is not allowed "
f"to run tool '{name}'."
)
logger.warning(error_msg)
return error_msg
try:
arguments = _filter_arguments(tool_def.handler, arguments, name)
if ctx is not None and user_id:
err = await _resolve_secret_refs(arguments, user_id, ctx)
if err is not None:
return err
if ctx is not None and _handler_accepts_ctx(tool_def.handler):
coro = tool_def.handler(**arguments, ctx=ctx)
else:
coro = tool_def.handler(**arguments)
if self.task_manager is not None and not tool_def.no_background:
return await self.task_manager.execute(
coro, tool_name=name, user_id=user_id,
channel_id=ctx.channel_id if ctx else "",
platform=ctx.platform if ctx else "",
)
return str(await coro)
except Exception as exc:
logger.exception("Tool %r failed", name)
return (
f"Tool '{name}' raised {type(exc).__name__}. "
"See server logs for details."
)
# ------------------------------------------------------------------
# Schema export & caching
# ------------------------------------------------------------------
[docs]
def invalidate_cache(self) -> None:
"""Clear the cached OpenAI tool representations.
Called automatically when tools are registered via the ``tool``
decorator. Must also be called explicitly after bulk mutations
such as ``_tools.clear()`` (e.g. in ``reload_tools``).
"""
self._openai_tools_cache = None
self._openai_tools_by_name = None
def _ensure_cache(self) -> None:
"""Rebuild the OpenAI-format caches if stale."""
if self._openai_tools_cache is not None:
return
tools = [
{
"type": "function",
"function": {
"name": td.name,
"description": td.description,
"parameters": td.parameters,
},
}
for td in self._tools.values()
]
self._openai_tools_cache = tools
self._openai_tools_by_name = {
t["function"]["name"]: t for t in tools
}
[docs]
def get_openai_tools(self) -> list[dict[str, Any]]:
"""Return tool definitions in the OpenAI function-calling JSON format.
Returns an empty list when no tools are registered, which means the
``tools`` parameter can be omitted from the API call.
"""
self._ensure_cache()
return list(self._openai_tools_cache)
[docs]
def get_openai_tools_by_names(
self,
names: set[str],
) -> list[dict[str, Any]]:
"""Return only the OpenAI tool dicts whose names are in *names*.
Uses a cached ``dict`` for O(1) per-name lookup instead of
rebuilding and filtering the full list each time.
"""
self._ensure_cache()
by_name = self._openai_tools_by_name
return [by_name[n] for n in names if n in by_name]
[docs]
def list_tools(self) -> list[ToolDefinition]:
"""Return all registered tool definitions."""
return list(self._tools.values())
[docs]
def repeat_allowed_tools(self) -> frozenset[str]:
"""Return names of tools that are exempt from repetition-loop detection."""
return frozenset(
td.name for td in self._tools.values() if td.allow_repeat
)
@property
def has_tools(self) -> bool:
"""Check whether has tools.
Returns:
bool: True on success, False otherwise.
"""
return len(self._tools) > 0
[docs]
def __len__(self) -> int:
"""Internal helper: len .
Returns:
int: The result.
"""
return len(self._tools)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
# Cache so we only inspect each handler once.
_ctx_cache: dict[int, bool] = {}
_sig_cache: dict[int, inspect.Signature | None] = {}
def _filter_arguments(
handler: Callable,
arguments: dict[str, Any],
tool_name: str,
) -> dict[str, Any]:
"""Strip keyword arguments that *handler* does not accept.
If the handler accepts ``**kwargs`` all arguments are passed through.
Otherwise only recognised parameter names are kept, and any dropped
keys are logged as a warning so hallucinated params don't crash the
call.
"""
key = id(handler)
sig = _sig_cache.get(key, ...) # sentinel to distinguish "not cached" from "None"
if sig is ...:
try:
sig = inspect.signature(handler)
except (ValueError, TypeError):
sig = None
_sig_cache[key] = sig
if sig is None:
return arguments # can't introspect — pass everything through
# If the handler has a **kwargs catch-all, everything is fine.
if any(
p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values()
):
return arguments
valid_names = set(sig.parameters.keys()) - {"ctx"}
unexpected = set(arguments.keys()) - valid_names
if not unexpected:
return arguments
logger.warning(
"Tool '%s': dropping unexpected argument(s) %s",
tool_name,
", ".join(sorted(unexpected)),
)
return {k: v for k, v in arguments.items() if k in valid_names}
def _handler_accepts_ctx(handler: Callable) -> bool:
"""Return *True* if *handler* has a ``ctx`` parameter."""
key = id(handler)
cached = _ctx_cache.get(key)
if cached is not None:
return cached
try:
sig = inspect.signature(handler)
result = "ctx" in sig.parameters
except (ValueError, TypeError):
result = False
_ctx_cache[key] = result
return result
async def _resolve_secret_refs(
arguments: dict[str, Any],
user_id: str,
ctx: "ToolContext",
) -> str | None:
"""Resolve secret:name refs in string arguments. Return error string if any fail, else None."""
from tools.manage_secrets import SECRET_REF_PREFIX, resolve_user_secret
for key, val in list(arguments.items()):
if isinstance(val, str) and val.startswith(SECRET_REF_PREFIX):
name = val[len(SECRET_REF_PREFIX) :].strip()
if not name:
return "Invalid secret reference: secret: must be followed by a name."
resolved = await resolve_user_secret(
user_id,
name,
redis_client=ctx.redis,
config=getattr(ctx, "config", None),
)
if resolved is None:
return f"Secret '{name}' not found or could not be decrypted."
arguments[key] = resolved
return None