Source code for classifiers.update_tool_embeddings

#!/usr/bin/env python3
"""Incrementally update tool embeddings in Redis.

Discovers all registered tools, compares against what already
exists in Redis, generates synthetic queries for missing tools
via the Gemini API (using the shared embedding key pool), and
adds only the new embeddings -- without touching existing ones.

Usage::

    python -m classifiers.update_tool_embeddings \
        [--force-index] [--force-all] [--openrouter-only] [--tools-dir tools]

``--openrouter-only`` skips Gemini for both synthetic query generation (OpenRouter
chat: ``google/gemini-3.1-flash-lite``) and embeddings (OpenRouter
``/embeddings``: ``google/gemini-embedding-001`` via :func:`gemini_embed_pool.set_openrouter_only`).
Requires ``OPENROUTER_QUERY_GEN_API_KEY``, ``OPENROUTER_API_KEY``, or ``API_KEY``.
Raises default ``TOOL_EMBED_OR_MAX_CONCURRENT`` to 32 when unset (heavy parallel
batches; override with env).

``--force-all`` regenerates synthetic queries (when combined with ``--force-index``
or implied) and embeddings for **every** registered tool, not only keys missing
from Redis.

To **regenerate synthetic queries and embeddings for specific tools only**
(e.g. after fixing query-gen for one tool), use::

    python -m classifiers.refresh_tool_embeddings --tools vmware_control

Comma-separate multiple names. This updates ``tool_index_data.json`` and Redis
for those tools only; other tools are unchanged.

Environment variables:
    REDIS_URL                  -- defaults to redis://localhost:6379/0
    TOOL_EMBED_OR_MAX_CONCURRENT -- max concurrent embedding HTTP batches in
                                  OpenRouter-only mode (default 32 if unset)
"""

from __future__ import annotations

import argparse
import asyncio
import jsonutil as json
import logging
import os
import sys
from typing import Any

import numpy as np
import redis.asyncio as aioredis
import httpx

sys.path.insert(
    0,
    os.path.abspath(
        os.path.join(os.path.dirname(__file__), ".."),
    ),
)

from gemini_embed_pool import (  # noqa: E402
    clear_openrouter_only,
    init_quota_tracking,
    set_openrouter_only,
)

from tools import ToolRegistry  # noqa: E402
from tool_loader import load_tools  # noqa: E402
from rag_system.openrouter_embeddings import (  # noqa: E402
    OpenRouterEmbeddings,
)
from classifiers.vector_classifier import (  # noqa: E402
    TOOL_EMBEDDINGS_HASH_KEY,
    TOOL_METADATA_HASH_KEY,
)
from classifiers.redis_vector_index import (  # noqa: E402
    delete_tool_embedding_hash,
    store_tool_embedding_hash,
)
from classifiers.build_tool_index import (  # noqa: E402
    SYNTHETIC_QUERY_COUNT,
    _openrouter_query_gen_api_key,
    generate_synthetic_queries,
)
from classifiers.tool_embedding_batch import (  # noqa: E402
    compute_tool_centroids_bulk,
)

logging.basicConfig(
    level=logging.INFO,
    format=("%(asctime)s - %(name)s - " "%(levelname)s - %(message)s"),
)
logger = logging.getLogger(__name__)

INDEX_FILE = os.path.join(
    os.path.dirname(__file__),
    "tool_index_data.json",
)


