Source code for classifiers.update_changed_tool_embeddings
#!/usr/bin/env python3
"""Surgical embedding update for tools whose descriptions changed.
This script is intentionally **smaller and safer** than
``update_tool_embeddings`` and ``refresh_tool_embeddings``:
- It **never** calls ``DEL``, ``HDEL``, ``FLUSHDB``, or any other delete
primitive on Redis. Only ``HSET`` for the explicit per-tool keys it
is updating.
- It does **not** prune orphaned entries (use ``update_tool_embeddings``
for that — once you're sure you actually want the prune).
- It does **not** regenerate embeddings for tools that did not change.
- It edits ``tool_index_data.json`` only for the targeted tools, leaving
every other entry byte-for-byte the same.
- If a target's query generation or centroid embedding fails, the script
skips that one tool and leaves Redis untouched for it.
Resolution order for "what changed":
1. ``--tools name1,name2`` — explicit allowlist (intersected with the
live tool registry; unknown names are warned and skipped).
2. Otherwise auto-detect by reading
``stargazer:tool_metadata`` from Redis and comparing each tool's
stored ``description`` against the live ``TOOL_DESCRIPTION``. Tools
missing from Redis entirely are **NOT** added — that's
``update_tool_embeddings``\\ ' job.
Pass ``--dry-run`` to print the plan without writing anything.
Usage::
python -m classifiers.update_changed_tool_embeddings
python -m classifiers.update_changed_tool_embeddings --tools redis_admin,rag_index_file
python -m classifiers.update_changed_tool_embeddings --dry-run
python -m classifiers.update_changed_tool_embeddings --openrouter-only
Env vars: ``REDIS_URL`` (default ``redis://localhost:6379/0``),
``OPENROUTER_QUERY_GEN_API_KEY`` / ``OPENROUTER_API_KEY`` / ``API_KEY``
when ``--openrouter-only`` is active.
"""
from __future__ import annotations
import argparse
import asyncio
import jsonutil as json
import logging
import os
import sys
from pathlib import Path
from typing import Any
import httpx
import numpy as np
import redis.asyncio as aioredis
sys.path.insert(
0,
os.path.abspath(os.path.join(os.path.dirname(__file__), "..")),
)
from gemini_embed_pool import ( # noqa: E402
clear_openrouter_only,
init_quota_tracking,
set_openrouter_only,
)
from classifiers.build_tool_index import ( # noqa: E402
SYNTHETIC_QUERY_COUNT,
_openrouter_query_gen_api_key,
generate_synthetic_queries,
)
from classifiers.redis_vector_index import ( # noqa: E402
store_tool_embedding_hash,
)
from classifiers.tool_embedding_batch import ( # noqa: E402
compute_tool_centroids_bulk,
)
from classifiers.update_tool_embeddings import ( # noqa: E402
INDEX_FILE,
discover_tools,
load_index_file,
save_index_file,
)
from classifiers.vector_classifier import ( # noqa: E402
TOOL_EMBEDDINGS_HASH_KEY,
TOOL_METADATA_HASH_KEY,
)
from rag_system.openrouter_embeddings import ( # noqa: E402
OpenRouterEmbeddings,
)
logging.basicConfig(
level=logging.INFO,
format="%(levelname)s: %(message)s",
)
logger = logging.getLogger(__name__)
_EMBED_MODEL_OPENROUTER = "google/gemini-embedding-001"
[docs]
async def find_stale_tools(
registered: dict[str, Any],
redis_client: aioredis.Redis,
) -> list[str]:
"""Return names of registered tools whose live description differs from
the one stored in Redis ``TOOL_METADATA_HASH_KEY``.
Tools whose name is **missing** from Redis are intentionally NOT included
— adding new tools is a job for ``update_tool_embeddings``.
"""
meta_raw: dict[str, str] = await redis_client.hgetall(
TOOL_METADATA_HASH_KEY,
)
stale: list[str] = []
skipped_missing: list[str] = []
for name, tool_def in registered.items():
stored = meta_raw.get(name)
if stored is None:
skipped_missing.append(name)
continue
try:
meta = json.loads(stored)
except (json.JSONDecodeError, TypeError):
stale.append(name)
continue
live_desc = (getattr(tool_def, "description", "") or "").strip()
stored_desc = (meta.get("description", "") or "").strip()
if live_desc != stored_desc:
stale.append(name)
if skipped_missing:
logger.info(
"Skipping %d tool(s) missing from Redis (use update_tool_embeddings to add): %s",
len(skipped_missing),
", ".join(sorted(skipped_missing)[:8])
+ (", ..." if len(skipped_missing) > 8 else ""),
)
return sorted(stale)
[docs]
async def update_changed_tool_embeddings(
*,
tool_names: list[str] | None = None,
tools_dir: str = "tools",
dry_run: bool = False,
openrouter_only: bool = False,
paid_key: str | None = None,
concurrency: int | None = None,
) -> bool:
"""Re-embed only the listed or description-changed tools, never wiping anything.
The core routine of this module: a deliberately conservative, additive
counterpart to :func:`update_tool_embeddings.update_tool_embeddings`. It
resolves a target set either from the explicit ``tool_names`` allowlist
(intersected with the live registry from :func:`discover_tools`) or, when
none is given, by auto-detecting stale tools through :func:`find_stale_tools`
(live ``description`` differing from the value stored in
``TOOL_METADATA_HASH_KEY``). For each target it regenerates synthetic queries
via the nested ``gen_one`` closure (Gemini, or OpenRouter when
``openrouter_only`` is set), computes centroids through
:func:`classifiers.tool_embedding_batch.compute_tool_centroids_bulk`, and
writes them with ``HSET`` only — into ``TOOL_EMBEDDINGS_HASH_KEY``,
``TOOL_METADATA_HASH_KEY``, and the per-tool RediSearch documents via
:func:`classifiers.redis_vector_index.store_tool_embedding_hash`. It performs
no ``DEL`` / ``HDEL`` / ``FLUSHDB`` and never adds tools missing from Redis,
and per-tool failures are isolated so one bad tool does not abort the run.
Opens (from :class:`config.Config` / ``REDIS_URL``) and always closes its own
async Redis connection, calls :func:`gemini_embed_pool.init_quota_tracking`,
optionally toggles OpenRouter-only mode, may export ``GEMINI_EMBED_PAID_KEY``
when ``paid_key`` is given, makes HTTP query-gen and embedding calls, and
rewrites only the targeted entries of ``tool_index_data.json`` via
:func:`save_index_file` (other entries preserved verbatim). A ``--dry-run``
path logs the plan and returns without any writes. Called only by :func:`main`
here (the ``python -m classifiers.update_changed_tool_embeddings`` entry
point); no other internal callers were found.
Args:
tool_names (list[str] | None): Explicit tool names to refresh; unknown
names are warned and skipped. When ``None`` or empty, stale tools are
auto-detected by description compare.
tools_dir (str): Directory to scan for tools. Defaults to ``"tools"``.
dry_run (bool): Print the plan (targets and intended HSETs) and return
``True`` without contacting the embedding backends or writing Redis or
the index file.
openrouter_only (bool): Route both query generation and embeddings through
OpenRouter; requires an OpenRouter or ``API_KEY`` credential and raises
the default embedding concurrency to 32.
paid_key (str | None): Pin synthetic-query generation to a single paid
Gemini key (also exported for the embed pool's fallback) and force its
use on the first call.
concurrency (int | None): Override for the synthetic-query generation
concurrency. When ``None``, defaults to 1 with ``paid_key``, 8 with
``openrouter_only``, otherwise 3.
Returns:
bool: ``True`` on success (including the no-stale-tools and dry-run
cases), ``False`` if a credential is missing, no valid targets were
given, or no target produced usable queries or centroids.
"""
if paid_key:
os.environ["GEMINI_EMBED_PAID_KEY"] = paid_key
os.environ["GEMINI_QUERY_GEN_PAID_AFTER_FAILURES"] = "0"
logger.info(
"Pinned to caller-supplied paid Gemini key (suffix ...%s) and set "
"GEMINI_QUERY_GEN_PAID_AFTER_FAILURES=0 so the first call uses it.",
paid_key[-6:],
)
logger.info("=" * 60)
logger.info(
"Surgical Tool Embeddings Update %s",
"(DRY RUN)" if dry_run else "",
)
logger.info("=" * 60)
from config import Config
cfg = Config.load()
redis_url = cfg.redis_url or os.environ.get(
"REDIS_URL",
"redis://localhost:6379/0",
)
redis_client = aioredis.from_url(
redis_url,
decode_responses=True,
**cfg.redis_connection_kwargs_for_url(redis_url),
)
init_quota_tracking(redis_client)
openrouter_activated = False
try:
if openrouter_only:
if not _openrouter_query_gen_api_key():
logger.error(
"--openrouter-only requires OPENROUTER_QUERY_GEN_API_KEY, "
"OPENROUTER_API_KEY, or API_KEY.",
)
return False
await set_openrouter_only()
openrouter_activated = True
os.environ.setdefault("TOOL_EMBED_OR_MAX_CONCURRENT", "32")
logger.info("Discovering registered tools...")
registered = discover_tools(tools_dir)
logger.info(" Found %d registered tools", len(registered))
if tool_names:
requested = {n.strip() for n in tool_names if n.strip()}
unknown = sorted(requested - set(registered.keys()))
if unknown:
logger.warning(
"Unknown tools (skipped): %s",
", ".join(unknown),
)
targets = sorted(requested & set(registered.keys()))
if not targets:
logger.error("No valid tool names provided.")
return False
logger.info(
"Explicit targets (%d): %s",
len(targets),
", ".join(targets),
)
else:
logger.info(
"Auto-detecting stale tools by description compare...",
)
targets = await find_stale_tools(registered, redis_client)
if not targets:
logger.info(
"No tools have a stale description. Nothing to do.",
)
return True
logger.info(
"Stale (%d): %s",
len(targets),
", ".join(targets),
)
index_path = Path(INDEX_FILE)
index_data = load_index_file()
if index_path.is_file() and not index_data:
logger.error(
"%s exists but did not parse to a non-empty object. "
"Refusing to overwrite the index. Repair the JSON and retry.",
index_path,
)
return False
if dry_run:
logger.info("=" * 60)
logger.info(
"DRY RUN: would regenerate synthetic queries + centroid for:",
)
for name in targets:
logger.info(" - %s", name)
logger.info(
"DRY RUN: would HSET %s and %s for those keys only.",
TOOL_EMBEDDINGS_HASH_KEY,
TOOL_METADATA_HASH_KEY,
)
logger.info("DRY RUN: would NOT call HDEL / DEL / FLUSHDB.")
logger.info(
"DRY RUN: would NOT touch %d other tool entries.",
len(registered) - len(targets),
)
logger.info("=" * 60)
return True
# ------------------------------------------------------------------
# 1. Generate fresh synthetic queries (only for targets)
# ------------------------------------------------------------------
http_client = httpx.AsyncClient(
timeout=httpx.Timeout(600.0, connect=30.0),
)
if concurrency is not None and concurrency > 0:
sem_n = concurrency
elif paid_key:
sem_n = 1
elif openrouter_only:
sem_n = 8
else:
sem_n = 3
sem = asyncio.Semaphore(sem_n)
logger.info("Synthetic-query concurrency: %d", sem_n)
regenerated: list[str] = []
gen_failures: list[str] = []
async def gen_one(tool_name: str) -> None:
"""Regenerate synthetic queries for one target tool, in place.
Acquires the enclosing ``sem`` semaphore (sized by the
paid-key/openrouter/default concurrency policy) and calls
``generate_synthetic_queries`` against the shared ``http_client``,
which talks to Gemini or, when ``openrouter_only`` is set, to
OpenRouter for both query generation. On success it overwrites the
tool's entry in the enclosing ``index_data`` dict (the in-memory
mirror of ``tool_index_data.json``) with the fresh name, description,
and queries, and records the name in ``regenerated``.
Failures are isolated per tool: if generation raises, or returns
fewer than ``SYNTHETIC_QUERY_COUNT`` queries, the tool name is added
to ``gen_failures`` and the function returns without mutating
``index_data`` or Redis, so a single bad tool never aborts the run.
This closure is defined and scheduled only inside
:func:`update_changed_tool_embeddings` (via :func:`asyncio.gather`);
it has no other callers.
Args:
tool_name: Registry name of the tool to (re)generate queries for;
must be a key of the enclosing ``registered`` mapping.
Returns:
``None``. Side effects are communicated through the enclosing
``index_data``, ``regenerated``, and ``gen_failures`` containers.
"""
async with sem:
tool = registered[tool_name]
desc = getattr(tool, "description", "") or ""
logger.info(" Generating queries for: %s", tool_name)
try:
qs = await generate_synthetic_queries(
http_client,
None,
None,
tool_name,
desc,
openrouter_only=openrouter_only,
)
except Exception as exc:
logger.error(
"Query generation failed for %s: %s",
tool_name,
exc,
)
gen_failures.append(tool_name)
return
if len(qs) < SYNTHETIC_QUERY_COUNT:
logger.warning(
"%s: only got %d/%d queries; skipping this tool.",
tool_name,
len(qs),
SYNTHETIC_QUERY_COUNT,
)
gen_failures.append(tool_name)
return
index_data[tool_name] = {
"name": tool_name,
"description": desc,
"synthetic_queries": qs,
}
regenerated.append(tool_name)
logger.info(
" %s: %d queries",
tool_name,
len(qs),
)
await asyncio.gather(*(gen_one(n) for n in targets))
await http_client.aclose()
ready = sorted(set(regenerated))
if not ready:
logger.error(
"No targets produced enough synthetic queries. "
"Redis was not modified. Failed: %s",
", ".join(sorted(gen_failures)) or "(none)",
)
return False
# ------------------------------------------------------------------
# 2. Compute centroids (only for ready targets)
# ------------------------------------------------------------------
logger.info("Computing centroids for %d tool(s)...", len(ready))
emb_client = OpenRouterEmbeddings(model=_EMBED_MODEL_OPENROUTER)
if openrouter_only:
emb_client.MAX_BATCH_SIZE = min(
128,
max(50, emb_client.MAX_BATCH_SIZE),
)
try:
tool_queries = {
name: index_data[name]["synthetic_queries"] for name in ready
}
centroids = await compute_tool_centroids_bulk(
emb_client,
tool_queries,
)
finally:
await emb_client.close()
embs_store: dict[str, str] = {}
meta_store: dict[str, str] = {}
for name in ready:
centroid = centroids.get(name)
if centroid is None:
logger.warning(
"%s: centroid computation failed; skipping HSET.",
name,
)
continue
desc = index_data[name].get("description", "")
qs = index_data[name].get("synthetic_queries", [])
embs_store[name] = json.dumps(centroid.tolist())
meta_store[name] = json.dumps(
{
"name": name,
"description": desc,
"query_count": len(qs),
}
)
if not embs_store:
logger.error(
"No centroids computed. Redis was not modified.",
)
return False
# ------------------------------------------------------------------
# 3. ADDITIVE writes only — HSET on the targeted keys.
# No HDEL, DEL, FLUSHDB, or unrelated tool keys.
# ------------------------------------------------------------------
logger.info(
"HSET %d key(s) into %s and %s (other %d tool entries untouched)...",
len(embs_store),
TOOL_EMBEDDINGS_HASH_KEY,
TOOL_METADATA_HASH_KEY,
len(registered) - len(embs_store),
)
await redis_client.hset(
TOOL_EMBEDDINGS_HASH_KEY,
mapping=embs_store,
)
await redis_client.hset(
TOOL_METADATA_HASH_KEY,
mapping=meta_store,
)
for name, blob in embs_store.items():
vec = np.array(json.loads(blob), dtype=np.float32)
meta = json.loads(meta_store[name])
await store_tool_embedding_hash(
redis_client,
name,
vec,
meta,
)
# ------------------------------------------------------------------
# 4. Persist index file (other entries preserved verbatim).
# ------------------------------------------------------------------
save_index_file(index_data)
logger.info("=" * 60)
logger.info(
"Done. Updated %d tool(s): %s",
len(embs_store),
", ".join(sorted(embs_store)),
)
if gen_failures:
logger.warning(
"Skipped %d tool(s) due to query gen / embedding failure: %s",
len(gen_failures),
", ".join(sorted(set(gen_failures))),
)
logger.info("=" * 60)
return True
finally:
if openrouter_activated:
await clear_openrouter_only()
await redis_client.aclose()
[docs]
async def main() -> None:
"""Async CLI entry point that parses flags and runs the surgical update.
Builds an :class:`argparse.ArgumentParser` exposing ``--tools`` / ``-t`` (a
comma-separated allowlist; auto-detect stale tools when omitted),
``--tools-dir``, ``--dry-run``, ``--openrouter-only``, ``--paid-key``, and
``--concurrency``. It splits the comma-separated tool list, then awaits
:func:`update_changed_tool_embeddings` with the parsed options and translates
the returned boolean into a process exit code via :func:`sys.exit` (``0`` on
success, ``1`` on failure or an unhandled exception, which is logged with a
traceback).
Invoked only by the module's ``if __name__ == "__main__"`` guard through
:func:`asyncio.run` (``python -m classifiers.update_changed_tool_embeddings``);
no other internal callers were found.
Returns:
None: The process is terminated via :func:`sys.exit`.
"""
parser = argparse.ArgumentParser(
description=(
"Surgically update embeddings for tools whose descriptions "
"changed. Never wipes the DB; never regenerates the whole cache."
),
)
parser.add_argument(
"--tools",
"-t",
default="",
help=(
"Comma-separated tool names to refresh. "
"If omitted, auto-detect stale tools via description compare."
),
)
parser.add_argument(
"--tools-dir",
default="tools",
help="Tool scripts directory (default: tools).",
)
parser.add_argument(
"--dry-run",
action="store_true",
help=(
"Print the plan (which tools would be re-embedded) without "
"calling Gemini/OpenRouter or writing to Redis or the index file."
),
)
parser.add_argument(
"--openrouter-only",
action="store_true",
help=(
"Use OpenRouter only for both synthetic-query generation and "
"embeddings (google/gemini-embedding-001). Requires "
"OPENROUTER_QUERY_GEN_API_KEY / OPENROUTER_API_KEY / API_KEY."
),
)
parser.add_argument(
"--paid-key",
default="",
help=(
"Pin synthetic-query generation to a single paid Gemini API key "
"(also exported as GEMINI_EMBED_PAID_KEY for the embed pool's "
"fallback). Sets GEMINI_QUERY_GEN_PAID_AFTER_FAILURES=0 and "
"drops generation concurrency to 1 (override with --concurrency)."
),
)
parser.add_argument(
"--concurrency",
type=int,
default=0,
help=(
"Override the synthetic-query gen concurrency. Default: 1 with "
"--paid-key, 8 with --openrouter-only, otherwise 3."
),
)
args = parser.parse_args()
tool_names = (
[n.strip() for n in args.tools.split(",") if n.strip()] if args.tools else None
)
try:
ok = await update_changed_tool_embeddings(
tool_names=tool_names,
tools_dir=args.tools_dir,
dry_run=args.dry_run,
openrouter_only=args.openrouter_only,
paid_key=args.paid_key.strip() or None,
concurrency=args.concurrency or None,
)
sys.exit(0 if ok else 1)
except Exception as exc:
logger.error("Surgical update failed: %s", exc, exc_info=True)
sys.exit(1)
if __name__ == "__main__":
asyncio.run(main())