Source code for classifiers.tool_embedding_batch

"""Shared helpers for tool centroid embeddings (batch-friendly)."""

from __future__ import annotations

import asyncio
import logging
import os
from typing import Any, Sequence, Union

import numpy as np

from gemini_embed_pool import check_openrouter_only
from rag_system.openrouter_embeddings import (
    OpenRouterEmbeddings,
    _normalize_embed_texts_input,
)

logger = logging.getLogger(__name__)


def _tool_embed_or_max_concurrent() -> int:
    """Resolve the max concurrent embedding HTTP requests for tool batches.

    Reads the ``TOOL_EMBED_OR_MAX_CONCURRENT`` environment variable (default
    ``12``) and clamps it to the inclusive range 1..64, falling back to 12 when
    the value is not a valid integer. This caps the parallelism of the
    OpenRouter-only fan-out in :func:`embed_texts_for_tool_scripts` so tool
    index / Redis refreshes do not overwhelm the embedding endpoint.

    Touches the process environment only (no network, Redis, or filesystem
    I/O). Called by :func:`embed_texts_for_tool_scripts` as the default when no
    explicit ``max_concurrent`` is passed. No other in-repo callers were found.

    Returns:
        The clamped concurrency limit as an int in ``[1, 64]``.
    """
    raw = os.environ.get("TOOL_EMBED_OR_MAX_CONCURRENT", "12")
    try:
        n = int(raw.strip())
    except ValueError:
        return 12
    return max(1, min(64, n))


