Source code for classifiers.refresh_tool_embeddings

#!/usr/bin/env python3
"""Refresh embeddings for tools whose descriptions have changed.

Compares each registered tool's live description against the stored
metadata in Redis.  Tools with mismatched descriptions are "stale"
and get their synthetic queries regenerated, re-embedded, and
overwritten in Redis.

Run ``python -m classifiers.refresh_tool_embeddings`` (see ``--help``) with
optional arguments:

- ``--force`` — Re-embed every tool regardless of whether its description
  changed (useful after switching embedding models).
- ``--tools`` — Comma-separated list of specific tool names to refresh.
- ``--tools-dir`` — Tool scripts directory (default: ``tools``).
"""

from __future__ import annotations

import argparse
import asyncio
import json
import logging
import os
import sys
from typing import Any

import numpy as np
import redis.asyncio as aioredis
import httpx

sys.path.insert(
    0,
    os.path.abspath(
        os.path.join(os.path.dirname(__file__), ".."),
    ),
)

from tools import ToolRegistry  # noqa: E402
from tool_loader import load_tools  # noqa: E402
from rag_system.openrouter_embeddings import (  # noqa: E402
    OpenRouterEmbeddings,
)
from classifiers.vector_classifier import (  # noqa: E402
    TOOL_EMBEDDINGS_HASH_KEY,
    TOOL_METADATA_HASH_KEY,
)
from classifiers.build_tool_index import (  # noqa: E402
    generate_synthetic_queries,
)
from classifiers.update_tool_embeddings import (  # noqa: E402
    compute_tool_embedding,
    load_index_file,
    save_index_file,
    discover_tools,
)

logging.basicConfig(
    level=logging.INFO,
    format="%(levelname)s: %(message)s",
)
logger = logging.getLogger(__name__)


[docs] async def find_stale_tools( registered: dict[str, Any], redis_client: aioredis.Redis, ) -> list[str]: """Return names of tools whose descriptions differ from Redis.""" meta_raw: dict[str, str] = await redis_client.hgetall( TOOL_METADATA_HASH_KEY, ) stale: list[str] = [] for name, tool_def in registered.items(): stored = meta_raw.get(name) if stored is None: continue try: meta = json.loads(stored) except (json.JSONDecodeError, TypeError): stale.append(name) continue live_desc = getattr(tool_def, "description", "") or "" stored_desc = meta.get("description", "") if live_desc != stored_desc: stale.append(name) return stale
[docs] async def refresh_tool_embeddings( *, force: bool = False, tool_names: list[str] | None = None, tools_dir: str = "tools", ) -> bool: """Main refresh routine.""" logger.info("=" * 60) logger.info("Tool Embeddings Refresh") logger.info("=" * 60) from config import Config cfg = Config.load() api_key = cfg.api_key if not api_key: logger.error("api_key required (config.yaml or API_KEY env var)") return False redis_url = cfg.redis_url or os.environ.get( "REDIS_URL", "redis://localhost:6379/0", ) redis_client = aioredis.from_url( redis_url, decode_responses=True, **cfg.redis_ssl_kwargs(), ) try: logger.info("Discovering registered tools...") registered = discover_tools(tools_dir) logger.info(" Found %d registered tools", len(registered)) if tool_names: targets = set(tool_names) & set(registered.keys()) unknown = set(tool_names) - set(registered.keys()) if unknown: logger.warning( "Unknown tools (skipped): %s", ", ".join(sorted(unknown)), ) if not targets: logger.error("No valid tool names to refresh") return False targets_list = sorted(targets) elif force: existing_keys = await redis_client.hkeys( TOOL_EMBEDDINGS_HASH_KEY, ) existing_names = { k.decode("utf-8") if isinstance(k, bytes) else k for k in existing_keys } targets_list = sorted( set(registered.keys()) & existing_names, ) else: logger.info("Checking for stale descriptions...") targets_list = sorted( await find_stale_tools(registered, redis_client), ) if not targets_list: logger.info("No tools need refreshing!") return True logger.info( "Refreshing %d tool(s):", len(targets_list), ) for name in targets_list: logger.info(" - %s", name) index_data = load_index_file() http_client = httpx.AsyncClient( timeout=httpx.Timeout(120.0, connect=10.0), ) base_url = cfg.llm_base_url sem = asyncio.Semaphore(3) async def gen(tool_name: str) -> None: """Gen. Args: tool_name (str): The tool name value. """ async with sem: tool = registered[tool_name] desc = getattr(tool, "description", "") or "" logger.info(" Generating for: %s", tool_name) qs = await generate_synthetic_queries( http_client, base_url, api_key, tool_name, desc, ) if qs: index_data[tool_name] = { "name": tool_name, "description": desc, "synthetic_queries": qs, } logger.info( " %s: %d queries", tool_name, len(qs), ) else: logger.warning(" %s: query gen failed", tool_name) logger.info("Generating synthetic queries...") await asyncio.gather(*(gen(n) for n in targets_list)) save_index_file(index_data) logger.info("Computing embeddings...") emb_client = OpenRouterEmbeddings(api_key=api_key) embs_store: dict[str, str] = {} meta_store: dict[str, str] = {} for tool_name in targets_list: info = index_data.get(tool_name, {}) qs = info.get("synthetic_queries", []) desc = info.get("description", "") if not qs: tool = registered[tool_name] desc = getattr(tool, "description", "") or "" logger.warning( " %s: no queries, using desc", tool_name, ) qs = [desc] if desc else [f"use {tool_name}"] centroid = await compute_tool_embedding( emb_client, qs, tool_name, ) if centroid is not None: embs_store[tool_name] = json.dumps( centroid.tolist(), ) meta_store[tool_name] = json.dumps({ "name": tool_name, "description": desc, "query_count": len(qs), }) logger.info( " %s: computed from %d queries", tool_name, len(qs), ) else: logger.warning( " %s: failed to compute", tool_name, ) await emb_client.close() if embs_store: logger.info( "Storing %d refreshed embeddings...", len(embs_store), ) await redis_client.hset( TOOL_EMBEDDINGS_HASH_KEY, mapping=embs_store, ) await redis_client.hset( TOOL_METADATA_HASH_KEY, mapping=meta_store, ) logger.info( "Updated %d tool embeddings in Redis", len(embs_store), ) else: logger.warning("No embeddings to store") logger.info("=" * 60) logger.info("Refresh complete!") logger.info("=" * 60) return True finally: await redis_client.aclose()
[docs] async def main() -> None: """Main. """ parser = argparse.ArgumentParser( description=( "Refresh embeddings for tools whose " "descriptions have changed" ), ) parser.add_argument( "--force", "-f", action="store_true", help=( "Re-embed all tools regardless of " "description changes" ), ) parser.add_argument( "--tools", "-t", default="", help="Comma-separated tool names to refresh", ) parser.add_argument( "--tools-dir", default="tools", help="Tool scripts directory (default: tools)", ) args = parser.parse_args() tool_names = ( [n.strip() for n in args.tools.split(",") if n.strip()] if args.tools else None ) try: success = await refresh_tool_embeddings( force=args.force, tool_names=tool_names, tools_dir=args.tools_dir, ) sys.exit(0 if success else 1) except Exception as exc: logger.error( "Refresh failed: %s", exc, exc_info=True, ) sys.exit(1)
if __name__ == "__main__": asyncio.run(main())