Source code for classifiers.build_tool_index

#!/usr/bin/env python3
"""Build the tool index used by the vector classifier.

Auto-discovers every registered tool via :mod:`tool_loader`, then
calls an LLM to generate 50 diverse synthetic user queries per tool
(reverse-HyDE).  Results are saved to ``tool_index_data.json`` in
this directory.

Usage::

    python -m classifiers.build_tool_index [--tools-dir tools]
"""

from __future__ import annotations

import asyncio
import argparse
import json
import logging
import os
import re
import sys
from typing import Any

import httpx

# Ensure the project root is importable when running as a script.
sys.path.insert(
    0,
    os.path.abspath(
        os.path.join(os.path.dirname(__file__), ".."),
    ),
)

from tools import ToolRegistry  # noqa: E402
from tool_loader import load_tools  # noqa: E402

logging.basicConfig(
    level=logging.INFO,
    format="%(levelname)s: %(message)s",
)
logger = logging.getLogger(__name__)

OUTPUT_FILE = os.path.join(
    os.path.dirname(__file__), "tool_index_data.json",
)

SYNTHETIC_QUERY_COUNT = 50

REVERSE_HYDE_PROMPT = """\
You are an expert prompt engineer building a semantic search \
index for an AI assistant's tool-calling system.

## Tool under analysis

- **Name:** {tool_name}
- **Description:** {tool_description}

## Your task

Generate exactly {count} synthetic user messages that a human \
would write in a chat with an AI assistant when they *need* \
this tool -- even if they don't know the tool exists.  The \
messages must be realistic, varied, and collectively span the \
full semantic surface area of the tool.

Distribute the {count} messages across the following \
categories.  Exact counts per category are guidelines -- some \
tools may naturally skew toward certain categories -- but \
every category MUST have at least one entry.

### Categories

1. **Direct commands** (~5): Explicit imperative sentences.
   "Run ...", "Execute ...", "Do ...", "Set up ..."
2. **Natural-language requests** (~8): Polite or \
conversational requests that clearly imply the tool.
   "Could you ...", "I'd like you to ...", "Please ..."
3. **Vague / ambiguous intents** (~8): The user's goal is \
clear but no specific tool is named; this tool is one \
plausible match.
4. **Questions** (~5): "Can you ...?", "Is it possible \
to ...?", "How do I ...?"
5. **Contextual / embedded** (~5): The request is buried \
inside a longer message giving background.
   "I was working on X earlier and now I need ..."
6. **Multi-step scenarios** (~5): The user describes a \
workflow where this tool is one step.
   "First do A, then B, and finally ..."
7. **Error / frustration** (~4): The user is stuck or \
something broke and they need this tool to fix it.
   "X isn't working ...", "I keep getting ...", \
"Help, I can't ..."
8. **Follow-up / conversational** (~5): Short messages \
that only make sense in context.
   "Actually, also ...", "And then ...", \
"One more thing -- ..."
9. **Domain-specific jargon** (~5): Technical or domain \
language that implies the tool without naming it.

## Constraints

- Do NOT mention the tool name in more than 5 of the 50 \
messages.
- Focus on *user intent*, not tool internals.
- Vary sentence length (short single-clause to \
multi-sentence).
- Include at least 5 messages that are <= 8 words.
- Return a JSON object with a single key `"queries"` whose \
value is the array of strings.  No markdown fences, no \
extra keys, no explanation.

Example (abbreviated):
{{"queries": ["find the file", \
"search for documents matching sales Q3", \
"where is that report I uploaded yesterday"]}}
"""


_GEMINI_GENERATE_BASE = (
    "https://generativelanguage.googleapis.com/v1beta/models"
)
_DEFAULT_GENERATE_MODEL = "gemini-3.1-flash-lite-preview"


