"""Embeddings client for RAG system.
Uses the native Google Gemini API only, via the shared key pool in
gemini_embed_pool. Provides both async (OpenRouterEmbeddings) and
synchronous (SyncOpenRouterEmbeddings, ChromaDB-compatible) interfaces.
"""
import asyncio
import logging
import time
from typing import List, Optional, Sequence, Union
import httpx
import numpy as np
from gemini_embed_pool import (
GEMINI_EMBED_BASE,
PAID_KEY_FALLBACK_THRESHOLD,
OpenRouterEmbedParseError,
check_openrouter_only,
check_openrouter_only_sync,
gemini_embed_paid_fallback,
gemini_embed_paid_fallback_sync,
next_gemini_embed_key,
openrouter_embed_batch,
openrouter_embed_batch_sync,
set_openrouter_only,
)
logger = logging.getLogger(__name__)
MAX_499_RETRIES = 3
EMBED_RETRY_BASE_DELAY = 1.0
MAX_EMBED_DELAY = 8.0
_RETRIABLE_STATUSES = {429, 500, 502, 503, 504}
def _normalize_embed_texts_input(
texts: Union[str, Sequence[str]],
) -> List[str]:
"""Coerce *texts* to ``List[str]``.
Passing a bare string would iterate characters and trigger one embed per
character — a common mistake when JSON stores ``synthetic_queries`` as a
string instead of a list.
"""
if isinstance(texts, str):
return [texts]
return list(texts)
def _gemini_model_name(model: str) -> str:
"""Strip the ``google/`` vendor prefix from an embedding model name.
Normalizes an OpenRouter-style id like ``google/gemini-embedding-001`` to the
bare ``gemini-embedding-001`` that the native Gemini embeddings endpoint
expects. Pure string transform; a name without the prefix is returned
unchanged. Used by :class:`OpenRouterEmbeddings` when targeting the Gemini API.
Args:
model: Model id, optionally carrying a ``google/`` prefix.
Returns:
str: The model id with any leading ``google/`` removed.
"""
return model.removeprefix("google/")
[docs]
class OpenRouterEmbeddings:
"""Async embeddings client that calls the Gemini API via the shared key pool.
Despite the historical name, this client targets Google's Gemini embeddings
endpoint using keys drawn from the shared pool rather than OpenRouter: it
batches inputs (bounded by ``MAX_BATCH_SIZE`` / ``MAX_BATCH_CHARS``), retries
with backoff, and exposes :meth:`embed_text`/:meth:`embed_texts` returning
dense :class:`numpy.ndarray` vectors of width ``dimensions`` (default 3072).
Instantiated across the codebase wherever embeddings are needed -- the vector
tool classifier (:mod:`classifiers.vector_classifier`), the tool/skill/
dangerous-command embedding refreshers under ``classifiers/``, and
:mod:`tools.search_tools`; the file-RAG manager uses the sync sibling
``SyncOpenRouterEmbeddings``.
"""
DEFAULT_MODEL = "google/gemini-embedding-001"
MAX_BATCH_SIZE = 50
MAX_BATCH_CHARS = 50_000
[docs]
def __init__(
self,
api_key: Optional[str] = None,
model: str = DEFAULT_MODEL,
dimensions: Optional[int] = None,
timeout: float = 30.0,
gemini_api_key: Optional[str] = None,
gemini_only: bool = True,
):
"""Initialize the instance.
Args:
api_key (Optional[str]): Unused; kept for backward compatibility.
model (str): The model value.
dimensions (Optional[int]): The dimensions value.
timeout (float): Maximum wait time in seconds.
gemini_api_key (Optional[str]): Unused; pool is used instead.
gemini_only (bool): Always True; embeddings use Gemini API only.
"""
self.model = model
self.dimensions = dimensions or 3072
self.timeout = timeout
self.gemini_only = True # Always use Gemini API
self._client = httpx.AsyncClient(timeout=timeout)
logger.info(
"Initialized embeddings client with model: %s (Gemini API)",
model,
)
[docs]
async def embed_text(self, text: str) -> np.ndarray:
"""Embed a single string into one dense vector.
Thin convenience wrapper that wraps ``text`` in a one-element list,
delegates to :meth:`embed_texts` (which handles batching, retries, and
the Gemini-then-OpenRouter-then-paid fallback chain), and returns the
lone resulting vector. Performs no network I/O of its own beyond what
:meth:`embed_texts` does. Called by the vector classifier
(``classifiers/vector_classifier.py``) and the search tool
(``tools/search_tools.py``) to embed an incoming query before a
similarity lookup.
Args:
text (str): The text to embed.
Returns:
np.ndarray: A single ``float32`` embedding vector of length
``self.dimensions``.
"""
embeddings = await self.embed_texts([text])
return embeddings[0]
[docs]
async def embed_texts(
self,
texts: Union[str, Sequence[str]],
) -> List[np.ndarray]:
"""Embed one or more texts into dense vectors, batching as needed.
Top-level async entry point for embedding. It coerces ``texts`` to a
list via :func:`_normalize_embed_texts_input` (so a bare string is
treated as a single document rather than iterated character by
character), splits the input into size- and char-bounded batches with
:meth:`_create_batches`, and embeds each batch via :meth:`_embed_batch`
— which drives the Gemini API through the shared key pool and falls
back to OpenRouter and the paid Gemini key on sustained rate limits.
Called by :meth:`embed_text` here, and by the classifier embedding
refresh helpers (``classifiers/tool_embedding_batch.py``,
``classifiers/update_skill_embeddings.py``) when rebuilding routing
vectors.
Args:
texts: A list of strings, or a single string (treated as one
document — not iterated by character).
Returns:
List[np.ndarray]: One ``float32`` vector per input text, in input
order. Returns an empty list when ``texts`` is empty.
"""
texts = _normalize_embed_texts_input(texts)
if not texts:
return []
batches = self._create_batches(texts)
all_embeddings: List[np.ndarray] = []
for batch in batches:
all_embeddings.extend(await self._embed_batch(batch))
return all_embeddings
def _create_batches(self, texts: List[str]) -> List[List[str]]:
"""Split texts into batches bounded by item count and total chars.
Greedily packs consecutive texts into a batch until adding the next one
would exceed either ``MAX_BATCH_SIZE`` items or ``MAX_BATCH_CHARS``
characters, then starts a new batch. Keeping each request under the
Gemini batch limits avoids oversized payloads that the API would
reject. Pure in-memory helper with no I/O; called by
:meth:`embed_texts` before dispatching each batch to
:meth:`_embed_batch`.
Args:
texts (List[str]): The texts to partition, in order.
Returns:
List[List[str]]: A list of batches that together contain every
input text in its original order.
"""
batches: List[List[str]] = []
current_batch: List[str] = []
current_chars = 0
for text in texts:
text_len = len(text)
would_exceed_items = len(current_batch) >= self.MAX_BATCH_SIZE
would_exceed_chars = current_chars + text_len > self.MAX_BATCH_CHARS
if current_batch and (would_exceed_items or would_exceed_chars):
batches.append(current_batch)
current_batch = []
current_chars = 0
current_batch.append(text)
current_chars += text_len
if current_batch:
batches.append(current_batch)
return batches
async def _embed_batch(
self,
texts: List[str],
) -> List[np.ndarray]:
"""Embed a single batch via Gemini, retrying transient failures forever.
Wraps :meth:`_embed_batch_gemini` in an unbounded retry loop with
exponential back-off (``EMBED_RETRY_BASE_DELAY`` doubling up to
``MAX_EMBED_DELAY``) so a flaky network or upstream hiccup does not
abort an embedding run. A malformed-but-200 OpenRouter payload
(``OpenRouterEmbedParseError``) is treated as unrecoverable for this
batch and yields zero vectors rather than looping forever. Sleeps via
:func:`asyncio.sleep` between attempts. Called by :meth:`embed_texts`
once per batch produced by :meth:`_create_batches`.
Args:
texts (List[str]): One pre-sized batch of texts to embed.
Returns:
List[np.ndarray]: One ``float32`` vector per input text; all-zero
vectors are returned when the upstream payload is unparseable.
"""
round_num = 0
while True:
try:
return await self._embed_batch_gemini(texts)
except OpenRouterEmbedParseError as exc:
logger.warning(
"OpenRouter returned 200 but malformed payload — "
"giving up on %d texts (zero vectors): %s",
len(texts),
exc,
)
return [np.zeros(self.dimensions, dtype=np.float32) for _ in texts]
except Exception as exc:
round_num += 1
delay = min(
EMBED_RETRY_BASE_DELAY * (2 ** (round_num - 1)),
MAX_EMBED_DELAY,
)
logger.warning(
"Gemini embed failed (round %d), retrying in %.1fs: %s",
round_num,
delay,
exc,
)
await asyncio.sleep(delay)
async def _embed_batch_gemini(
self,
texts: List[str],
task_type: Optional[str] = None,
) -> List[np.ndarray]:
"""Embed a batch of texts via the native Gemini API (shared key pool).
Rotates to a fresh key on every 429. After
``PAID_KEY_FALLBACK_THRESHOLD`` consecutive non-daily 429s, escalates
to OpenRouter (and pins ``openrouter_only`` for 30 min); if OpenRouter
also fails, falls through to the paid tier-3 Gemini key as a last
resort.
Args:
texts: Texts to embed.
task_type: Optional Gemini task type (e.g. ``QUESTION_ANSWERING``,
``RETRIEVAL_DOCUMENT``). Omit for default behaviour.
"""
if not texts:
return []
valid_indices: list[int] = []
valid_texts: list[str] = []
for i, t in enumerate(texts):
if t and t.strip():
valid_indices.append(i)
valid_texts.append(t)
if not valid_texts:
z = np.zeros(self.dimensions, dtype=np.float32)
return [z.copy() for _ in texts]
if len(valid_texts) < len(texts):
partial = await self._embed_batch_gemini(
valid_texts,
task_type=task_type,
)
out: List[np.ndarray] = [
np.zeros(self.dimensions, dtype=np.float32) for _ in texts
]
for idx, emb in zip(valid_indices, partial):
out[idx] = emb
return out
texts = valid_texts
if await check_openrouter_only():
logger.info(
"OpenRouter-only mode — bypassing Gemini for %d texts", len(texts)
)
try:
vecs = await openrouter_embed_batch(
texts,
model=self.model,
dimensions=self.dimensions,
)
return [np.array(v, dtype=np.float32) for v in vecs]
except Exception:
logger.warning(
"OpenRouter failed while pinned to openrouter_only — "
"trying paid Gemini key as last resort",
exc_info=True,
)
paid_vecs = await gemini_embed_paid_fallback(
texts,
model=self.model,
dimensions=self.dimensions,
task_type=task_type,
)
return [np.array(v, dtype=np.float32) for v in paid_vecs]
gemini_model = _gemini_model_name(self.model)
requests_list = []
for t in texts:
req: dict = {
"model": f"models/{gemini_model}",
"content": {"parts": [{"text": t}]},
"output_dimensionality": self.dimensions,
}
if task_type:
req["taskType"] = task_type
requests_list.append(req)
payload = {"requests": requests_list}
consecutive_429 = 0
tried_or_then_paid = False
max_attempts = 20
for attempt in range(max_attempts):
api_key = next_gemini_embed_key()
url = (
f"{GEMINI_EMBED_BASE}/{gemini_model}:batchEmbedContents"
f"?key={api_key}"
)
if attempt > 0:
delay = (
1.0
if attempt <= PAID_KEY_FALLBACK_THRESHOLD
else min(
EMBED_RETRY_BASE_DELAY
* (2 ** (attempt - PAID_KEY_FALLBACK_THRESHOLD - 1)),
MAX_EMBED_DELAY,
)
)
await asyncio.sleep(delay)
last_error: str | None = None
try:
response = await self._client.post(url, json=payload)
if response.status_code == 429:
consecutive_429 += 1
if (
consecutive_429 >= PAID_KEY_FALLBACK_THRESHOLD
and not tried_or_then_paid
):
tried_or_then_paid = True
logger.warning(
"Free Gemini pool 429'd %d times — escalating "
"to OpenRouter, then paid Gemini key",
consecutive_429,
)
try:
vecs = await openrouter_embed_batch(
texts,
model=self.model,
dimensions=self.dimensions,
)
await set_openrouter_only()
return [np.array(v, dtype=np.float32) for v in vecs]
except Exception:
logger.warning(
"OpenRouter embed fallback failed — "
"trying paid Gemini key as last resort",
exc_info=True,
)
try:
paid_vecs = await gemini_embed_paid_fallback(
texts,
model=self.model,
dimensions=self.dimensions,
task_type=task_type,
)
return [np.array(v, dtype=np.float32) for v in paid_vecs]
except Exception:
logger.warning(
"Paid Gemini key fallback also failed",
exc_info=True,
)
last_error = "HTTP 429"
continue
if response.status_code in _RETRIABLE_STATUSES:
last_error = f"HTTP {response.status_code}"
continue
response.raise_for_status()
data = response.json()
return [
np.array(item["values"], dtype=np.float32)
for item in data["embeddings"]
]
except Exception as exc:
last_error = str(exc)
raise RuntimeError(
f"Gemini embed failed after {max_attempts} attempts: {last_error}"
)
[docs]
async def embed_text_for_search(
self,
text: str,
task_type: str = "QUESTION_ANSWERING",
) -> List[float]:
"""Embed a single text using the Gemini API only, with a task type.
Intended for pre-computing a query embedding before passing it to
``FileRAGManager.search(query_embedding=...)``. Retries on
transient errors with exponential back-off.
"""
if not text or not text.strip():
return [0.0] * self.dimensions
round_num = 0
while True:
try:
results = await self._embed_batch_gemini(
[text],
task_type=task_type,
)
return results[0].tolist()
except OpenRouterEmbedParseError as exc:
logger.warning(
"OpenRouter returned 200 but malformed payload for "
"embed_text_for_search — giving up (zero vector): %s",
exc,
)
return [0.0] * self.dimensions
except Exception as exc:
round_num += 1
delay = min(
EMBED_RETRY_BASE_DELAY * (2 ** (round_num - 1)),
MAX_EMBED_DELAY,
)
logger.warning(
"Gemini embed_text_for_search failed (round %d), "
"retrying in %.1fs: %s",
round_num,
delay,
exc,
)
await asyncio.sleep(delay)
[docs]
async def close(self):
"""Close the underlying httpx async client and release its connections.
Calls ``aclose`` on the shared :class:`httpx.AsyncClient` created in
:meth:`__init__`, freeing pooled sockets. Invoked directly by callers
that manage the client's lifetime, and automatically by
:meth:`__aexit__` when the instance is used as an async context
manager.
"""
await self._client.aclose()
[docs]
async def __aenter__(self):
"""Enter the async context manager, returning this client unchanged.
Lets the embeddings client be used with ``async with`` so its httpx
connections are guaranteed to be closed on exit via :meth:`__aexit__`.
Invoked by the Python runtime at the start of an ``async with`` block.
Returns:
OpenRouterEmbeddings: This same instance.
"""
return self
[docs]
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Exit the async context manager, closing the httpx client.
Delegates to :meth:`close` to release the pooled connections regardless
of whether the ``async with`` block exited normally or via an
exception. Invoked by the Python runtime at the end of an ``async
with`` block. Does not suppress exceptions.
Args:
exc_type: Exception type if the block raised, else ``None``.
exc_val: Exception instance if the block raised, else ``None``.
exc_tb: Traceback if the block raised, else ``None``.
"""
await self.close()
[docs]
class SyncOpenRouterEmbeddings:
"""Synchronous wrapper used by ChromaDB's embedding function interface.
Uses Gemini API via the shared key pool. Batches are dispatched
concurrently via a ThreadPoolExecutor when there are multiple batches.
"""
MAX_BATCH_SIZE = 50
MAX_BATCH_CHARS = 50_000
MAX_EMBED_WORKERS = 20
[docs]
def __init__(
self,
api_key: Optional[str] = None,
model: str = OpenRouterEmbeddings.DEFAULT_MODEL,
dimensions: Optional[int] = None,
timeout: float = 30.0,
gemini_api_key: Optional[str] = None,
gemini_only: bool = True,
document_task_type: Optional[str] = None,
query_task_type: Optional[str] = None,
):
"""Initialize the instance.
Args:
api_key (Optional[str]): Unused; kept for backward compatibility.
model (str): The model value.
dimensions (Optional[int]): The dimensions value.
timeout (float): Maximum wait time in seconds.
gemini_api_key (Optional[str]): Unused; pool is used instead.
gemini_only (bool): Unused; always Gemini API.
document_task_type: Optional Gemini ``taskType`` for corpus
(e.g. ``RETRIEVAL_DOCUMENT``); used by ``embed_documents``.
query_task_type: Optional Gemini ``taskType`` for queries
(e.g. ``RETRIEVAL_QUERY``); used by ``embed_query``.
"""
self.model = model
self.dimensions = dimensions or 3072
self.timeout = timeout
self._name = f"openrouter_{model.replace('/', '_')}"
self.is_legacy = False
self.document_task_type = document_task_type
self.query_task_type = query_task_type
import threading
self._local = threading.local()
def _get_client(self) -> httpx.Client:
"""Return a per-thread httpx client, creating or reopening as needed.
Because the sync embedder dispatches batches across a
``ThreadPoolExecutor``, each worker thread needs its own
:class:`httpx.Client`; sharing one would corrupt the connection pool.
Lazily creates a client on the ``threading.local`` store the first time
a thread asks, and replaces it if a previous one was closed. Called by
:meth:`_embed_batch_gemini` on the sync path immediately before each
HTTP request.
Returns:
httpx.Client: A live client bound to the calling thread, configured
with this instance's ``timeout``.
"""
client = getattr(self._local, "client", None)
if client is None or client.is_closed:
client = httpx.Client(timeout=self.timeout)
self._local.client = client
return client
[docs]
def name(self) -> str:
"""Return the stable identifier ChromaDB uses for this embedder.
Part of the ChromaDB ``EmbeddingFunction`` contract; the value
(derived in :meth:`__init__` from the model name) lets ChromaDB detect
when a collection's embedding function changes. Pure getter with no I/O.
Returns:
str: The embedder's name, e.g. ``openrouter_google_gemini-embedding-001``.
"""
return self._name
[docs]
def dimension(self) -> int:
"""Return the fixed embedding dimensionality reported to ChromaDB.
Part of the ChromaDB ``EmbeddingFunction`` contract, used to validate
that stored vectors match the collection's expected width. Returns the
constant 3072 produced by the Gemini embedding model. Pure getter with
no I/O.
Returns:
int: The vector length (3072).
"""
return 3072
[docs]
def __call__(self, input: List[str]) -> List[List[float]]:
"""Embed a list of texts via the legacy ChromaDB callable interface.
Implements the original ChromaDB ``EmbeddingFunction`` protocol where
the embedder itself is invoked as a function. Treats inputs as corpus
documents, applying ``document_task_type`` (matching
:meth:`embed_documents`), and delegates the actual batching and HTTP
work to :meth:`_embed_inputs`. Invoked by older ChromaDB versions and
any call site that calls the embedder object directly.
Args:
input (List[str]): Texts to embed.
Returns:
List[List[float]]: One embedding (list of floats) per input text.
"""
return self._embed_inputs(input, self.document_task_type)
def _embed_inputs(
self,
input: Union[str, Sequence[str]],
task_type: Optional[str],
) -> List[List[float]]:
"""Batch-embed texts synchronously, fanning out batches across threads.
Shared core behind :meth:`__call__`, :meth:`embed_documents`, and
:meth:`embed_query`. Normalizes ``input`` to a list, partitions it with
:meth:`_create_batches`, and embeds each batch through
:meth:`_embed_batch`. A single batch is embedded inline; multiple
batches are dispatched concurrently on a ``ThreadPoolExecutor`` (bounded
by ``MAX_EMBED_WORKERS``) and then reassembled in input order so results
line up with the original texts regardless of completion order.
Args:
input: A list of texts, or a single string (treated as one document).
task_type: Optional Gemini ``taskType`` (e.g. ``RETRIEVAL_DOCUMENT``
or ``RETRIEVAL_QUERY``) applied to every batch, or ``None`` for
the model default.
Returns:
List[List[float]]: One embedding per input text, in input order;
an empty list when ``input`` is empty.
"""
input = _normalize_embed_texts_input(input)
if not input:
return []
batches = self._create_batches(input)
if len(batches) == 1:
return self._embed_batch(batches[0], task_type=task_type)
from concurrent.futures import ThreadPoolExecutor, as_completed
workers = min(len(batches), self.MAX_EMBED_WORKERS)
ordered: list[tuple[int, List[List[float]]]] = []
with ThreadPoolExecutor(max_workers=workers) as pool:
futures = {
pool.submit(self._embed_batch, batch, task_type): idx
for idx, batch in enumerate(batches)
}
for future in as_completed(futures):
ordered.append((futures[future], future.result()))
ordered.sort(key=lambda x: x[0])
result: List[List[float]] = []
for _, embs in ordered:
result.extend(embs)
return result
# ChromaDB >=0.6 calls these instead of __call__
[docs]
def embed_documents(self, input: List[str]) -> List[List[float]]:
"""Embed corpus documents for the ChromaDB upsert path.
The modern (ChromaDB >= 0.6) entry point used when adding documents to
a collection. Applies ``document_task_type`` so vectors are optimized
for the retrieval-corpus side, then delegates to :meth:`_embed_inputs`.
Reached via the vector-store compatibility layer
(``vector_store.ChromaCompatCollection``), which prefers this method
over :meth:`__call__` when present.
Args:
input (List[str]): Document texts to embed.
Returns:
List[List[float]]: One embedding per document, in input order.
"""
return self._embed_inputs(input, self.document_task_type)
[docs]
def embed_query(self, input: List[str]) -> List[List[float]]:
"""Embed query texts for the ChromaDB query path.
The modern (ChromaDB >= 0.6) entry point used when searching a
collection. Applies ``query_task_type`` so vectors are optimized for
the query side of asymmetric retrieval, then delegates to
:meth:`_embed_inputs`. Reached via the vector-store compatibility layer
(``vector_store.ChromaCompatCollection``) when issuing a similarity
search.
Args:
input (List[str]): Query texts to embed.
Returns:
List[List[float]]: One embedding per query, in input order.
"""
return self._embed_inputs(input, self.query_task_type)
def _create_batches(self, texts: List[str]) -> List[List[str]]:
"""Split texts into batches bounded by item count and total chars.
Greedily packs consecutive texts into a batch until adding the next one
would exceed either ``MAX_BATCH_SIZE`` items or ``MAX_BATCH_CHARS``
characters, then starts a new batch, keeping each Gemini request within
the API's batch limits. Pure in-memory helper with no I/O; called by
:meth:`_embed_inputs` before the batches are embedded (in parallel when
there is more than one).
Args:
texts (List[str]): The texts to partition, in order.
Returns:
List[List[str]]: A list of batches that together contain every
input text in its original order.
"""
batches: List[List[str]] = []
current_batch: List[str] = []
current_chars = 0
for text in texts:
text_len = len(text)
would_exceed_items = len(current_batch) >= self.MAX_BATCH_SIZE
would_exceed_chars = current_chars + text_len > self.MAX_BATCH_CHARS
if current_batch and (would_exceed_items or would_exceed_chars):
batches.append(current_batch)
current_batch = []
current_chars = 0
current_batch.append(text)
current_chars += text_len
if current_batch:
batches.append(current_batch)
return batches
def _embed_batch(
self,
texts: List[str],
task_type: Optional[str] = None,
) -> List[List[float]]:
"""Embed one batch via Gemini (sync), retrying transient failures forever.
Synchronous counterpart to :meth:`OpenRouterEmbeddings._embed_batch`.
Wraps :meth:`_embed_batch_gemini` in an unbounded retry loop with
exponential back-off (``EMBED_RETRY_BASE_DELAY`` doubling up to
``MAX_EMBED_DELAY``, sleeping with :func:`time.sleep`) so transient
errors do not abort an embedding run. A malformed-but-200 OpenRouter
payload (``OpenRouterEmbedParseError``) is treated as unrecoverable for
this batch and yields zero vectors. Runs inside a worker thread when
invoked from :meth:`_embed_inputs`'s ``ThreadPoolExecutor``.
Args:
texts (List[str]): One pre-sized batch of texts to embed.
task_type (Optional[str]): Optional Gemini ``taskType`` forwarded
to the underlying request.
Returns:
List[List[float]]: One embedding per input text; all-zero vectors
when the upstream payload is unparseable.
"""
round_num = 0
while True:
try:
return self._embed_batch_gemini(texts, task_type=task_type)
except OpenRouterEmbedParseError as exc:
logger.warning(
"OpenRouter returned 200 but malformed payload (sync) — "
"giving up on %d texts (zero vectors): %s",
len(texts),
exc,
)
return [[0.0] * self.dimensions for _ in texts]
except Exception as exc:
round_num += 1
delay = min(
EMBED_RETRY_BASE_DELAY * (2 ** (round_num - 1)),
MAX_EMBED_DELAY,
)
logger.warning(
"Gemini embed failed (round %d), retrying in %.1fs: %s",
round_num,
delay,
exc,
)
time.sleep(delay)
def _embed_batch_gemini(
self,
texts: List[str],
task_type: Optional[str] = None,
) -> List[List[float]]:
"""Embed a batch of texts via the native Gemini API (shared key pool).
Rotates to a fresh key on every 429. After
``PAID_KEY_FALLBACK_THRESHOLD`` consecutive non-daily 429s, escalates
to OpenRouter (and pins ``openrouter_only`` for 30 min); if OpenRouter
also fails, falls through to the paid tier-3 Gemini key as a last
resort.
Args:
texts: Texts to embed.
task_type: Optional Gemini task type (e.g. ``RETRIEVAL_DOCUMENT``).
"""
if not texts:
return []
valid_indices: list[int] = []
valid_texts: list[str] = []
for i, t in enumerate(texts):
if t and t.strip():
valid_indices.append(i)
valid_texts.append(t)
if not valid_texts:
return [[0.0] * self.dimensions for _ in texts]
if len(valid_texts) < len(texts):
partial = self._embed_batch_gemini(valid_texts, task_type=task_type)
out: List[List[float]] = [[0.0] * self.dimensions for _ in texts]
for idx, emb in zip(valid_indices, partial):
out[idx] = emb
return out
texts = valid_texts
if check_openrouter_only_sync():
logger.info(
"OpenRouter-only mode (sync) — bypassing Gemini for %d texts",
len(texts),
)
try:
return openrouter_embed_batch_sync(
texts,
model=self.model,
dimensions=self.dimensions,
)
except Exception:
logger.warning(
"OpenRouter (sync) failed while pinned to openrouter_only — "
"trying paid Gemini key as last resort",
exc_info=True,
)
return gemini_embed_paid_fallback_sync(
texts,
model=self.model,
dimensions=self.dimensions,
task_type=task_type,
)
gemini_model = _gemini_model_name(self.model)
requests_list = []
for t in texts:
req: dict = {
"model": f"models/{gemini_model}",
"content": {"parts": [{"text": t}]},
"output_dimensionality": self.dimensions,
}
if task_type:
req["taskType"] = task_type
requests_list.append(req)
payload = {"requests": requests_list}
client = self._get_client()
consecutive_429 = 0
tried_or_then_paid = False
max_attempts = 20
for attempt in range(max_attempts):
api_key = next_gemini_embed_key()
url = (
f"{GEMINI_EMBED_BASE}/{gemini_model}:batchEmbedContents"
f"?key={api_key}"
)
if attempt > 0:
delay = (
1.0
if attempt <= PAID_KEY_FALLBACK_THRESHOLD
else min(
EMBED_RETRY_BASE_DELAY
* (2 ** (attempt - PAID_KEY_FALLBACK_THRESHOLD - 1)),
MAX_EMBED_DELAY,
)
)
time.sleep(delay)
last_error: str | None = None
try:
response = client.post(url, json=payload)
if response.status_code == 429:
consecutive_429 += 1
if (
consecutive_429 >= PAID_KEY_FALLBACK_THRESHOLD
and not tried_or_then_paid
):
tried_or_then_paid = True
logger.warning(
"Free Gemini pool 429'd %d times (sync) — "
"escalating to OpenRouter, then paid Gemini key",
consecutive_429,
)
try:
vecs = openrouter_embed_batch_sync(
texts,
model=self.model,
dimensions=self.dimensions,
)
import gemini_embed_pool as _gep
_gep._openrouter_only = True
logger.warning(
"OpenRouter-only mode ACTIVATED (sync, in-memory only)",
)
return vecs
except Exception:
logger.warning(
"OpenRouter embed fallback (sync) failed — "
"trying paid Gemini key as last resort",
exc_info=True,
)
try:
return gemini_embed_paid_fallback_sync(
texts,
model=self.model,
dimensions=self.dimensions,
task_type=task_type,
)
except Exception:
logger.warning(
"Paid Gemini key fallback (sync) also failed",
exc_info=True,
)
last_error = "HTTP 429"
continue
if response.status_code in _RETRIABLE_STATUSES:
last_error = f"HTTP {response.status_code}"
continue
response.raise_for_status()
data = response.json()
return [item["values"] for item in data["embeddings"]]
except Exception as exc:
last_error = str(exc)
raise RuntimeError(
f"Gemini embed failed after {max_attempts} attempts: {last_error}"
)