Source code for classifiers.refresh_tool_embeddings
#!/usr/bin/env python3
"""Refresh embeddings for tools whose descriptions have changed.
Compares each registered tool's live description against the stored
metadata in Redis. Tools with mismatched descriptions are "stale".
Tools whose ``tool_index_data.json`` entry is missing or does not have
the expected number of non-empty string ``synthetic_queries`` are "invalid".
Both sets are refreshed (union) unless ``--invalid-only`` is passed.
**Redis:** This script does **not** delete or flush the embedding hash. It only
HSETs keys for tools that successfully produced an embedding. If query
generation or embedding fails for every target, Redis is left unchanged.
**Index file:** If ``tool_index_data.json`` exists but fails to parse, this
module used to load ``{}`` and could overwrite the file with a partial index.
That is now blocked (see load guard below).
Run ``python -m classifiers.refresh_tool_embeddings`` (see ``--help``) with
optional arguments:
- ``--force`` — Re-embed every tool regardless of whether its description
changed (useful after switching embedding models).
- ``--tools`` — Comma-separated list of specific tool names to refresh.
- ``--tools-dir`` — Tool scripts directory (default: ``tools``).
- ``--embed-only`` — Skip Gemini query generation; recompute embeddings from
existing ``tool_index_data.json`` only (use after ``build_tool_index``).
- ``--invalid-only`` — Only tools with invalid/missing index queries (skip
stale-description check). Ignored with ``--force`` or ``--embed-only``.
Synthetic query generation uses exponential backoff, rotates Gemini models
(``gemini-2.5-flash``, ``gemini-3-flash-preview`` after the default), then
OpenRouter ``google/<primary>`` if all Gemini calls fail (key:
``OPENROUTER_QUERY_GEN_API_KEY`` or ``OPENROUTER_API_KEY``; 429 backoff).
See ``classifiers/build_tool_index.generate_synthetic_queries`` and
``GEMINI_QUERY_GEN_*`` / ``GEMINI_QUERY_GEN_429_*`` env vars.
"""
from __future__ import annotations
import argparse
import asyncio
import jsonutil as json
import logging
import os
import shutil
import sys
from pathlib import Path
from typing import Any
import numpy as np
import redis.asyncio as aioredis
import httpx
sys.path.insert(
0,
os.path.abspath(
os.path.join(os.path.dirname(__file__), ".."),
),
)
from gemini_embed_pool import init_quota_tracking # noqa: E402
from rag_system.openrouter_embeddings import ( # noqa: E402
OpenRouterEmbeddings,
)
from classifiers.vector_classifier import ( # noqa: E402
TOOL_EMBEDDINGS_HASH_KEY,
TOOL_METADATA_HASH_KEY,
)
from classifiers.redis_vector_index import store_tool_embedding_hash # noqa: E402
from classifiers.build_tool_index import ( # noqa: E402
SYNTHETIC_QUERY_COUNT,
discover_invalid_query_index_tools,
generate_synthetic_queries,
)
from classifiers.tool_embedding_batch import ( # noqa: E402
compute_tool_centroids_bulk,
)
from classifiers.update_tool_embeddings import ( # noqa: E402
INDEX_FILE,
load_index_file,
save_index_file,
discover_tools,
)
logging.basicConfig(
level=logging.INFO,
format="%(levelname)s: %(message)s",
)
logger = logging.getLogger(__name__)
[docs]
async def find_stale_tools(
registered: dict[str, Any],
redis_client: aioredis.Redis,
) -> list[str]:
"""Return names of tools whose live descriptions differ from Redis.
Detects "stale" tools: those whose currently registered ``description``
no longer matches the description captured the last time the tool was
embedded. This is the cheap change-detection step that lets a refresh
re-embed only the tools that actually changed instead of every tool.
Reads the stored metadata once via ``HGETALL`` on the
``TOOL_METADATA_HASH_KEY`` Redis hash, JSON-decodes each entry, and
compares its ``description`` field against the live ``description``
attribute on the registered tool definition. Tools absent from the hash
are skipped (nothing to compare against); tools whose stored metadata is
unparseable JSON are treated as stale so they get rebuilt. This function
only reads from Redis and does not mutate it.
Called by :func:`refresh_tool_embeddings` (twice: the ``--embed-only``
and the default stale/invalid paths) and by the standalone
``classifiers.update_changed_tool_embeddings`` script.
Args:
registered: Mapping of tool name to the registered tool definition
object, as returned by ``discover_tools``; each value must expose
a ``description`` attribute.
redis_client: Async Redis client used to read the stored tool
metadata hash.
Returns:
The list of tool names whose live description differs from (or has no
valid stored counterpart for) the value in Redis.
"""
meta_raw: dict[str, str] = await redis_client.hgetall(
TOOL_METADATA_HASH_KEY,
)
stale: list[str] = []
for name, tool_def in registered.items():
stored = meta_raw.get(name)
if stored is None:
continue
try:
meta = json.loads(stored)
except (json.JSONDecodeError, TypeError):
stale.append(name)
continue
live_desc = getattr(tool_def, "description", "") or ""
stored_desc = meta.get("description", "")
if live_desc != stored_desc:
stale.append(name)
return stale
[docs]
async def refresh_tool_embeddings(
*,
force: bool = False,
embed_only: bool = False,
invalid_only: bool = False,
tool_names: list[str] | None = None,
tools_dir: str = "tools",
) -> bool:
"""Re-embed the tools that changed and write their vectors back to Redis.
The end-to-end refresh routine for the vector tool classifier. It decides
which tools to refresh, (re)generates their synthetic queries via the LLM
unless ``embed_only`` is set, recomputes centroid embeddings for them, and
stores those vectors plus metadata in Redis so the classifier can route to
the changed tools. The default mode is incremental (only stale or invalid
tools); flags widen or narrow that target set.
This orchestrates several collaborators: ``discover_tools`` to enumerate
registered tools, :func:`find_stale_tools` and
``discover_invalid_query_index_tools`` to pick targets, ``load_index_file``
/ ``save_index_file`` for ``tool_index_data.json`` (backed up to a ``.bak``
copy before write), ``generate_synthetic_queries`` to call Gemini
``generateContent`` (with OpenRouter fallback) for query text, the
``OpenRouterEmbeddings`` client plus :func:`compute_tool_centroids_bulk` to
turn those queries into one vector per tool, and ``store_tool_embedding_hash``
to index each vector. It touches three Redis structures: it reads
``TOOL_METADATA_HASH_KEY`` (via :func:`find_stale_tools`) and
``TOOL_EMBEDDINGS_HASH_KEY`` (for the force path), then HSETs both hashes
plus the per-tool RediSearch index entries. Redis is left untouched when no
embedding is produced. It also reads ``Config`` / env vars for the API key
and Redis URL, opens its own ``httpx.AsyncClient`` for query generation,
bounds generation concurrency with a semaphore of 3, and always closes the
Redis client in a ``finally`` block.
Called by the module's :func:`main` (CLI entry point) and invoked
programmatically by the ``write_python_tool`` and ``import_mcp_tool`` tools
after a new tool is created, so the classifier learns the new tool.
Args:
force: Re-embed every tool that already has an embedding (or, with
``embed_only``, every tool with synthetic queries in the index),
ignoring the stale-description check.
embed_only: Recompute embeddings from the existing
``tool_index_data.json`` only, skipping LLM synthetic-query
generation. Used after ``build_tool_index``.
invalid_only: Refresh only tools whose index entry is missing or has a
bad ``synthetic_queries`` list, skipping the stale check. Ignored
with ``force`` or ``embed_only``.
tool_names: Explicit list of tool names to refresh; unknown names are
warned about and skipped. Overrides the automatic target selection.
tools_dir: Directory of tool scripts to discover (default ``tools``).
Returns:
``True`` on success (including the no-op case where nothing needed
refreshing), ``False`` on a fatal precondition failure such as a
missing API key, an unreadable index file, or no valid targets.
Raises:
RuntimeError: If a target tool has fewer than the expected number of
synthetic queries in the index when embeddings are computed.
"""
logger.info("=" * 60)
logger.info("Tool Embeddings Refresh")
logger.info("=" * 60)
from config import Config
cfg = Config.load()
api_key = cfg.api_key
if not api_key:
logger.error("api_key required (config.yaml or API_KEY env var)")
return False
redis_url = cfg.redis_url or os.environ.get(
"REDIS_URL",
"redis://localhost:6379/0",
)
redis_client = aioredis.from_url(
redis_url,
decode_responses=True,
**cfg.redis_connection_kwargs_for_url(redis_url),
)
init_quota_tracking(redis_client)
try:
logger.info("Discovering registered tools...")
registered = discover_tools(tools_dir)
logger.info(" Found %d registered tools", len(registered))
index_path = Path(INDEX_FILE)
if embed_only and invalid_only:
logger.warning(
"--invalid-only is ignored with --embed-only "
"(cannot repair queries without generation)",
)
if embed_only:
index_data = load_index_file()
if index_path.is_file() and not index_data:
logger.error(
"%s exists but could not be read or parsed to a non-empty "
"object. Repair the JSON and retry.",
index_path,
)
return False
if tool_names:
targets = set(tool_names) & set(registered.keys())
unknown = set(tool_names) - set(registered.keys())
if unknown:
logger.warning(
"Unknown tools (skipped): %s",
", ".join(sorted(unknown)),
)
if not targets:
logger.error("No valid tool names to refresh")
return False
targets_list = sorted(targets)
elif force:
targets_list = sorted(
n
for n in registered.keys()
if index_data.get(n, {}).get("synthetic_queries")
)
if not targets_list:
logger.error(
"No tools with synthetic_queries in %s; run "
"build_tool_index first.",
index_path.name,
)
return False
else:
targets_list = sorted(
await find_stale_tools(registered, redis_client),
)
targets_list = [
n
for n in targets_list
if index_data.get(n, {}).get("synthetic_queries")
]
if not targets_list:
logger.info("No tools need refreshing!")
return True
logger.info(
"Embed-only: refreshing %d tool(s) from index (no Gemini query gen)",
len(targets_list),
)
else:
if tool_names:
targets = set(tool_names) & set(registered.keys())
unknown = set(tool_names) - set(registered.keys())
if unknown:
logger.warning(
"Unknown tools (skipped): %s",
", ".join(sorted(unknown)),
)
if not targets:
logger.error("No valid tool names to refresh")
return False
targets_list = sorted(targets)
elif force:
if invalid_only:
logger.info(
"--invalid-only ignored with --force",
)
existing_keys = await redis_client.hkeys(
TOOL_EMBEDDINGS_HASH_KEY,
)
existing_names = {
k.decode("utf-8") if isinstance(k, bytes) else k
for k in existing_keys
}
targets_list = sorted(
set(registered.keys()) & existing_names,
)
else:
logger.info(
"Checking stale descriptions and invalid index entries...",
)
index_data = load_index_file()
if index_path.is_file() and not index_data:
logger.error(
"%s exists but could not be read or parsed to a non-empty "
"object. Refusing to run (would risk overwriting the full "
"index with a partial result). Repair the JSON and retry.",
index_path,
)
return False
stale = set(
await find_stale_tools(registered, redis_client),
)
invalid = set(
discover_invalid_query_index_tools(
index_data,
registered,
),
)
if invalid_only:
targets_list = sorted(invalid)
else:
targets_list = sorted(stale | invalid)
if stale or invalid:
overlap = len(stale & invalid)
logger.info(
" %d stale description(s), %d invalid index entr(y/ies), "
"%d in both, %d unique to refresh",
len(stale),
len(invalid),
overlap,
len(targets_list),
)
if not targets_list:
logger.info("No tools need refreshing!")
return True
if tool_names or force:
index_data = load_index_file()
if index_path.is_file() and not index_data:
logger.error(
"%s exists but could not be read or parsed to a non-empty "
"object. Refusing to run (would risk overwriting the full "
"index with a partial result). Repair the JSON and retry.",
index_path,
)
return False
http_client = httpx.AsyncClient(
timeout=httpx.Timeout(600.0, connect=30.0),
)
base_url = cfg.llm_base_url
sem = asyncio.Semaphore(3)
async def gen(tool_name: str) -> None:
"""Generate synthetic queries for one tool and stage them in the index.
Concurrency-bounded helper used when query regeneration is needed
(i.e. not ``--embed-only``). It acquires the enclosing ``sem``
semaphore (cap of 3 concurrent generations), looks the tool up in
the closed-over ``registered`` mapping, reads its live
``description``, and calls
:func:`classifiers.build_tool_index.generate_synthetic_queries`
with the shared ``http_client``, ``base_url``, and ``api_key`` from
the enclosing scope. That call hits Gemini ``generateContent`` (with
model fallbacks and OpenRouter as a last resort) and may raise on
exhaustion.
Side effects: mutates the closed-over ``index_data`` dict in place,
writing ``{"name", "description", "synthetic_queries"}`` for
*tool_name*; this in-memory dict is later persisted to
``tool_index_data.json`` by the caller. It also logs progress. It
does not touch Redis directly.
Called by :func:`refresh_tool_embeddings` via
``asyncio.gather(*(gen(n) for n in targets_list))``. No other
internal callers were found.
Args:
tool_name: Name of the registered tool to (re)generate queries
for; must be a key in the enclosing ``registered`` mapping.
Raises:
Exception: Propagates failures from
:func:`generate_synthetic_queries` (e.g. ``RuntimeError``
when valid queries cannot be produced) up through
``asyncio.gather``.
"""
async with sem:
tool = registered[tool_name]
desc = getattr(tool, "description", "") or ""
logger.info(" Generating for: %s", tool_name)
qs = await generate_synthetic_queries(
http_client,
base_url,
api_key,
tool_name,
desc,
)
index_data[tool_name] = {
"name": tool_name,
"description": desc,
"synthetic_queries": qs,
}
logger.info(
" %s: %d queries",
tool_name,
len(qs),
)
logger.info("Generating synthetic queries...")
await asyncio.gather(*(gen(n) for n in targets_list))
if index_path.is_file():
try:
shutil.copy2(index_path, f"{index_path}.bak")
logger.info("Backed up index to %s.bak", index_path.name)
except OSError as exc:
logger.warning("Could not back up index file: %s", exc)
save_index_file(index_data)
logger.info(
"Refreshing %d tool(s):",
len(targets_list),
)
for name in targets_list:
logger.info(" - %s", name)
logger.info("Computing embeddings...")
emb_client = OpenRouterEmbeddings(api_key=api_key)
embs_store: dict[str, str] = {}
meta_store: dict[str, str] = {}
tool_queries: dict[str, list[str]] = {}
meta_by_tool: dict[str, tuple[str, list[str]]] = {}
for tool_name in targets_list:
info = index_data.get(tool_name, {})
qs = info.get("synthetic_queries", [])
desc = info.get("description", "")
if len(qs) < SYNTHETIC_QUERY_COUNT:
raise RuntimeError(
f"{tool_name!r}: expected {SYNTHETIC_QUERY_COUNT} synthetic "
f"queries in the index, got {len(qs)}. "
"Regenerate queries for this tool (e.g. force refresh / "
"--force-index) and retry.",
)
tool_queries[tool_name] = qs
meta_by_tool[tool_name] = (desc, qs)
centroids = await compute_tool_centroids_bulk(
emb_client,
tool_queries,
)
for tool_name, centroid in centroids.items():
desc, qs = meta_by_tool[tool_name]
if centroid is not None:
embs_store[tool_name] = json.dumps(
centroid.tolist(),
)
meta_store[tool_name] = json.dumps(
{
"name": tool_name,
"description": desc,
"query_count": len(qs),
}
)
logger.info(
" %s: computed from %d queries",
tool_name,
len(qs),
)
else:
logger.warning(
" %s: failed to compute",
tool_name,
)
await emb_client.close()
if embs_store:
logger.info(
"Storing %d refreshed embeddings (other Redis keys unchanged)...",
len(embs_store),
)
await redis_client.hset(
TOOL_EMBEDDINGS_HASH_KEY,
mapping=embs_store,
)
await redis_client.hset(
TOOL_METADATA_HASH_KEY,
mapping=meta_store,
)
for tn, js in embs_store.items():
vec = np.array(json.loads(js), dtype=np.float32)
meta = json.loads(meta_store[tn])
await store_tool_embedding_hash(
redis_client,
tn,
vec,
meta,
)
logger.info(
"Updated %d tool embeddings in Redis",
len(embs_store),
)
else:
logger.warning(
"No embeddings to store — Redis tool hashes were not modified "
"(query gen or embedding failed for all targets).",
)
logger.info("=" * 60)
logger.info("Refresh complete!")
logger.info("=" * 60)
return True
finally:
await redis_client.aclose()
[docs]
async def main() -> None:
"""Parse CLI arguments and drive a single :func:`refresh_tool_embeddings` run.
The command-line entry point for ``python -m
classifiers.refresh_tool_embeddings``. It builds the ``argparse`` parser
for the ``--force``, ``--tools``, ``--tools-dir``, ``--embed-only``, and
``--invalid-only`` flags, normalises the comma-separated ``--tools`` value
into a list, and forwards everything to :func:`refresh_tool_embeddings`.
The whole run's exit status is derived here: it calls ``sys.exit(0)`` on a
truthy result, ``sys.exit(1)`` on a falsy result, and also ``sys.exit(1)``
after logging the traceback if the refresh raises. Output goes to the
module logger configured at import time.
Called only by the ``if __name__ == "__main__"`` guard via
``asyncio.run(main())``; it is the process entry point and has no in-repo
callers.
"""
parser = argparse.ArgumentParser(
description=("Refresh embeddings for tools whose " "descriptions have changed"),
)
parser.add_argument(
"--force",
"-f",
action="store_true",
help=("Re-embed all tools regardless of " "description changes"),
)
parser.add_argument(
"--tools",
"-t",
default="",
help="Comma-separated tool names to refresh",
)
parser.add_argument(
"--tools-dir",
default="tools",
help="Tool scripts directory (default: tools)",
)
parser.add_argument(
"--embed-only",
action="store_true",
help=(
"Recompute embeddings from tool_index_data.json only "
"(no Gemini synthetic-query generation). Use with --force "
"after build_tool_index."
),
)
parser.add_argument(
"--invalid-only",
action="store_true",
help=(
"Only refresh tools with bad/missing synthetic_queries in the "
"index file (skip Redis stale-description check)"
),
)
args = parser.parse_args()
tool_names = (
[n.strip() for n in args.tools.split(",") if n.strip()] if args.tools else None
)
try:
success = await refresh_tool_embeddings(
force=args.force,
embed_only=args.embed_only,
invalid_only=args.invalid_only,
tool_names=tool_names,
tools_dir=args.tools_dir,
)
sys.exit(0 if success else 1)
except Exception as exc:
logger.error(
"Refresh failed: %s",
exc,
exc_info=True,
)
sys.exit(1)
if __name__ == "__main__":
asyncio.run(main())