Source code for classifiers.update_tool_embeddings
#!/usr/bin/env python3
"""Incrementally update tool embeddings in Redis.
Discovers all registered tools, compares against what already
exists in Redis, generates synthetic queries for missing tools
via the Gemini API (using the shared embedding key pool), and
adds only the new embeddings -- without touching existing ones.
Usage::
python -m classifiers.update_tool_embeddings \
[--force-index] [--tools-dir tools]
Environment variables:
REDIS_URL -- defaults to redis://localhost:6379/0
"""
from __future__ import annotations
import argparse
import asyncio
import json
import logging
import os
import sys
from typing import Any
import numpy as np
import redis.asyncio as aioredis
import httpx
sys.path.insert(
0,
os.path.abspath(
os.path.join(os.path.dirname(__file__), ".."),
),
)
from tools import ToolRegistry # noqa: E402
from tool_loader import load_tools # noqa: E402
from rag_system.openrouter_embeddings import ( # noqa: E402
OpenRouterEmbeddings,
)
from classifiers.vector_classifier import ( # noqa: E402
TOOL_EMBEDDINGS_HASH_KEY,
TOOL_METADATA_HASH_KEY,
)
from classifiers.build_tool_index import ( # noqa: E402
generate_synthetic_queries,
SYNTHETIC_QUERY_COUNT,
)
logging.basicConfig(
level=logging.INFO,
format=(
"%(asctime)s - %(name)s - "
"%(levelname)s - %(message)s"
),
)
logger = logging.getLogger(__name__)
INDEX_FILE = os.path.join(
os.path.dirname(__file__), "tool_index_data.json",
)
[docs]
async def get_existing_redis_tools(
redis_client: aioredis.Redis,
) -> set[str]:
"""Return tool names already in Redis."""
try:
keys = await redis_client.hkeys(
TOOL_EMBEDDINGS_HASH_KEY,
)
return {
k.decode("utf-8") if isinstance(k, bytes) else k
for k in keys
}
except Exception as exc:
logger.error(
"Failed to fetch existing tools: %s", exc,
)
return set()
[docs]
def discover_tools(
tools_dir: str = "tools",
) -> dict[str, Any]:
"""Auto-discover all registered tools keyed by name."""
registry = ToolRegistry()
load_tools(tools_dir, registry)
return {td.name: td for td in registry.list_tools()}
[docs]
def load_index_file() -> dict[str, Any]:
"""Load index file from the configured source.
Returns:
dict[str, Any]: The result.
"""
if os.path.exists(INDEX_FILE):
try:
with open(
INDEX_FILE, "r", encoding="utf-8",
) as f:
return json.load(f)
except Exception as exc:
logger.warning(
"Failed to load index file: %s", exc,
)
return {}
[docs]
def save_index_file(data: dict[str, Any]) -> None:
"""Save index file.
Args:
data (dict[str, Any]): Input data payload.
"""
with open(INDEX_FILE, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2)
logger.info(
"Saved index file with %d tools", len(data),
)
[docs]
async def compute_tool_embedding(
embedding_client: OpenRouterEmbeddings,
synthetic_queries: list[str],
tool_name: str,
) -> np.ndarray | None:
"""Compute a normalised centroid from queries."""
try:
embs = [
e for e in
await embedding_client.embed_texts(
synthetic_queries,
)
if e.size > 0
]
except Exception as exc:
logger.warning(
"Failed to embed queries for %r: %s",
tool_name, exc,
)
return None
if not embs:
return None
centroid = np.mean(embs, axis=0)
norm = np.linalg.norm(centroid)
if norm > 0:
centroid = centroid / norm
return centroid
[docs]
async def update_tool_embeddings(
force_index: bool = False,
tools_dir: str = "tools",
) -> bool:
"""Main incremental update routine."""
logger.info("=" * 60)
logger.info("Incremental Tool Embeddings Update")
logger.info("=" * 60)
from config import Config
cfg = Config.load()
redis_url = cfg.redis_url or os.environ.get(
"REDIS_URL", "redis://localhost:6379/0",
)
redis_client = aioredis.from_url(
redis_url, decode_responses=True, **cfg.redis_ssl_kwargs(),
)
try:
# 1. Discover tools
logger.info("Discovering registered tools...")
registered = discover_tools(tools_dir)
logger.info(
" Found %d registered tools",
len(registered),
)
# 2. Check Redis
logger.info("Checking existing embeddings...")
existing = await get_existing_redis_tools(
redis_client,
)
logger.info(
" Found %d with embeddings", len(existing),
)
# 3. Identify missing and orphaned
registered_names = set(registered.keys())
missing = registered_names - existing
orphaned = existing - registered_names
# 3a. Prune orphaned embeddings (tools removed from codebase)
if orphaned:
logger.info(
"Pruning %d orphaned embeddings:",
len(orphaned),
)
for name in sorted(orphaned):
logger.info(" - %s", name)
await redis_client.hdel(
TOOL_EMBEDDINGS_HASH_KEY, *orphaned,
)
await redis_client.hdel(
TOOL_METADATA_HASH_KEY, *orphaned,
)
logger.info(
"Removed %d orphaned entries from Redis",
len(orphaned),
)
# 3b. Load index file and prune orphans from it too
logger.info("Loading tool index file...")
index_data = load_index_file()
pruned_from_index = 0
for name in list(index_data.keys()):
if name not in registered_names:
del index_data[name]
pruned_from_index += 1
if pruned_from_index:
save_index_file(index_data)
logger.info(
"Pruned %d orphaned entries from index file",
pruned_from_index,
)
if not missing:
logger.info("All tools have embeddings!")
return True
logger.info(
"Found %d tools without embeddings:",
len(missing),
)
for name in sorted(missing):
logger.info(" - %s", name)
needing_queries: list[str] = []
for name in missing:
qs = index_data.get(name, {}).get(
"synthetic_queries", [],
)
if (
not qs
or len(qs) < SYNTHETIC_QUERY_COUNT
or force_index
):
needing_queries.append(name)
# 5. Generate synthetic queries
if needing_queries:
logger.info(
"Generating queries for %d tools...",
len(needing_queries),
)
http_client = httpx.AsyncClient(
timeout=httpx.Timeout(120.0, connect=10.0),
)
sem = asyncio.Semaphore(3)
async def gen(tool_name: str) -> None:
async with sem:
tool = registered[tool_name]
desc = (
getattr(tool, "description", "")
or ""
)
logger.info(
" Generating for: %s", tool_name,
)
qs = await generate_synthetic_queries(
http_client, None, None,
tool_name, desc,
)
if qs:
index_data[tool_name] = {
"name": tool_name,
"description": desc,
"synthetic_queries": qs,
}
logger.info(
" %s: %d queries",
tool_name, len(qs),
)
else:
logger.warning(
" %s: failed", tool_name,
)
await asyncio.gather(
*(gen(n) for n in needing_queries),
)
save_index_file(index_data)
# 6. Compute embeddings (Gemini API via shared key pool)
logger.info("Computing embeddings...")
emb_client = OpenRouterEmbeddings()
embs_store: dict[str, str] = {}
meta_store: dict[str, str] = {}
for tool_name in missing:
info = index_data.get(tool_name, {})
qs = info.get("synthetic_queries", [])
desc = info.get("description", "")
if not qs:
logger.warning(
" %s: no queries, using desc",
tool_name,
)
qs = (
[desc] if desc
else [f"use {tool_name}"]
)
centroid = await compute_tool_embedding(
emb_client, qs, tool_name,
)
if centroid is not None:
embs_store[tool_name] = json.dumps(
centroid.tolist(),
)
meta_store[tool_name] = json.dumps({
"name": tool_name,
"description": desc,
"query_count": len(qs),
})
logger.info(
" %s: computed from %d queries",
tool_name, len(qs),
)
else:
logger.warning(
" %s: failed to compute",
tool_name,
)
await emb_client.close()
# 7. Store in Redis (additive)
if embs_store:
logger.info(
"Storing %d new embeddings...",
len(embs_store),
)
await redis_client.hset(
TOOL_EMBEDDINGS_HASH_KEY,
mapping=embs_store,
)
await redis_client.hset(
TOOL_METADATA_HASH_KEY,
mapping=meta_store,
)
logger.info(
"Added %d tool embeddings to Redis",
len(embs_store),
)
else:
logger.warning("No new embeddings to store")
logger.info("=" * 60)
logger.info("Incremental update complete!")
logger.info("=" * 60)
return True
finally:
await redis_client.aclose()
[docs]
async def main() -> None:
"""Main.
"""
parser = argparse.ArgumentParser(
description=(
"Update tool embeddings incrementally "
"(only add missing tools)"
),
)
parser.add_argument(
"--force-index", "-f",
action="store_true",
help="Force regeneration of synthetic queries",
)
parser.add_argument(
"--tools-dir",
default="tools",
help="Tool scripts directory (default: tools)",
)
args = parser.parse_args()
try:
success = await update_tool_embeddings(
force_index=args.force_index,
tools_dir=args.tools_dir,
)
sys.exit(0 if success else 1)
except Exception as exc:
logger.error(
"Update failed: %s", exc, exc_info=True,
)
sys.exit(1)
if __name__ == "__main__":
asyncio.run(main())