Source code for classifiers.update_tool_embeddings

#!/usr/bin/env python3
"""Incrementally update tool embeddings in Redis.

Discovers all registered tools, compares against what already
exists in Redis, generates synthetic queries for missing tools
via the Gemini API (using the shared embedding key pool), and
adds only the new embeddings -- without touching existing ones.

Usage::

    python -m classifiers.update_tool_embeddings \
        [--force-index] [--tools-dir tools]

Environment variables:
    REDIS_URL           -- defaults to redis://localhost:6379/0
"""

from __future__ import annotations

import argparse
import asyncio
import json
import logging
import os
import sys
from typing import Any

import numpy as np
import redis.asyncio as aioredis
import httpx

sys.path.insert(
    0,
    os.path.abspath(
        os.path.join(os.path.dirname(__file__), ".."),
    ),
)

from tools import ToolRegistry  # noqa: E402
from tool_loader import load_tools  # noqa: E402
from rag_system.openrouter_embeddings import (  # noqa: E402
    OpenRouterEmbeddings,
)
from classifiers.vector_classifier import (  # noqa: E402
    TOOL_EMBEDDINGS_HASH_KEY,
    TOOL_METADATA_HASH_KEY,
)
from classifiers.build_tool_index import (  # noqa: E402
    generate_synthetic_queries,
    SYNTHETIC_QUERY_COUNT,
)

logging.basicConfig(
    level=logging.INFO,
    format=(
        "%(asctime)s - %(name)s - "
        "%(levelname)s - %(message)s"
    ),
)
logger = logging.getLogger(__name__)

INDEX_FILE = os.path.join(
    os.path.dirname(__file__), "tool_index_data.json",
)