[docs] async def embed_texts_for_tool_scripts( embedding_client: OpenRouterEmbeddings, texts: Union[str, Sequence[str]], *, max_concurrent: int | None = None, ) -> list[np.ndarray]: """Embed texts for tool-index / Redis flows only. When OpenRouter-only mode is active, runs multiple client batches in parallel (bounded by semaphore). Otherwise uses sequential ``embed_texts`` so Gemini rate limits are not compounded. """ texts_norm = _normalize_embed_texts_input(texts) if not texts_norm: return [] batches = embedding_client._create_batches(texts_norm) mc = ( max_concurrent if max_concurrent is not None else _tool_embed_or_max_concurrent() ) if len(batches) <= 1 or not await check_openrouter_only(): return await embedding_client.embed_texts(texts_norm) logger.info( "Tool embeddings (OpenRouter-only): %d batches, up to %d concurrent", len(batches), mc, ) sem = asyncio.Semaphore(mc) async def _run( idx: int, batch: list[str], ) -> tuple[int, list[np.ndarray]]: """Embed one batch under the concurrency semaphore, tagged by order. Acquires the enclosing ``sem`` semaphore (bounded by *mc*) before calling the OpenRouter client's ``_embed_batch`` coroutine, which issues the actual embedding HTTP request to OpenRouter. The original *idx* is carried through alongside the result so the caller can restore input order after :func:`asyncio.gather` completes the batches out of order. This nested closure is defined and scheduled only inside :func:`embed_texts_for_tool_scripts`; it has no other callers. Args: idx: Zero-based position of *batch* within the flattened batch list, used to re-sort results back into request order. batch: The list of text strings making up this one client batch. Returns: A ``(idx, vectors)`` tuple where ``vectors`` is the list of embedding arrays returned for *batch*. """ async with sem: part = await embedding_client._embed_batch(batch) return idx, part parts = await asyncio.gather( *(_run(i, b) for i, b in enumerate(batches)), ) parts.sort(key=lambda x: x[0]) out: list[np.ndarray] = [] for _, vecs in parts: out.extend(vecs) return out
[docs] def normalize_synthetic_queries(qs: Any) -> list[str]: """Ensure synthetic queries are a list of strings. If *qs* is a single string (e.g. bad JSON), treat it as one query. Iterating a bare string in ``embed_texts`` would embed per-character. """ if isinstance(qs, str): return [qs] if not qs: return [] return [str(x) for x in qs]
def _centroid_from_vectors(vectors: list[np.ndarray]) -> np.ndarray | None: """Average a list of embedding vectors into a single L2-normalised centroid. Stacks *vectors* into a matrix, takes the per-dimension mean, and divides by the result's L2 norm so the returned centroid is unit length (when the mean is non-zero). This is the shared reduction step that turns a tool's (or a category's) synthetic-query embeddings into one representative vector for cosine search in RediSearch. This is a pure NumPy helper with no I/O or external collaborators. It is called internally by :func:`compute_tool_embedding` (single tool) and :func:`compute_tool_centroids_bulk` (per-tool slices); no callers outside this module were found. Args: vectors: Embedding vectors to average. Callers are expected to have already filtered out empty (``size == 0``) vectors. Returns: The unit-length centroid as a ``numpy.ndarray``, or ``None`` when *vectors* is empty. If the mean vector has zero norm it is returned un-normalised rather than dividing by zero. """ if not vectors: return None centroid = np.mean(np.stack(vectors, axis=0), axis=0) norm = np.linalg.norm(centroid) if norm > 0: centroid = centroid / norm return centroid
[docs] async def compute_tool_embedding( embedding_client: OpenRouterEmbeddings, synthetic_queries: list[str], tool_name: str, ) -> np.ndarray | None: """Compute one tool's centroid embedding from its synthetic queries. Single-tool convenience wrapper: it normalises the synthetic queries, embeds them via :func:`embed_texts_for_tool_scripts` (which calls the OpenRouter embeddings client, fanning out batches concurrently in OpenRouter-only mode), drops any empty vectors, and reduces the rest to one L2-normalised centroid via :func:`_centroid_from_vectors`. That centroid is the per-tool vector used for cosine routing in the classifier. Embedding failures are caught and logged, returning ``None`` so a single bad tool does not abort a larger refresh. Issues embedding HTTP requests through the passed-in client and logs a warning on failure; it does not touch Redis or the filesystem. This is the single-tool counterpart to :func:`compute_tool_centroids_bulk`, which the refresh paths use instead; no in-repo callers of this function were found (it is retained as a standalone single-tool helper). Args: embedding_client: OpenRouter embeddings client used to embed the queries. synthetic_queries: The tool's synthetic query strings (normalised internally; a bare string is treated as one query). tool_name: Tool name, used only for log context on failure. Returns: The unit-length centroid ``numpy.ndarray`` for the tool, or ``None`` if there are no usable queries or embedding failed. """ qs = normalize_synthetic_queries(synthetic_queries) if not qs: return None try: embs = [ e for e in await embed_texts_for_tool_scripts( embedding_client, qs, ) if e.size > 0 ] except Exception as exc: logger.warning( "Failed to embed queries for %r: %s", tool_name, exc, ) return None return _centroid_from_vectors(embs)
[docs] async def compute_tool_centroids_bulk( embedding_client: OpenRouterEmbeddings, tool_queries: dict[str, list[str]], ) -> dict[str, np.ndarray | None]: """Compute normalised centroids for many tools in minimal ``embed_texts`` calls. Flattens all query strings, embeds once (subject to client batch limits), then slices vectors per tool and reduces to centroids. Avoids one HTTP round-trip per tool when refreshing many tools. """ spans: list[tuple[str, int, int]] = [] flat: list[str] = [] for name, raw in tool_queries.items(): qs = normalize_synthetic_queries(raw) if not qs: continue start = len(flat) flat.extend(qs) spans.append((name, start, len(flat))) out: dict[str, np.ndarray | None] = {n: None for n in tool_queries} if not flat: return out try: logger.info( "Sending %d query strings to embedding API for tool centroids...", len(flat), ) all_vecs = await embed_texts_for_tool_scripts( embedding_client, flat, ) logger.info("Received %d vectors from API.", len(all_vecs)) except Exception as exc: logger.warning("Bulk tool embedding failed: %s", exc) return out if len(all_vecs) != len(flat): logger.warning( "Bulk embed length mismatch: got %d vectors for %d texts", len(all_vecs), len(flat), ) return {n: None for n in tool_queries} for name, a, b in spans: chunk = [e for e in all_vecs[a:b] if e.size > 0] out[name] = _centroid_from_vectors(chunk) return out