[docs] async def get_existing_redis_tools( redis_client: aioredis.Redis, ) -> set[str]: """Return the set of tool names that already have centroid embeddings in Redis. Issues a single ``HKEYS`` against the ``TOOL_EMBEDDINGS_HASH_KEY`` hash (the ``stargazer:tool_embeddings`` map defined in :mod:`classifiers.vector_classifier`) and decodes any ``bytes`` field names to ``str`` so callers can diff them against the live tool registry. This is what drives the incremental behavior: names present here are skipped, names absent are treated as missing, and names here but no longer registered are pruned as orphans. Errors are swallowed and logged so a transient Redis failure yields an empty set rather than aborting the run. Called once by :func:`update_tool_embeddings` in this module to compute the missing/orphaned split; no other callers were found. Args: redis_client (aioredis.Redis): Async Redis connection used for the ``HKEYS`` read. Returns: set[str]: Tool names currently stored in the embeddings hash, or an empty set if the read failed. """ try: keys = await redis_client.hkeys( TOOL_EMBEDDINGS_HASH_KEY, ) return {k.decode("utf-8") if isinstance(k, bytes) else k for k in keys} except Exception as exc: logger.error( "Failed to fetch existing tools: %s", exc, ) return set()
[docs] def discover_tools( tools_dir: str = "tools", ) -> dict[str, Any]: """Auto-discover every tool on disk and return it keyed by tool name. Builds a fresh :class:`tools.ToolRegistry`, populates it by calling :func:`tool_loader.load_tools` over the given directory (which imports each tool module and registers its definition), and collapses ``registry.list_tools()`` into a name-to-definition mapping. This is the live source of truth for which tools exist, against which the Redis embedding hash is diffed for missing/orphaned tools; it touches the filesystem (importing tool modules) but not Redis. Called by :func:`update_tool_embeddings` here, by :func:`update_changed_tool_embeddings.update_changed_tool_embeddings`, and by ``classifiers/refresh_tool_embeddings.py`` to resolve the current tool set. Args: tools_dir (str): Directory to scan for tool modules. Defaults to ``"tools"``. Returns: dict[str, Any]: Mapping of tool name to its registered tool definition object (carrying ``.name`` and ``.description``). """ registry = ToolRegistry() load_tools(tools_dir, registry) return {td.name: td for td in registry.list_tools()}
[docs] def load_index_file() -> dict[str, Any]: """Load index file from the configured source. Returns: dict[str, Any]: The result. """ if os.path.exists(INDEX_FILE): try: with open( INDEX_FILE, "r", encoding="utf-8", ) as f: return json.load(f) except Exception as exc: logger.warning( "Failed to load index file: %s", exc, ) return {}
[docs] def save_index_file(data: dict[str, Any]) -> None: """Save index file. Args: data (dict[str, Any]): Input data payload. """ with open(INDEX_FILE, "w", encoding="utf-8") as f: json.dump(data, f, indent=2) logger.info( "Saved index file with %d tools", len(data), )
_EMBED_MODEL_OPENROUTER = "google/gemini-embedding-001"
[docs] async def update_tool_embeddings( force_index: bool = False, tools_dir: str = "tools", *, force_all: bool = False, openrouter_only: bool = False, ) -> bool: """Incrementally reconcile the tool vector index in Redis with the live registry. The core routine of this module. It discovers the registered tools via :func:`discover_tools`, reads the existing embedding keys via :func:`get_existing_redis_tools`, and computes the missing set (tools to add) and the orphaned set (tools to remove). Orphans are pruned from ``TOOL_EMBEDDINGS_HASH_KEY`` / ``TOOL_METADATA_HASH_KEY`` and the per-tool RediSearch documents, and also from the on-disk ``tool_index_data.json``. For the missing tools it generates synthetic queries (Gemini, or OpenRouter chat ``google/gemini-3.1-flash-lite`` when ``openrouter_only`` is set) via the nested ``gen`` closure, computes centroid vectors through :func:`classifiers.tool_embedding_batch.compute_tool_centroids_bulk` and an :class:`OpenRouterEmbeddings` client, and additively HSETs the new vectors plus metadata back into Redis. Heavily I/O bound: it loads :class:`config.Config`, opens its own async Redis connection (URL or Sentinel), calls :func:`gemini_embed_pool.init_quota_tracking`, optionally toggles OpenRouter-only mode for the embedding pool, makes HTTP calls for query generation and embeddings, and reads/writes both Redis and ``tool_index_data.json``. The Redis connection and any OpenRouter-only state are always released in the ``finally`` block. Called only by :func:`main` in this module (the ``python -m classifiers.update_tool_embeddings`` entry point); no other internal callers were found. Args: force_index (bool): Regenerate synthetic queries even when the index file already has enough for a tool. Implied by ``force_all``. tools_dir (str): Directory to scan for tools. Defaults to ``"tools"``. force_all (bool): Treat every registered tool as missing, recomputing queries and embeddings for all of them rather than only Redis-absent keys. openrouter_only (bool): Route both query generation and embeddings through OpenRouter instead of the Gemini key pool. Requires an OpenRouter or ``API_KEY`` credential; raises the default embedding concurrency to 32. Returns: bool: ``True`` on success (including the no-op case where everything is already embedded), ``False`` if a required credential is missing. Raises: RuntimeError: If a tool slated for embedding lacks the expected number of synthetic queries in the index file. """ logger.info("=" * 60) logger.info("Incremental Tool Embeddings Update") logger.info("=" * 60) from config import Config cfg = Config.load() redis_url = cfg.redis_url or os.environ.get("REDIS_URL", "") if redis_url: redis_client = aioredis.from_url( redis_url, decode_responses=True, **cfg.redis_connection_kwargs_for_url(redis_url), ) elif cfg.redis_sentinels: redis_client = cfg.build_async_redis_client(decode_responses=True) redis_url = f"sentinel:{cfg.redis_sentinel_master}" else: redis_url = "redis://localhost:6379/0" redis_client = aioredis.from_url( redis_url, decode_responses=True, **cfg.redis_connection_kwargs_for_url(redis_url), ) init_quota_tracking(redis_client) openrouter_activated = False try: if openrouter_only: if not _openrouter_query_gen_api_key(): logger.error( "openrouter_only requires OPENROUTER_QUERY_GEN_API_KEY, " "OPENROUTER_API_KEY, or API_KEY", ) return False await set_openrouter_only() openrouter_activated = True os.environ.setdefault("TOOL_EMBED_OR_MAX_CONCURRENT", "32") # 1. Discover tools logger.info("Discovering registered tools...") registered = discover_tools(tools_dir) logger.info( " Found %d registered tools", len(registered), ) # 2. Check Redis logger.info("Checking existing embeddings...") existing = await get_existing_redis_tools( redis_client, ) logger.info( " Found %d with embeddings", len(existing), ) # 3. Identify missing and orphaned registered_names = set(registered.keys()) if force_all: missing = set(registered_names) else: missing = registered_names - existing orphaned = existing - registered_names # 3a. Prune orphaned embeddings (tools removed from codebase) if orphaned: logger.info( "Pruning %d orphaned embeddings:", len(orphaned), ) for name in sorted(orphaned): logger.info(" - %s", name) await redis_client.hdel( TOOL_EMBEDDINGS_HASH_KEY, *orphaned, ) await redis_client.hdel( TOOL_METADATA_HASH_KEY, *orphaned, ) for name in orphaned: await delete_tool_embedding_hash(redis_client, name) logger.info( "Removed %d orphaned entries from Redis", len(orphaned), ) # 3b. Load index file and prune orphans from it too logger.info("Loading tool index file...") index_data = load_index_file() pruned_from_index = 0 for name in list(index_data.keys()): if name not in registered_names: del index_data[name] pruned_from_index += 1 if pruned_from_index: save_index_file(index_data) logger.info( "Pruned %d orphaned entries from index file", pruned_from_index, ) if not missing: logger.info("All tools have embeddings!") return True logger.info( "Processing %d tools (%s)...", len(missing), "force-all" if force_all else "missing from Redis", ) for name in sorted(missing): logger.info(" - %s", name) force_idx = force_index or force_all needing_queries: list[str] = [] for name in missing: qs = index_data.get(name, {}).get( "synthetic_queries", [], ) if not qs or len(qs) < SYNTHETIC_QUERY_COUNT or force_idx: needing_queries.append(name) # 5. Generate synthetic queries if needing_queries: logger.info( "Generating queries for %d tools...", len(needing_queries), ) http_client = httpx.AsyncClient( timeout=httpx.Timeout(600.0, connect=30.0), ) _qsem = 8 if openrouter_only else 3 sem = asyncio.Semaphore(_qsem) async def gen(tool_name: str) -> None: """Generate synthetic queries for one tool and record them. Closure over the enclosing :func:`update_tool_embeddings` state. Bounded by the ``sem`` :class:`asyncio.Semaphore` (8 slots in OpenRouter-only mode, otherwise 3) so the concurrent ``asyncio.gather`` fan-out does not overwhelm the query-gen backend. Looks the tool up in ``registered``, reads its ``description``, and awaits :func:`classifiers.build_tool_index.generate_synthetic_queries` over the shared ``http_client`` (Gemini, or OpenRouter chat ``google/gemini-3.1-flash-lite`` when ``openrouter_only`` is set). The resulting queries are written back into the shared ``index_data`` dict under ``tool_name`` as ``name`` / ``description`` / ``synthetic_queries``; the caller persists ``index_data`` to ``tool_index_data.json`` afterward. Progress is logged but no value is returned. Args: tool_name (str): Registered tool name to generate queries for; must be a key of ``registered``. Returns: None: Results are stored in the shared ``index_data`` dict. """ async with sem: tool = registered[tool_name] desc = getattr(tool, "description", "") or "" logger.info( " Generating for: %s", tool_name, ) qs = await generate_synthetic_queries( http_client, None, None, tool_name, desc, openrouter_only=openrouter_only, ) index_data[tool_name] = { "name": tool_name, "description": desc, "synthetic_queries": qs, } logger.info( " %s: %d queries", tool_name, len(qs), ) await asyncio.gather( *(gen(n) for n in needing_queries), ) save_index_file(index_data) # 6. Compute embeddings (Gemini pool or OpenRouter when OR-only active) logger.info("Computing embeddings...") emb_client = OpenRouterEmbeddings(model=_EMBED_MODEL_OPENROUTER) if openrouter_only: emb_client.MAX_BATCH_SIZE = min(128, max(50, emb_client.MAX_BATCH_SIZE)) embs_store: dict[str, str] = {} meta_store: dict[str, str] = {} tool_queries: dict[str, list[str]] = {} meta_by_tool: dict[str, tuple[str, list[str]]] = {} for tool_name in missing: info = index_data.get(tool_name, {}) qs = info.get("synthetic_queries", []) desc = info.get("description", "") if len(qs) < SYNTHETIC_QUERY_COUNT: raise RuntimeError( f"{tool_name!r}: expected {SYNTHETIC_QUERY_COUNT} synthetic " f"queries in the index, got {len(qs)}. " "Regenerate with: python -m classifiers.update_tool_embeddings " "--force-index", ) tool_queries[tool_name] = qs meta_by_tool[tool_name] = (desc, qs) centroids = await compute_tool_centroids_bulk( emb_client, tool_queries, ) for tool_name, centroid in centroids.items(): desc, qs = meta_by_tool[tool_name] if centroid is not None: embs_store[tool_name] = json.dumps( centroid.tolist(), ) meta_store[tool_name] = json.dumps( { "name": tool_name, "description": desc, "query_count": len(qs), } ) logger.info( " %s: computed from %d queries", tool_name, len(qs), ) else: logger.warning( " %s: failed to compute", tool_name, ) await emb_client.close() # 7. Store in Redis (additive or full refresh) if embs_store: logger.info( "Storing %d new embeddings...", len(embs_store), ) await redis_client.hset( TOOL_EMBEDDINGS_HASH_KEY, mapping=embs_store, ) await redis_client.hset( TOOL_METADATA_HASH_KEY, mapping=meta_store, ) for tn, js in embs_store.items(): vec = np.array(json.loads(js), dtype=np.float32) meta = json.loads(meta_store[tn]) await store_tool_embedding_hash( redis_client, tn, vec, meta, ) logger.info( "Added %d tool embeddings to Redis", len(embs_store), ) else: logger.warning("No new embeddings to store") logger.info("=" * 60) logger.info("Incremental update complete!") logger.info("=" * 60) return True finally: if openrouter_activated: await clear_openrouter_only() await redis_client.aclose()
[docs] async def main() -> None: """Async CLI entry point that parses flags and runs the incremental update. Builds an :class:`argparse.ArgumentParser` exposing ``--force-index`` / ``-f`` (regenerate synthetic queries), ``--force-all`` (recompute every tool, not just Redis-absent ones), ``--openrouter-only`` (route query generation and embeddings through OpenRouter), and ``--tools-dir`` (tool scan directory). It then awaits :func:`update_tool_embeddings` with those values and translates the returned boolean into a process exit code via :func:`sys.exit` (``0`` on success, ``1`` on failure or unhandled exception, which is logged with a traceback). Invoked only by the module's ``if __name__ == "__main__"`` guard through :func:`asyncio.run` (``python -m classifiers.update_tool_embeddings``); no other internal callers were found. Returns: None: The process is terminated via :func:`sys.exit`. """ parser = argparse.ArgumentParser( description=( "Update tool embeddings incrementally " "(only add missing tools)" ), ) parser.add_argument( "--force-index", "-f", action="store_true", help="Force regeneration of synthetic queries", ) parser.add_argument( "--force-all", action="store_true", help=( "Regenerate synthetic queries and embeddings for every registered " "tool (not only keys missing from Redis). Implies full query regen." ), ) parser.add_argument( "--openrouter-only", action="store_true", help=( "Use OpenRouter only: chat " "(google/gemini-3.1-flash-lite) for synthetic queries and " "OpenRouter embeddings (google/gemini-embedding-001). " "Requires OPENROUTER_* or API_KEY." ), ) parser.add_argument( "--tools-dir", default="tools", help="Tool scripts directory (default: tools)", ) args = parser.parse_args() try: success = await update_tool_embeddings( force_index=args.force_index, tools_dir=args.tools_dir, force_all=args.force_all, openrouter_only=args.openrouter_only, ) sys.exit(0 if success else 1) except Exception as exc: logger.error( "Update failed: %s", exc, exc_info=True, ) sys.exit(1)
if __name__ == "__main__": asyncio.run(main())