"""Shared helpers for tool centroid embeddings (batch-friendly)."""
from __future__ import annotations
import asyncio
import logging
import os
from typing import Any, Sequence, Union
import numpy as np
from gemini_embed_pool import check_openrouter_only
from rag_system.openrouter_embeddings import (
OpenRouterEmbeddings,
_normalize_embed_texts_input,
)
logger = logging.getLogger(__name__)
def _tool_embed_or_max_concurrent() -> int:
"""Resolve the max concurrent embedding HTTP requests for tool batches.
Reads the ``TOOL_EMBED_OR_MAX_CONCURRENT`` environment variable (default
``12``) and clamps it to the inclusive range 1..64, falling back to 12 when
the value is not a valid integer. This caps the parallelism of the
OpenRouter-only fan-out in :func:`embed_texts_for_tool_scripts` so tool
index / Redis refreshes do not overwhelm the embedding endpoint.
Touches the process environment only (no network, Redis, or filesystem
I/O). Called by :func:`embed_texts_for_tool_scripts` as the default when no
explicit ``max_concurrent`` is passed. No other in-repo callers were found.
Returns:
The clamped concurrency limit as an int in ``[1, 64]``.
"""
raw = os.environ.get("TOOL_EMBED_OR_MAX_CONCURRENT", "12")
try:
n = int(raw.strip())
except ValueError:
return 12
return max(1, min(64, n))
[docs]
async def embed_texts_for_tool_scripts(
embedding_client: OpenRouterEmbeddings,
texts: Union[str, Sequence[str]],
*,
max_concurrent: int | None = None,
) -> list[np.ndarray]:
"""Embed texts for tool-index / Redis flows only.
When OpenRouter-only mode is active, runs multiple client batches in
parallel (bounded by semaphore). Otherwise uses sequential
``embed_texts`` so Gemini rate limits are not compounded.
"""
texts_norm = _normalize_embed_texts_input(texts)
if not texts_norm:
return []
batches = embedding_client._create_batches(texts_norm)
mc = (
max_concurrent
if max_concurrent is not None
else _tool_embed_or_max_concurrent()
)
if len(batches) <= 1 or not await check_openrouter_only():
return await embedding_client.embed_texts(texts_norm)
logger.info(
"Tool embeddings (OpenRouter-only): %d batches, up to %d concurrent",
len(batches),
mc,
)
sem = asyncio.Semaphore(mc)
async def _run(
idx: int,
batch: list[str],
) -> tuple[int, list[np.ndarray]]:
"""Embed one batch under the concurrency semaphore, tagged by order.
Acquires the enclosing ``sem`` semaphore (bounded by *mc*) before
calling the OpenRouter client's ``_embed_batch`` coroutine, which issues
the actual embedding HTTP request to OpenRouter. The original *idx* is
carried through alongside the result so the caller can restore input
order after :func:`asyncio.gather` completes the batches out of order.
This nested closure is defined and scheduled only inside
:func:`embed_texts_for_tool_scripts`; it has no other callers.
Args:
idx: Zero-based position of *batch* within the flattened batch list,
used to re-sort results back into request order.
batch: The list of text strings making up this one client batch.
Returns:
A ``(idx, vectors)`` tuple where ``vectors`` is the list of embedding
arrays returned for *batch*.
"""
async with sem:
part = await embedding_client._embed_batch(batch)
return idx, part
parts = await asyncio.gather(
*(_run(i, b) for i, b in enumerate(batches)),
)
parts.sort(key=lambda x: x[0])
out: list[np.ndarray] = []
for _, vecs in parts:
out.extend(vecs)
return out
[docs]
def normalize_synthetic_queries(qs: Any) -> list[str]:
"""Ensure synthetic queries are a list of strings.
If *qs* is a single string (e.g. bad JSON), treat it as one query.
Iterating a bare string in ``embed_texts`` would embed per-character.
"""
if isinstance(qs, str):
return [qs]
if not qs:
return []
return [str(x) for x in qs]
def _centroid_from_vectors(vectors: list[np.ndarray]) -> np.ndarray | None:
"""Average a list of embedding vectors into a single L2-normalised centroid.
Stacks *vectors* into a matrix, takes the per-dimension mean, and divides by
the result's L2 norm so the returned centroid is unit length (when the mean
is non-zero). This is the shared reduction step that turns a tool's (or a
category's) synthetic-query embeddings into one representative vector for
cosine search in RediSearch.
This is a pure NumPy helper with no I/O or external collaborators. It is
called internally by :func:`compute_tool_embedding` (single tool) and
:func:`compute_tool_centroids_bulk` (per-tool slices); no callers outside
this module were found.
Args:
vectors: Embedding vectors to average. Callers are expected to have
already filtered out empty (``size == 0``) vectors.
Returns:
The unit-length centroid as a ``numpy.ndarray``, or ``None`` when
*vectors* is empty. If the mean vector has zero norm it is returned
un-normalised rather than dividing by zero.
"""
if not vectors:
return None
centroid = np.mean(np.stack(vectors, axis=0), axis=0)
norm = np.linalg.norm(centroid)
if norm > 0:
centroid = centroid / norm
return centroid