Source code for classifiers.init_tool_embeddings

#!/usr/bin/env python3
"""Initialize tool embeddings in Redis.

Reads tool definitions and their synthetic queries from
``tool_index_data.json``, computes centroid embeddings via the
OpenRouter API, and stores them in Redis hashes for fast
vector-based tool selection at runtime.

Usage::

    python -m classifiers.init_tool_embeddings [--force]

Environment variables:
    OPENROUTER_API_KEY  -- required
    REDIS_URL           -- defaults to redis://localhost:6379/0
"""

from __future__ import annotations

import argparse
import asyncio
import logging
import os
import sys

import redis.asyncio as aioredis

sys.path.insert(
    0,
    os.path.abspath(
        os.path.join(os.path.dirname(__file__), ".."),
    ),
)

from classifiers.vector_classifier import (  # noqa: E402
    initialize_tool_embeddings_from_file,
)

logging.basicConfig(
    level=logging.INFO,
    format=(
        "%(asctime)s - %(name)s - "
        "%(levelname)s - %(message)s"
    ),
)
logger = logging.getLogger(__name__)


[docs] async def main() -> None: """Main. """ parser = argparse.ArgumentParser( description=( "Initialize tool embeddings in Redis " "for vector-based classification" ), ) parser.add_argument( "--force", "-f", action="store_true", help="Force recomputation even if present", ) parser.add_argument( "--index-file", type=str, default=None, help=( "Path to tool_index_data.json " "(defaults to classifiers/tool_index_data.json)" ), ) parser.add_argument( "--redis-url", type=str, default=None, help="Redis URL (env REDIS_URL or localhost)", ) args = parser.parse_args() index_file = args.index_file or os.path.join( os.path.dirname(__file__), "tool_index_data.json", ) if not os.path.exists(index_file): logger.error("Index file not found: %s", index_file) sys.exit(1) redis_url = ( args.redis_url or os.environ.get( "REDIS_URL", "redis://localhost:6379/0", ) ) try: from config import Config _ssl = Config.load().redis_ssl_kwargs() except Exception: _ssl = {} logger.info("=" * 60) logger.info("Tool Embeddings Initialization") logger.info("=" * 60) logger.info("Index file: %s", index_file) logger.info("Redis URL: %s", redis_url) logger.info("Force: %s", args.force) logger.info("=" * 60) redis_client = aioredis.from_url( redis_url, decode_responses=True, **_ssl, ) try: success = await initialize_tool_embeddings_from_file( index_file_path=index_file, redis_client=redis_client, force_recompute=args.force, ) if success: logger.info("=" * 60) logger.info("Initialization completed!") logger.info("=" * 60) sys.exit(0) else: logger.error("=" * 60) logger.error("Initialization failed!") logger.error("=" * 60) sys.exit(1) except Exception as exc: logger.error( "Initialization failed: %s", exc, exc_info=True, ) sys.exit(1) finally: await redis_client.aclose()
if __name__ == "__main__": asyncio.run(main())