[docs] async def get_existing_redis_tools( redis_client: aioredis.Redis, ) -> set[str]: """Return tool names already in Redis.""" try: keys = await redis_client.hkeys( TOOL_EMBEDDINGS_HASH_KEY, ) return { k.decode("utf-8") if isinstance(k, bytes) else k for k in keys } except Exception as exc: logger.error( "Failed to fetch existing tools: %s", exc, ) return set()
[docs] def discover_tools( tools_dir: str = "tools", ) -> dict[str, Any]: """Auto-discover all registered tools keyed by name.""" registry = ToolRegistry() load_tools(tools_dir, registry) return {td.name: td for td in registry.list_tools()}
[docs] def load_index_file() -> dict[str, Any]: """Load index file from the configured source. Returns: dict[str, Any]: The result. """ if os.path.exists(INDEX_FILE): try: with open( INDEX_FILE, "r", encoding="utf-8", ) as f: return json.load(f) except Exception as exc: logger.warning( "Failed to load index file: %s", exc, ) return {}
[docs] def save_index_file(data: dict[str, Any]) -> None: """Save index file. Args: data (dict[str, Any]): Input data payload. """ with open(INDEX_FILE, "w", encoding="utf-8") as f: json.dump(data, f, indent=2) logger.info( "Saved index file with %d tools", len(data), )
[docs] async def compute_tool_embedding( embedding_client: OpenRouterEmbeddings, synthetic_queries: list[str], tool_name: str, ) -> np.ndarray | None: """Compute a normalised centroid from queries.""" try: embs = [ e for e in await embedding_client.embed_texts( synthetic_queries, ) if e.size > 0 ] except Exception as exc: logger.warning( "Failed to embed queries for %r: %s", tool_name, exc, ) return None if not embs: return None centroid = np.mean(embs, axis=0) norm = np.linalg.norm(centroid) if norm > 0: centroid = centroid / norm return centroid
[docs] async def update_tool_embeddings( force_index: bool = False, tools_dir: str = "tools", ) -> bool: """Main incremental update routine.""" logger.info("=" * 60) logger.info("Incremental Tool Embeddings Update") logger.info("=" * 60) from config import Config cfg = Config.load() redis_url = cfg.redis_url or os.environ.get( "REDIS_URL", "redis://localhost:6379/0", ) redis_client = aioredis.from_url( redis_url, decode_responses=True, **cfg.redis_ssl_kwargs(), ) try: # 1. Discover tools logger.info("Discovering registered tools...") registered = discover_tools(tools_dir) logger.info( " Found %d registered tools", len(registered), ) # 2. Check Redis logger.info("Checking existing embeddings...") existing = await get_existing_redis_tools( redis_client, ) logger.info( " Found %d with embeddings", len(existing), ) # 3. Identify missing and orphaned registered_names = set(registered.keys()) missing = registered_names - existing orphaned = existing - registered_names # 3a. Prune orphaned embeddings (tools removed from codebase) if orphaned: logger.info( "Pruning %d orphaned embeddings:", len(orphaned), ) for name in sorted(orphaned): logger.info(" - %s", name) await redis_client.hdel( TOOL_EMBEDDINGS_HASH_KEY, *orphaned, ) await redis_client.hdel( TOOL_METADATA_HASH_KEY, *orphaned, ) logger.info( "Removed %d orphaned entries from Redis", len(orphaned), ) # 3b. Load index file and prune orphans from it too logger.info("Loading tool index file...") index_data = load_index_file() pruned_from_index = 0 for name in list(index_data.keys()): if name not in registered_names: del index_data[name] pruned_from_index += 1 if pruned_from_index: save_index_file(index_data) logger.info( "Pruned %d orphaned entries from index file", pruned_from_index, ) if not missing: logger.info("All tools have embeddings!") return True logger.info( "Found %d tools without embeddings:", len(missing), ) for name in sorted(missing): logger.info(" - %s", name) needing_queries: list[str] = [] for name in missing: qs = index_data.get(name, {}).get( "synthetic_queries", [], ) if ( not qs or len(qs) < SYNTHETIC_QUERY_COUNT or force_index ): needing_queries.append(name) # 5. Generate synthetic queries if needing_queries: logger.info( "Generating queries for %d tools...", len(needing_queries), ) http_client = httpx.AsyncClient( timeout=httpx.Timeout(120.0, connect=10.0), ) sem = asyncio.Semaphore(3) async def gen(tool_name: str) -> None: async with sem: tool = registered[tool_name] desc = ( getattr(tool, "description", "") or "" ) logger.info( " Generating for: %s", tool_name, ) qs = await generate_synthetic_queries( http_client, None, None, tool_name, desc, ) if qs: index_data[tool_name] = { "name": tool_name, "description": desc, "synthetic_queries": qs, } logger.info( " %s: %d queries", tool_name, len(qs), ) else: logger.warning( " %s: failed", tool_name, ) await asyncio.gather( *(gen(n) for n in needing_queries), ) save_index_file(index_data) # 6. Compute embeddings (Gemini API via shared key pool) logger.info("Computing embeddings...") emb_client = OpenRouterEmbeddings() embs_store: dict[str, str] = {} meta_store: dict[str, str] = {} for tool_name in missing: info = index_data.get(tool_name, {}) qs = info.get("synthetic_queries", []) desc = info.get("description", "") if not qs: logger.warning( " %s: no queries, using desc", tool_name, ) qs = ( [desc] if desc else [f"use {tool_name}"] ) centroid = await compute_tool_embedding( emb_client, qs, tool_name, ) if centroid is not None: embs_store[tool_name] = json.dumps( centroid.tolist(), ) meta_store[tool_name] = json.dumps({ "name": tool_name, "description": desc, "query_count": len(qs), }) logger.info( " %s: computed from %d queries", tool_name, len(qs), ) else: logger.warning( " %s: failed to compute", tool_name, ) await emb_client.close() # 7. Store in Redis (additive) if embs_store: logger.info( "Storing %d new embeddings...", len(embs_store), ) await redis_client.hset( TOOL_EMBEDDINGS_HASH_KEY, mapping=embs_store, ) await redis_client.hset( TOOL_METADATA_HASH_KEY, mapping=meta_store, ) logger.info( "Added %d tool embeddings to Redis", len(embs_store), ) else: logger.warning("No new embeddings to store") logger.info("=" * 60) logger.info("Incremental update complete!") logger.info("=" * 60) return True finally: await redis_client.aclose()
[docs] async def main() -> None: """Main. """ parser = argparse.ArgumentParser( description=( "Update tool embeddings incrementally " "(only add missing tools)" ), ) parser.add_argument( "--force-index", "-f", action="store_true", help="Force regeneration of synthetic queries", ) parser.add_argument( "--tools-dir", default="tools", help="Tool scripts directory (default: tools)", ) args = parser.parse_args() try: success = await update_tool_embeddings( force_index=args.force_index, tools_dir=args.tools_dir, ) sys.exit(0 if success else 1) except Exception as exc: logger.error( "Update failed: %s", exc, exc_info=True, ) sys.exit(1)
if __name__ == "__main__": asyncio.run(main())