#!/usr/bin/env python3
"""Initialize tool embeddings in Redis.
Reads tool definitions and their synthetic queries from
``tool_index_data.json``, computes centroid embeddings via the
OpenRouter API, and stores them in Redis hashes for fast
vector-based tool selection at runtime.
Usage::
python -m classifiers.init_tool_embeddings [--force]
Environment variables:
OPENROUTER_API_KEY -- required
REDIS_URL -- defaults to redis://localhost:6379/0
"""
from __future__ import annotations
import argparse
import asyncio
import logging
import os
import sys
import redis.asyncio as aioredis
sys.path.insert(
0,
os.path.abspath(
os.path.join(os.path.dirname(__file__), ".."),
),
)
from classifiers.vector_classifier import ( # noqa: E402
initialize_tool_embeddings_from_file,
)
logging.basicConfig(
level=logging.INFO,
format=(
"%(asctime)s - %(name)s - "
"%(levelname)s - %(message)s"
),
)
logger = logging.getLogger(__name__)
[docs]
async def main() -> None:
"""Main.
"""
parser = argparse.ArgumentParser(
description=(
"Initialize tool embeddings in Redis "
"for vector-based classification"
),
)
parser.add_argument(
"--force", "-f",
action="store_true",
help="Force recomputation even if present",
)
parser.add_argument(
"--index-file",
type=str,
default=None,
help=(
"Path to tool_index_data.json "
"(defaults to classifiers/tool_index_data.json)"
),
)
parser.add_argument(
"--redis-url",
type=str,
default=None,
help="Redis URL (env REDIS_URL or localhost)",
)
args = parser.parse_args()
index_file = args.index_file or os.path.join(
os.path.dirname(__file__),
"tool_index_data.json",
)
if not os.path.exists(index_file):
logger.error("Index file not found: %s", index_file)
sys.exit(1)
redis_url = (
args.redis_url
or os.environ.get(
"REDIS_URL", "redis://localhost:6379/0",
)
)
try:
from config import Config
_ssl = Config.load().redis_ssl_kwargs()
except Exception:
_ssl = {}
logger.info("=" * 60)
logger.info("Tool Embeddings Initialization")
logger.info("=" * 60)
logger.info("Index file: %s", index_file)
logger.info("Redis URL: %s", redis_url)
logger.info("Force: %s", args.force)
logger.info("=" * 60)
redis_client = aioredis.from_url(
redis_url, decode_responses=True, **_ssl,
)
try:
success = await initialize_tool_embeddings_from_file(
index_file_path=index_file,
redis_client=redis_client,
force_recompute=args.force,
)
if success:
logger.info("=" * 60)
logger.info("Initialization completed!")
logger.info("=" * 60)
sys.exit(0)
else:
logger.error("=" * 60)
logger.error("Initialization failed!")
logger.error("=" * 60)
sys.exit(1)
except Exception as exc:
logger.error(
"Initialization failed: %s",
exc, exc_info=True,
)
sys.exit(1)
finally:
await redis_client.aclose()
if __name__ == "__main__":
asyncio.run(main())