[docs] async def generate_synthetic_queries( client: httpx.AsyncClient, base_url: str | None, api_key: str | None, tool_name: str, tool_description: str, count: int = SYNTHETIC_QUERY_COUNT, model: str = _DEFAULT_GENERATE_MODEL, ) -> list[str]: """Call the Gemini API directly to produce *count* synthetic queries. Uses the shared flash-lite key pool for authentication. *base_url* and *api_key* are accepted for backward compatibility but ignored -- all calls go through the Gemini API. """ from gemini_embed_pool import ( is_daily_quota_429, mark_key_daily_spent, next_gemini_flash_key, ) prompt = REVERSE_HYDE_PROMPT.format( tool_name=tool_name, tool_description=tool_description, count=count, ) gemini_key = next_gemini_flash_key() url = ( f"{_GEMINI_GENERATE_BASE}/{model}" f":generateContent?key={gemini_key}" ) payload = { "contents": [{"parts": [{"text": prompt}]}], "systemInstruction": { "parts": [{"text": "You are a helpful assistant that outputs strict JSON."}], }, "generationConfig": { "temperature": 0.8, "maxOutputTokens": 16000, "responseMimeType": "application/json", }, } max_attempts = 20 for attempt in range(max_attempts): if attempt > 0: gemini_key = next_gemini_flash_key() url = ( f"{_GEMINI_GENERATE_BASE}/{model}" f":generateContent?key={gemini_key}" ) delay = 1.0 if attempt <= 13 else min(2.0 ** (attempt - 13), 30.0) await asyncio.sleep(delay) try: resp = await client.post(url, json=payload) if resp.status_code in (429, 500, 502, 503, 504): if resp.status_code == 429 and is_daily_quota_429(resp): await mark_key_daily_spent(gemini_key, "generate") logger.debug( "HTTP %d for %s (key ...%s), rotating key", resp.status_code, tool_name, gemini_key[-8:], ) continue resp.raise_for_status() resp_data = resp.json() candidates = resp_data.get("candidates", []) if not candidates: logger.warning("No candidates for %s", tool_name) return [] parts = ( candidates[0] .get("content", {}) .get("parts", []) ) content = parts[0].get("text", "") if parts else "" if not content: logger.warning("Empty content for %s", tool_name) return [] content = content.strip() content = re.sub( r"<thinking>.*?</thinking>\s*", "", content, flags=re.DOTALL, ) if content.startswith("```json"): content = content[7:] if content.startswith("```"): content = content[3:] if content.endswith("```"): content = content[:-3] content = content.strip() if not content: logger.warning( "No JSON payload after stripping for %s", tool_name, ) return [] data = json.loads(content) if isinstance(data, dict): queries = data.get("queries") if not isinstance(queries, list): for v in data.values(): if isinstance(v, list): queries = v break else: queries = [] elif isinstance(data, list): queries = data else: queries = [] return [ str(q) for q in queries if isinstance(q, (str, int)) ] except Exception as exc: logger.error( "Error generating queries for %s: %s", tool_name, exc, ) if attempt == max_attempts - 1: return [] return []
[docs] async def build_index(tools_dir: str = "tools") -> None: """Discover tools and generate synthetic queries.""" logger.info( "--- Auto-discovering tools from %s ---", tools_dir, ) registry = ToolRegistry() load_tools(tools_dir, registry) all_tools = registry.list_tools() logger.info("Found %d tools.", len(all_tools)) client = httpx.AsyncClient( timeout=httpx.Timeout(120.0, connect=10.0), ) index_data: dict[str, Any] = {} if os.path.exists(OUTPUT_FILE): try: with open( OUTPUT_FILE, "r", encoding="utf-8", ) as f: index_data = json.load(f) logger.info( "Loaded existing data with %d entries.", len(index_data), ) except Exception: pass logger.info( "--- Generating %d queries per tool ---", SYNTHETIC_QUERY_COUNT, ) sem = asyncio.Semaphore(3) async def process_tool(tool: Any) -> None: """Process tool. Args: tool (Any): The tool value. """ if tool.name in index_data: existing = index_data[tool.name].get( "synthetic_queries", [], ) if len(existing) >= SYNTHETIC_QUERY_COUNT: logger.info( "Skipping %s (%d queries)", tool.name, len(existing), ) return logger.info( "Regenerating %s (%d queries)", tool.name, len(existing), ) async with sem: logger.info("Generating for: %s", tool.name) queries = await generate_synthetic_queries( client, None, None, tool.name, tool.description or "", ) if queries: index_data[tool.name] = { "name": tool.name, "description": tool.description, "synthetic_queries": queries, } logger.info( "Finished %s (%d queries)", tool.name, len(queries), ) else: logger.warning( "Failed to generate for %s", tool.name, ) tasks = [process_tool(t) for t in all_tools] await asyncio.gather(*tasks) logger.info("--- Saving to %s ---", OUTPUT_FILE) with open(OUTPUT_FILE, "w", encoding="utf-8") as f: json.dump(index_data, f, indent=2) logger.info("Done! %d tools indexed.", len(index_data))
if __name__ == "__main__": parser = argparse.ArgumentParser( description=( "Build tool index (synthetic queries) " "for vector classifier" ), ) parser.add_argument( "--tools-dir", default="tools", help="Directory containing tool scripts", ) args = parser.parse_args() asyncio.run(build_index(tools_dir=args.tools_dir))