Source code for classifiers.update_tool_embeddings
#!/usr/bin/env python3
"""Incrementally update tool embeddings in Redis.
Discovers all registered tools, compares against what already
exists in Redis, generates synthetic queries for missing tools
via the Gemini API (using the shared embedding key pool), and
adds only the new embeddings -- without touching existing ones.
Usage::
python -m classifiers.update_tool_embeddings \
[--force-index] [--force-all] [--openrouter-only] [--tools-dir tools]
``--openrouter-only`` skips Gemini for both synthetic query generation (OpenRouter
chat: ``google/gemini-3.1-flash-lite``) and embeddings (OpenRouter
``/embeddings``: ``google/gemini-embedding-001`` via :func:`gemini_embed_pool.set_openrouter_only`).
Requires ``OPENROUTER_QUERY_GEN_API_KEY``, ``OPENROUTER_API_KEY``, or ``API_KEY``.
Raises default ``TOOL_EMBED_OR_MAX_CONCURRENT`` to 32 when unset (heavy parallel
batches; override with env).
``--force-all`` regenerates synthetic queries (when combined with ``--force-index``
or implied) and embeddings for **every** registered tool, not only keys missing
from Redis.
To **regenerate synthetic queries and embeddings for specific tools only**
(e.g. after fixing query-gen for one tool), use::
python -m classifiers.refresh_tool_embeddings --tools vmware_control
Comma-separate multiple names. This updates ``tool_index_data.json`` and Redis
for those tools only; other tools are unchanged.
Environment variables:
REDIS_URL -- defaults to redis://localhost:6379/0
TOOL_EMBED_OR_MAX_CONCURRENT -- max concurrent embedding HTTP batches in
OpenRouter-only mode (default 32 if unset)
"""
from __future__ import annotations
import argparse
import asyncio
import jsonutil as json
import logging
import os
import sys
from typing import Any
import numpy as np
import redis.asyncio as aioredis
import httpx
sys.path.insert(
0,
os.path.abspath(
os.path.join(os.path.dirname(__file__), ".."),
),
)
from gemini_embed_pool import ( # noqa: E402
clear_openrouter_only,
init_quota_tracking,
set_openrouter_only,
)
from tools import ToolRegistry # noqa: E402
from tool_loader import load_tools # noqa: E402
from rag_system.openrouter_embeddings import ( # noqa: E402
OpenRouterEmbeddings,
)
from classifiers.vector_classifier import ( # noqa: E402
TOOL_EMBEDDINGS_HASH_KEY,
TOOL_METADATA_HASH_KEY,
)
from classifiers.redis_vector_index import ( # noqa: E402
delete_tool_embedding_hash,
store_tool_embedding_hash,
)
from classifiers.build_tool_index import ( # noqa: E402
SYNTHETIC_QUERY_COUNT,
_openrouter_query_gen_api_key,
generate_synthetic_queries,
)
from classifiers.tool_embedding_batch import ( # noqa: E402
compute_tool_centroids_bulk,
)
logging.basicConfig(
level=logging.INFO,
format=("%(asctime)s - %(name)s - " "%(levelname)s - %(message)s"),
)
logger = logging.getLogger(__name__)
INDEX_FILE = os.path.join(
os.path.dirname(__file__),
"tool_index_data.json",
)
[docs]
async def get_existing_redis_tools(
redis_client: aioredis.Redis,
) -> set[str]:
"""Return the set of tool names that already have centroid embeddings in Redis.
Issues a single ``HKEYS`` against the ``TOOL_EMBEDDINGS_HASH_KEY`` hash (the
``stargazer:tool_embeddings`` map defined in
:mod:`classifiers.vector_classifier`) and decodes any ``bytes`` field names
to ``str`` so callers can diff them against the live tool registry. This is
what drives the incremental behavior: names present here are skipped, names
absent are treated as missing, and names here but no longer registered are
pruned as orphans. Errors are swallowed and logged so a transient Redis
failure yields an empty set rather than aborting the run.
Called once by :func:`update_tool_embeddings` in this module to compute the
missing/orphaned split; no other callers were found.
Args:
redis_client (aioredis.Redis): Async Redis connection used for the
``HKEYS`` read.
Returns:
set[str]: Tool names currently stored in the embeddings hash, or an
empty set if the read failed.
"""
try:
keys = await redis_client.hkeys(
TOOL_EMBEDDINGS_HASH_KEY,
)
return {k.decode("utf-8") if isinstance(k, bytes) else k for k in keys}
except Exception as exc:
logger.error(
"Failed to fetch existing tools: %s",
exc,
)
return set()
[docs]
def discover_tools(
tools_dir: str = "tools",
) -> dict[str, Any]:
"""Auto-discover every tool on disk and return it keyed by tool name.
Builds a fresh :class:`tools.ToolRegistry`, populates it by calling
:func:`tool_loader.load_tools` over the given directory (which imports each
tool module and registers its definition), and collapses
``registry.list_tools()`` into a name-to-definition mapping. This is the live
source of truth for which tools exist, against which the Redis embedding hash
is diffed for missing/orphaned tools; it touches the filesystem (importing
tool modules) but not Redis.
Called by :func:`update_tool_embeddings` here, by
:func:`update_changed_tool_embeddings.update_changed_tool_embeddings`, and by
``classifiers/refresh_tool_embeddings.py`` to resolve the current tool set.
Args:
tools_dir (str): Directory to scan for tool modules. Defaults to
``"tools"``.
Returns:
dict[str, Any]: Mapping of tool name to its registered tool definition
object (carrying ``.name`` and ``.description``).
"""
registry = ToolRegistry()
load_tools(tools_dir, registry)
return {td.name: td for td in registry.list_tools()}
[docs]
def load_index_file() -> dict[str, Any]:
"""Load index file from the configured source.
Returns:
dict[str, Any]: The result.
"""
if os.path.exists(INDEX_FILE):
try:
with open(
INDEX_FILE,
"r",
encoding="utf-8",
) as f:
return json.load(f)
except Exception as exc:
logger.warning(
"Failed to load index file: %s",
exc,
)
return {}
[docs]
def save_index_file(data: dict[str, Any]) -> None:
"""Save index file.
Args:
data (dict[str, Any]): Input data payload.
"""
with open(INDEX_FILE, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2)
logger.info(
"Saved index file with %d tools",
len(data),
)
_EMBED_MODEL_OPENROUTER = "google/gemini-embedding-001"
[docs]
async def update_tool_embeddings(
force_index: bool = False,
tools_dir: str = "tools",
*,
force_all: bool = False,
openrouter_only: bool = False,
) -> bool:
"""Incrementally reconcile the tool vector index in Redis with the live registry.
The core routine of this module. It discovers the registered tools via
:func:`discover_tools`, reads the existing embedding keys via
:func:`get_existing_redis_tools`, and computes the missing set (tools to add)
and the orphaned set (tools to remove). Orphans are pruned from
``TOOL_EMBEDDINGS_HASH_KEY`` / ``TOOL_METADATA_HASH_KEY`` and the per-tool
RediSearch documents, and also from the on-disk ``tool_index_data.json``. For
the missing tools it generates synthetic queries (Gemini, or OpenRouter chat
``google/gemini-3.1-flash-lite`` when ``openrouter_only`` is set) via the
nested ``gen`` closure, computes centroid vectors through
:func:`classifiers.tool_embedding_batch.compute_tool_centroids_bulk` and an
:class:`OpenRouterEmbeddings` client, and additively HSETs the new vectors
plus metadata back into Redis.
Heavily I/O bound: it loads :class:`config.Config`, opens its own async Redis
connection (URL or Sentinel), calls :func:`gemini_embed_pool.init_quota_tracking`,
optionally toggles OpenRouter-only mode for the embedding pool, makes HTTP
calls for query generation and embeddings, and reads/writes both Redis and
``tool_index_data.json``. The Redis connection and any OpenRouter-only state
are always released in the ``finally`` block.
Called only by :func:`main` in this module (the ``python -m
classifiers.update_tool_embeddings`` entry point); no other internal callers
were found.
Args:
force_index (bool): Regenerate synthetic queries even when the index file
already has enough for a tool. Implied by ``force_all``.
tools_dir (str): Directory to scan for tools. Defaults to ``"tools"``.
force_all (bool): Treat every registered tool as missing, recomputing
queries and embeddings for all of them rather than only Redis-absent
keys.
openrouter_only (bool): Route both query generation and embeddings through
OpenRouter instead of the Gemini key pool. Requires an OpenRouter or
``API_KEY`` credential; raises the default embedding concurrency to 32.
Returns:
bool: ``True`` on success (including the no-op case where everything is
already embedded), ``False`` if a required credential is missing.
Raises:
RuntimeError: If a tool slated for embedding lacks the expected number of
synthetic queries in the index file.
"""
logger.info("=" * 60)
logger.info("Incremental Tool Embeddings Update")
logger.info("=" * 60)
from config import Config
cfg = Config.load()
redis_url = cfg.redis_url or os.environ.get("REDIS_URL", "")
if redis_url:
redis_client = aioredis.from_url(
redis_url,
decode_responses=True,
**cfg.redis_connection_kwargs_for_url(redis_url),
)
elif cfg.redis_sentinels:
redis_client = cfg.build_async_redis_client(decode_responses=True)
redis_url = f"sentinel:{cfg.redis_sentinel_master}"
else:
redis_url = "redis://localhost:6379/0"
redis_client = aioredis.from_url(
redis_url,
decode_responses=True,
**cfg.redis_connection_kwargs_for_url(redis_url),
)
init_quota_tracking(redis_client)
openrouter_activated = False
try:
if openrouter_only:
if not _openrouter_query_gen_api_key():
logger.error(
"openrouter_only requires OPENROUTER_QUERY_GEN_API_KEY, "
"OPENROUTER_API_KEY, or API_KEY",
)
return False
await set_openrouter_only()
openrouter_activated = True
os.environ.setdefault("TOOL_EMBED_OR_MAX_CONCURRENT", "32")
# 1. Discover tools
logger.info("Discovering registered tools...")
registered = discover_tools(tools_dir)
logger.info(
" Found %d registered tools",
len(registered),
)
# 2. Check Redis
logger.info("Checking existing embeddings...")
existing = await get_existing_redis_tools(
redis_client,
)
logger.info(
" Found %d with embeddings",
len(existing),
)
# 3. Identify missing and orphaned
registered_names = set(registered.keys())
if force_all:
missing = set(registered_names)
else:
missing = registered_names - existing
orphaned = existing - registered_names
# 3a. Prune orphaned embeddings (tools removed from codebase)
if orphaned:
logger.info(
"Pruning %d orphaned embeddings:",
len(orphaned),
)
for name in sorted(orphaned):
logger.info(" - %s", name)
await redis_client.hdel(
TOOL_EMBEDDINGS_HASH_KEY,
*orphaned,
)
await redis_client.hdel(
TOOL_METADATA_HASH_KEY,
*orphaned,
)
for name in orphaned:
await delete_tool_embedding_hash(redis_client, name)
logger.info(
"Removed %d orphaned entries from Redis",
len(orphaned),
)
# 3b. Load index file and prune orphans from it too
logger.info("Loading tool index file...")
index_data = load_index_file()
pruned_from_index = 0
for name in list(index_data.keys()):
if name not in registered_names:
del index_data[name]
pruned_from_index += 1
if pruned_from_index:
save_index_file(index_data)
logger.info(
"Pruned %d orphaned entries from index file",
pruned_from_index,
)
if not missing:
logger.info("All tools have embeddings!")
return True
logger.info(
"Processing %d tools (%s)...",
len(missing),
"force-all" if force_all else "missing from Redis",
)
for name in sorted(missing):
logger.info(" - %s", name)
force_idx = force_index or force_all
needing_queries: list[str] = []
for name in missing:
qs = index_data.get(name, {}).get(
"synthetic_queries",
[],
)
if not qs or len(qs) < SYNTHETIC_QUERY_COUNT or force_idx:
needing_queries.append(name)
# 5. Generate synthetic queries
if needing_queries:
logger.info(
"Generating queries for %d tools...",
len(needing_queries),
)
http_client = httpx.AsyncClient(
timeout=httpx.Timeout(600.0, connect=30.0),
)
_qsem = 8 if openrouter_only else 3
sem = asyncio.Semaphore(_qsem)
async def gen(tool_name: str) -> None:
"""Generate synthetic queries for one tool and record them.
Closure over the enclosing :func:`update_tool_embeddings`
state. Bounded by the ``sem`` :class:`asyncio.Semaphore` (8
slots in OpenRouter-only mode, otherwise 3) so the concurrent
``asyncio.gather`` fan-out does not overwhelm the query-gen
backend. Looks the tool up in ``registered``, reads its
``description``, and awaits
:func:`classifiers.build_tool_index.generate_synthetic_queries`
over the shared ``http_client`` (Gemini, or OpenRouter chat
``google/gemini-3.1-flash-lite`` when ``openrouter_only`` is
set). The resulting queries are written back into the shared
``index_data`` dict under ``tool_name`` as ``name`` /
``description`` / ``synthetic_queries``; the caller persists
``index_data`` to ``tool_index_data.json`` afterward. Progress
is logged but no value is returned.
Args:
tool_name (str): Registered tool name to generate queries
for; must be a key of ``registered``.
Returns:
None: Results are stored in the shared ``index_data`` dict.
"""
async with sem:
tool = registered[tool_name]
desc = getattr(tool, "description", "") or ""
logger.info(
" Generating for: %s",
tool_name,
)
qs = await generate_synthetic_queries(
http_client,
None,
None,
tool_name,
desc,
openrouter_only=openrouter_only,
)
index_data[tool_name] = {
"name": tool_name,
"description": desc,
"synthetic_queries": qs,
}
logger.info(
" %s: %d queries",
tool_name,
len(qs),
)
await asyncio.gather(
*(gen(n) for n in needing_queries),
)
save_index_file(index_data)
# 6. Compute embeddings (Gemini pool or OpenRouter when OR-only active)
logger.info("Computing embeddings...")
emb_client = OpenRouterEmbeddings(model=_EMBED_MODEL_OPENROUTER)
if openrouter_only:
emb_client.MAX_BATCH_SIZE = min(128, max(50, emb_client.MAX_BATCH_SIZE))
embs_store: dict[str, str] = {}
meta_store: dict[str, str] = {}
tool_queries: dict[str, list[str]] = {}
meta_by_tool: dict[str, tuple[str, list[str]]] = {}
for tool_name in missing:
info = index_data.get(tool_name, {})
qs = info.get("synthetic_queries", [])
desc = info.get("description", "")
if len(qs) < SYNTHETIC_QUERY_COUNT:
raise RuntimeError(
f"{tool_name!r}: expected {SYNTHETIC_QUERY_COUNT} synthetic "
f"queries in the index, got {len(qs)}. "
"Regenerate with: python -m classifiers.update_tool_embeddings "
"--force-index",
)
tool_queries[tool_name] = qs
meta_by_tool[tool_name] = (desc, qs)
centroids = await compute_tool_centroids_bulk(
emb_client,
tool_queries,
)
for tool_name, centroid in centroids.items():
desc, qs = meta_by_tool[tool_name]
if centroid is not None:
embs_store[tool_name] = json.dumps(
centroid.tolist(),
)
meta_store[tool_name] = json.dumps(
{
"name": tool_name,
"description": desc,
"query_count": len(qs),
}
)
logger.info(
" %s: computed from %d queries",
tool_name,
len(qs),
)
else:
logger.warning(
" %s: failed to compute",
tool_name,
)
await emb_client.close()
# 7. Store in Redis (additive or full refresh)
if embs_store:
logger.info(
"Storing %d new embeddings...",
len(embs_store),
)
await redis_client.hset(
TOOL_EMBEDDINGS_HASH_KEY,
mapping=embs_store,
)
await redis_client.hset(
TOOL_METADATA_HASH_KEY,
mapping=meta_store,
)
for tn, js in embs_store.items():
vec = np.array(json.loads(js), dtype=np.float32)
meta = json.loads(meta_store[tn])
await store_tool_embedding_hash(
redis_client,
tn,
vec,
meta,
)
logger.info(
"Added %d tool embeddings to Redis",
len(embs_store),
)
else:
logger.warning("No new embeddings to store")
logger.info("=" * 60)
logger.info("Incremental update complete!")
logger.info("=" * 60)
return True
finally:
if openrouter_activated:
await clear_openrouter_only()
await redis_client.aclose()
[docs]
async def main() -> None:
"""Async CLI entry point that parses flags and runs the incremental update.
Builds an :class:`argparse.ArgumentParser` exposing ``--force-index`` /
``-f`` (regenerate synthetic queries), ``--force-all`` (recompute every tool,
not just Redis-absent ones), ``--openrouter-only`` (route query generation and
embeddings through OpenRouter), and ``--tools-dir`` (tool scan directory).
It then awaits :func:`update_tool_embeddings` with those values and translates
the returned boolean into a process exit code via :func:`sys.exit` (``0`` on
success, ``1`` on failure or unhandled exception, which is logged with a
traceback).
Invoked only by the module's ``if __name__ == "__main__"`` guard through
:func:`asyncio.run` (``python -m classifiers.update_tool_embeddings``); no
other internal callers were found.
Returns:
None: The process is terminated via :func:`sys.exit`.
"""
parser = argparse.ArgumentParser(
description=(
"Update tool embeddings incrementally " "(only add missing tools)"
),
)
parser.add_argument(
"--force-index",
"-f",
action="store_true",
help="Force regeneration of synthetic queries",
)
parser.add_argument(
"--force-all",
action="store_true",
help=(
"Regenerate synthetic queries and embeddings for every registered "
"tool (not only keys missing from Redis). Implies full query regen."
),
)
parser.add_argument(
"--openrouter-only",
action="store_true",
help=(
"Use OpenRouter only: chat "
"(google/gemini-3.1-flash-lite) for synthetic queries and "
"OpenRouter embeddings (google/gemini-embedding-001). "
"Requires OPENROUTER_* or API_KEY."
),
)
parser.add_argument(
"--tools-dir",
default="tools",
help="Tool scripts directory (default: tools)",
)
args = parser.parse_args()
try:
success = await update_tool_embeddings(
force_index=args.force_index,
tools_dir=args.tools_dir,
force_all=args.force_all,
openrouter_only=args.openrouter_only,
)
sys.exit(0 if success else 1)
except Exception as exc:
logger.error(
"Update failed: %s",
exc,
exc_info=True,
)
sys.exit(1)
if __name__ == "__main__":
asyncio.run(main())