#!/usr/bin/env python3
"""Build the tool index used by the vector classifier.
Auto-discovers every registered tool via :mod:`tool_loader`, then
calls an LLM to generate 50 diverse synthetic user queries per tool
(reverse-HyDE). Results are saved to ``tool_index_data.json`` in
this directory.
Usage::
python -m classifiers.build_tool_index [--tools-dir tools]
"""
from __future__ import annotations
import asyncio
import argparse
import json
import logging
import os
import re
import sys
from typing import Any
import httpx
# Ensure the project root is importable when running as a script.
sys.path.insert(
0,
os.path.abspath(
os.path.join(os.path.dirname(__file__), ".."),
),
)
from tools import ToolRegistry # noqa: E402
from tool_loader import load_tools # noqa: E402
logging.basicConfig(
level=logging.INFO,
format="%(levelname)s: %(message)s",
)
logger = logging.getLogger(__name__)
OUTPUT_FILE = os.path.join(
os.path.dirname(__file__), "tool_index_data.json",
)
SYNTHETIC_QUERY_COUNT = 50
REVERSE_HYDE_PROMPT = """\
You are an expert prompt engineer building a semantic search \
index for an AI assistant's tool-calling system.
## Tool under analysis
- **Name:** {tool_name}
- **Description:** {tool_description}
## Your task
Generate exactly {count} synthetic user messages that a human \
would write in a chat with an AI assistant when they *need* \
this tool -- even if they don't know the tool exists. The \
messages must be realistic, varied, and collectively span the \
full semantic surface area of the tool.
Distribute the {count} messages across the following \
categories. Exact counts per category are guidelines -- some \
tools may naturally skew toward certain categories -- but \
every category MUST have at least one entry.
### Categories
1. **Direct commands** (~5): Explicit imperative sentences.
"Run ...", "Execute ...", "Do ...", "Set up ..."
2. **Natural-language requests** (~8): Polite or \
conversational requests that clearly imply the tool.
"Could you ...", "I'd like you to ...", "Please ..."
3. **Vague / ambiguous intents** (~8): The user's goal is \
clear but no specific tool is named; this tool is one \
plausible match.
4. **Questions** (~5): "Can you ...?", "Is it possible \
to ...?", "How do I ...?"
5. **Contextual / embedded** (~5): The request is buried \
inside a longer message giving background.
"I was working on X earlier and now I need ..."
6. **Multi-step scenarios** (~5): The user describes a \
workflow where this tool is one step.
"First do A, then B, and finally ..."
7. **Error / frustration** (~4): The user is stuck or \
something broke and they need this tool to fix it.
"X isn't working ...", "I keep getting ...", \
"Help, I can't ..."
8. **Follow-up / conversational** (~5): Short messages \
that only make sense in context.
"Actually, also ...", "And then ...", \
"One more thing -- ..."
9. **Domain-specific jargon** (~5): Technical or domain \
language that implies the tool without naming it.
## Constraints
- Do NOT mention the tool name in more than 5 of the 50 \
messages.
- Focus on *user intent*, not tool internals.
- Vary sentence length (short single-clause to \
multi-sentence).
- Include at least 5 messages that are <= 8 words.
- Return a JSON object with a single key `"queries"` whose \
value is the array of strings. No markdown fences, no \
extra keys, no explanation.
Example (abbreviated):
{{"queries": ["find the file", \
"search for documents matching sales Q3", \
"where is that report I uploaded yesterday"]}}
"""
_GEMINI_GENERATE_BASE = (
"https://generativelanguage.googleapis.com/v1beta/models"
)
_DEFAULT_GENERATE_MODEL = "gemini-3.1-flash-lite-preview"
[docs]
async def generate_synthetic_queries(
client: httpx.AsyncClient,
base_url: str | None,
api_key: str | None,
tool_name: str,
tool_description: str,
count: int = SYNTHETIC_QUERY_COUNT,
model: str = _DEFAULT_GENERATE_MODEL,
) -> list[str]:
"""Call the Gemini API directly to produce *count* synthetic queries.
Uses the shared flash-lite key pool for authentication.
*base_url* and *api_key* are accepted for backward compatibility
but ignored -- all calls go through the Gemini API.
"""
from gemini_embed_pool import (
is_daily_quota_429,
mark_key_daily_spent,
next_gemini_flash_key,
)
prompt = REVERSE_HYDE_PROMPT.format(
tool_name=tool_name,
tool_description=tool_description,
count=count,
)
gemini_key = next_gemini_flash_key()
url = (
f"{_GEMINI_GENERATE_BASE}/{model}"
f":generateContent?key={gemini_key}"
)
payload = {
"contents": [{"parts": [{"text": prompt}]}],
"systemInstruction": {
"parts": [{"text": "You are a helpful assistant that outputs strict JSON."}],
},
"generationConfig": {
"temperature": 0.8,
"maxOutputTokens": 16000,
"responseMimeType": "application/json",
},
}
max_attempts = 20
for attempt in range(max_attempts):
if attempt > 0:
gemini_key = next_gemini_flash_key()
url = (
f"{_GEMINI_GENERATE_BASE}/{model}"
f":generateContent?key={gemini_key}"
)
delay = 1.0 if attempt <= 13 else min(2.0 ** (attempt - 13), 30.0)
await asyncio.sleep(delay)
try:
resp = await client.post(url, json=payload)
if resp.status_code in (429, 500, 502, 503, 504):
if resp.status_code == 429 and is_daily_quota_429(resp):
await mark_key_daily_spent(gemini_key, "generate")
logger.debug(
"HTTP %d for %s (key ...%s), rotating key",
resp.status_code, tool_name, gemini_key[-8:],
)
continue
resp.raise_for_status()
resp_data = resp.json()
candidates = resp_data.get("candidates", [])
if not candidates:
logger.warning("No candidates for %s", tool_name)
return []
parts = (
candidates[0]
.get("content", {})
.get("parts", [])
)
content = parts[0].get("text", "") if parts else ""
if not content:
logger.warning("Empty content for %s", tool_name)
return []
content = content.strip()
content = re.sub(
r"<thinking>.*?</thinking>\s*",
"", content, flags=re.DOTALL,
)
if content.startswith("```json"):
content = content[7:]
if content.startswith("```"):
content = content[3:]
if content.endswith("```"):
content = content[:-3]
content = content.strip()
if not content:
logger.warning(
"No JSON payload after stripping for %s", tool_name,
)
return []
data = json.loads(content)
if isinstance(data, dict):
queries = data.get("queries")
if not isinstance(queries, list):
for v in data.values():
if isinstance(v, list):
queries = v
break
else:
queries = []
elif isinstance(data, list):
queries = data
else:
queries = []
return [
str(q) for q in queries
if isinstance(q, (str, int))
]
except Exception as exc:
logger.error(
"Error generating queries for %s: %s",
tool_name, exc,
)
if attempt == max_attempts - 1:
return []
return []
[docs]
async def build_index(tools_dir: str = "tools") -> None:
"""Discover tools and generate synthetic queries."""
logger.info(
"--- Auto-discovering tools from %s ---",
tools_dir,
)
registry = ToolRegistry()
load_tools(tools_dir, registry)
all_tools = registry.list_tools()
logger.info("Found %d tools.", len(all_tools))
client = httpx.AsyncClient(
timeout=httpx.Timeout(120.0, connect=10.0),
)
index_data: dict[str, Any] = {}
if os.path.exists(OUTPUT_FILE):
try:
with open(
OUTPUT_FILE, "r", encoding="utf-8",
) as f:
index_data = json.load(f)
logger.info(
"Loaded existing data with %d entries.",
len(index_data),
)
except Exception:
pass
logger.info(
"--- Generating %d queries per tool ---",
SYNTHETIC_QUERY_COUNT,
)
sem = asyncio.Semaphore(3)
async def process_tool(tool: Any) -> None:
"""Process tool.
Args:
tool (Any): The tool value.
"""
if tool.name in index_data:
existing = index_data[tool.name].get(
"synthetic_queries", [],
)
if len(existing) >= SYNTHETIC_QUERY_COUNT:
logger.info(
"Skipping %s (%d queries)",
tool.name, len(existing),
)
return
logger.info(
"Regenerating %s (%d queries)",
tool.name, len(existing),
)
async with sem:
logger.info("Generating for: %s", tool.name)
queries = await generate_synthetic_queries(
client, None, None,
tool.name,
tool.description or "",
)
if queries:
index_data[tool.name] = {
"name": tool.name,
"description": tool.description,
"synthetic_queries": queries,
}
logger.info(
"Finished %s (%d queries)",
tool.name, len(queries),
)
else:
logger.warning(
"Failed to generate for %s",
tool.name,
)
tasks = [process_tool(t) for t in all_tools]
await asyncio.gather(*tasks)
logger.info("--- Saving to %s ---", OUTPUT_FILE)
with open(OUTPUT_FILE, "w", encoding="utf-8") as f:
json.dump(index_data, f, indent=2)
logger.info("Done! %d tools indexed.", len(index_data))
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description=(
"Build tool index (synthetic queries) "
"for vector classifier"
),
)
parser.add_argument(
"--tools-dir",
default="tools",
help="Directory containing tool scripts",
)
args = parser.parse_args()
asyncio.run(build_index(tools_dir=args.tools_dir))