"""Embeddings client for RAG system.
Uses the native Google Gemini API only, via the shared key pool in
gemini_embed_pool. Provides both async (OpenRouterEmbeddings) and
synchronous (SyncOpenRouterEmbeddings, ChromaDB-compatible) interfaces.
"""
import asyncio
import logging
import time
from typing import List, Optional
import httpx
import numpy as np
from gemini_embed_pool import (
GEMINI_EMBED_BASE,
PAID_KEY_FALLBACK_THRESHOLD,
check_openrouter_only,
check_openrouter_only_sync,
get_paid_fallback_key,
next_gemini_embed_key,
openrouter_embed_batch,
openrouter_embed_batch_sync,
set_openrouter_only,
)
logger = logging.getLogger(__name__)
MAX_499_RETRIES = 3
EMBED_RETRY_BASE_DELAY = 1.0
MAX_EMBED_DELAY = 8.0
_RETRIABLE_STATUSES = {429, 500, 502, 503, 504}
def _gemini_model_name(model: str) -> str:
"""Convert ``google/gemini-embedding-001`` → ``gemini-embedding-001``."""
return model.removeprefix("google/")
[docs]
class OpenRouterEmbeddings:
"""Async embeddings client using Gemini API via the shared key pool."""
DEFAULT_MODEL = "google/gemini-embedding-001"
MAX_BATCH_SIZE = 50
MAX_BATCH_CHARS = 50_000
[docs]
def __init__(
self,
api_key: Optional[str] = None,
model: str = DEFAULT_MODEL,
dimensions: Optional[int] = None,
timeout: float = 30.0,
gemini_api_key: Optional[str] = None,
gemini_only: bool = True,
):
"""Initialize the instance.
Args:
api_key (Optional[str]): Unused; kept for backward compatibility.
model (str): The model value.
dimensions (Optional[int]): The dimensions value.
timeout (float): Maximum wait time in seconds.
gemini_api_key (Optional[str]): Unused; pool is used instead.
gemini_only (bool): Always True; embeddings use Gemini API only.
"""
self.model = model
self.dimensions = dimensions or 3072
self.timeout = timeout
self.gemini_only = True # Always use Gemini API
self._client = httpx.AsyncClient(timeout=timeout)
logger.info(
"Initialized embeddings client with model: %s (Gemini API)", model,
)
[docs]
async def embed_text(self, text: str) -> np.ndarray:
"""Embed text.
Args:
text (str): Text content.
Returns:
np.ndarray: The result.
"""
embeddings = await self.embed_texts([text])
return embeddings[0]
[docs]
async def embed_texts(self, texts: List[str]) -> List[np.ndarray]:
"""Embed texts.
Args:
texts (List[str]): The texts value.
Returns:
List[np.ndarray]: The result.
"""
if not texts:
return []
batches = self._create_batches(texts)
all_embeddings: List[np.ndarray] = []
for batch in batches:
all_embeddings.extend(await self._embed_batch(batch))
return all_embeddings
def _create_batches(self, texts: List[str]) -> List[List[str]]:
"""Internal helper: create batches.
Args:
texts (List[str]): The texts value.
Returns:
List[List[str]]: The result.
"""
batches: List[List[str]] = []
current_batch: List[str] = []
current_chars = 0
for text in texts:
text_len = len(text)
would_exceed_items = (
len(current_batch) >= self.MAX_BATCH_SIZE
)
would_exceed_chars = (
current_chars + text_len > self.MAX_BATCH_CHARS
)
if current_batch and (
would_exceed_items or would_exceed_chars
):
batches.append(current_batch)
current_batch = []
current_chars = 0
current_batch.append(text)
current_chars += text_len
if current_batch:
batches.append(current_batch)
return batches
async def _embed_batch(
self, texts: List[str],
) -> List[np.ndarray]:
"""Internal helper: embed batch via Gemini API."""
round_num = 0
while True:
try:
return await self._embed_batch_gemini(texts)
except Exception as exc:
round_num += 1
delay = min(
EMBED_RETRY_BASE_DELAY * (2 ** (round_num - 1)),
MAX_EMBED_DELAY,
)
logger.warning(
"Gemini embed failed (round %d), retrying in %.1fs: %s",
round_num, delay, exc,
)
await asyncio.sleep(delay)
async def _embed_batch_gemini(
self, texts: List[str], task_type: Optional[str] = None,
) -> List[np.ndarray]:
"""Embed a batch of texts via the native Gemini API (shared key pool).
Rotates to a fresh key on every 429 and falls back to the paid
key after exhausting ``PAID_KEY_FALLBACK_THRESHOLD`` rotations.
Args:
texts: Texts to embed.
task_type: Optional Gemini task type (e.g. ``QUESTION_ANSWERING``,
``RETRIEVAL_DOCUMENT``). Omit for default behaviour.
"""
if await check_openrouter_only():
logger.info("OpenRouter-only mode — bypassing Gemini for %d texts", len(texts))
vecs = await openrouter_embed_batch(texts, model=self.model)
return [np.array(v, dtype=np.float32) for v in vecs]
gemini_model = _gemini_model_name(self.model)
requests_list = []
for t in texts:
req: dict = {
"model": f"models/{gemini_model}",
"content": {"parts": [{"text": t}]},
"output_dimensionality": 3072,
}
if task_type:
req["taskType"] = task_type
requests_list.append(req)
payload = {"requests": requests_list}
consecutive_429 = 0
switched_to_paid = False
max_attempts = 20
for attempt in range(max_attempts):
api_key = next_gemini_embed_key()
url = (
f"{GEMINI_EMBED_BASE}/{gemini_model}:batchEmbedContents"
f"?key={api_key}"
)
if attempt > 0:
delay = 1.0 if attempt <= PAID_KEY_FALLBACK_THRESHOLD else min(
EMBED_RETRY_BASE_DELAY * (2 ** (attempt - PAID_KEY_FALLBACK_THRESHOLD - 1)),
MAX_EMBED_DELAY,
)
await asyncio.sleep(delay)
last_error: str | None = None
try:
response = await self._client.post(url, json=payload)
if response.status_code == 429:
consecutive_429 += 1
if (
consecutive_429 >= PAID_KEY_FALLBACK_THRESHOLD
and not switched_to_paid
):
paid = get_paid_fallback_key()
if paid:
logger.warning(
"Switching to paid Gemini key after %d "
"consecutive 429s",
consecutive_429,
)
switched_to_paid = True
elif switched_to_paid:
await set_openrouter_only()
try:
logger.warning(
"Paid Gemini key 429'd — trying OpenRouter",
)
vecs = await openrouter_embed_batch(
texts, model=self.model,
)
return [
np.array(v, dtype=np.float32) for v in vecs
]
except Exception:
logger.warning(
"OpenRouter embed fallback also failed",
exc_info=True,
)
last_error = "HTTP 429"
continue
if response.status_code in _RETRIABLE_STATUSES:
last_error = f"HTTP {response.status_code}"
continue
response.raise_for_status()
data = response.json()
return [
np.array(item["values"], dtype=np.float32)
for item in data["embeddings"]
]
except Exception as exc:
last_error = str(exc)
raise RuntimeError(
f"Gemini embed failed after {max_attempts} attempts: {last_error}"
)
[docs]
async def embed_text_for_search(
self, text: str, task_type: str = "QUESTION_ANSWERING",
) -> List[float]:
"""Embed a single text using the Gemini API only, with a task type.
Intended for pre-computing a query embedding before passing it to
``FileRAGManager.search(query_embedding=...)``. Retries on
transient errors with exponential back-off.
"""
round_num = 0
while True:
try:
results = await self._embed_batch_gemini(
[text], task_type=task_type,
)
return results[0].tolist()
except Exception as exc:
round_num += 1
delay = min(
EMBED_RETRY_BASE_DELAY * (2 ** (round_num - 1)),
MAX_EMBED_DELAY,
)
logger.warning(
"Gemini embed_text_for_search failed (round %d), "
"retrying in %.1fs: %s",
round_num, delay, exc,
)
await asyncio.sleep(delay)
[docs]
async def close(self):
"""Close.
"""
await self._client.aclose()
[docs]
async def __aenter__(self):
"""Internal helper: aenter .
"""
return self
[docs]
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Internal helper: aexit .
Args:
exc_type: The exc type value.
exc_val: The exc val value.
exc_tb: The exc tb value.
"""
await self.close()
[docs]
class SyncOpenRouterEmbeddings:
"""Synchronous wrapper used by ChromaDB's embedding function interface.
Uses Gemini API via the shared key pool. Batches are dispatched
concurrently via a ThreadPoolExecutor when there are multiple batches.
"""
MAX_BATCH_SIZE = 50
MAX_BATCH_CHARS = 50_000
MAX_EMBED_WORKERS = 20
[docs]
def __init__(
self,
api_key: Optional[str] = None,
model: str = OpenRouterEmbeddings.DEFAULT_MODEL,
dimensions: Optional[int] = None,
timeout: float = 30.0,
gemini_api_key: Optional[str] = None,
gemini_only: bool = True,
document_task_type: Optional[str] = None,
query_task_type: Optional[str] = None,
):
"""Initialize the instance.
Args:
api_key (Optional[str]): Unused; kept for backward compatibility.
model (str): The model value.
dimensions (Optional[int]): The dimensions value.
timeout (float): Maximum wait time in seconds.
gemini_api_key (Optional[str]): Unused; pool is used instead.
gemini_only (bool): Unused; always Gemini API.
document_task_type: Optional Gemini ``taskType`` for corpus
(e.g. ``RETRIEVAL_DOCUMENT``); used by ``embed_documents``.
query_task_type: Optional Gemini ``taskType`` for queries
(e.g. ``RETRIEVAL_QUERY``); used by ``embed_query``.
"""
self.model = model
self.dimensions = dimensions or 3072
self.timeout = timeout
self._name = f"openrouter_{model.replace('/', '_')}"
self.is_legacy = False
self.document_task_type = document_task_type
self.query_task_type = query_task_type
import threading
self._local = threading.local()
def _get_client(self) -> httpx.Client:
"""Return a thread-local httpx.Client to avoid connection pool corruption."""
client = getattr(self._local, "client", None)
if client is None or client.is_closed:
client = httpx.Client(timeout=self.timeout)
self._local.client = client
return client
[docs]
def name(self) -> str:
"""Name.
Returns:
str: Result string.
"""
return self._name
[docs]
def dimension(self) -> int:
"""Dimension.
Returns:
int: The result.
"""
return 3072
[docs]
def __call__(self, input: List[str]) -> List[List[float]]:
"""ChromaDB EmbeddingFunction interface (legacy).
Uses ``document_task_type`` when set (same as :meth:`embed_documents`).
"""
return self._embed_inputs(input, self.document_task_type)
def _embed_inputs(
self,
input: List[str],
task_type: Optional[str],
) -> List[List[float]]:
"""Batch-embed *input* with optional Gemini ``taskType`` per batch."""
if not input:
return []
batches = self._create_batches(input)
if len(batches) == 1:
return self._embed_batch(batches[0], task_type=task_type)
from concurrent.futures import ThreadPoolExecutor, as_completed
workers = min(len(batches), self.MAX_EMBED_WORKERS)
ordered: list[tuple[int, List[List[float]]]] = []
with ThreadPoolExecutor(max_workers=workers) as pool:
futures = {
pool.submit(self._embed_batch, batch, task_type): idx
for idx, batch in enumerate(batches)
}
for future in as_completed(futures):
ordered.append((futures[future], future.result()))
ordered.sort(key=lambda x: x[0])
result: List[List[float]] = []
for _, embs in ordered:
result.extend(embs)
return result
# ChromaDB >=0.6 calls these instead of __call__
[docs]
def embed_documents(self, input: List[str]) -> List[List[float]]:
"""Embed documents (ChromaDB interface for upsert)."""
return self._embed_inputs(input, self.document_task_type)
[docs]
def embed_query(self, input: List[str]) -> List[List[float]]:
"""Embed query texts (ChromaDB interface for query)."""
return self._embed_inputs(input, self.query_task_type)
def _create_batches(self, texts: List[str]) -> List[List[str]]:
"""Internal helper: create batches.
Args:
texts (List[str]): The texts value.
Returns:
List[List[str]]: The result.
"""
batches: List[List[str]] = []
current_batch: List[str] = []
current_chars = 0
for text in texts:
text_len = len(text)
would_exceed_items = len(current_batch) >= self.MAX_BATCH_SIZE
would_exceed_chars = (
current_chars + text_len > self.MAX_BATCH_CHARS
)
if current_batch and (would_exceed_items or would_exceed_chars):
batches.append(current_batch)
current_batch = []
current_chars = 0
current_batch.append(text)
current_chars += text_len
if current_batch:
batches.append(current_batch)
return batches
def _embed_batch(
self,
texts: List[str],
task_type: Optional[str] = None,
) -> List[List[float]]:
"""Internal helper: embed batch via Gemini API."""
round_num = 0
while True:
try:
return self._embed_batch_gemini(texts, task_type=task_type)
except Exception as exc:
round_num += 1
delay = min(
EMBED_RETRY_BASE_DELAY * (2 ** (round_num - 1)),
MAX_EMBED_DELAY,
)
logger.warning(
"Gemini embed failed (round %d), retrying in %.1fs: %s",
round_num, delay, exc,
)
time.sleep(delay)
def _embed_batch_gemini(
self, texts: List[str], task_type: Optional[str] = None,
) -> List[List[float]]:
"""Embed a batch of texts via the native Gemini API (shared key pool).
Rotates to a fresh key on every 429 and falls back to the paid
key after exhausting ``PAID_KEY_FALLBACK_THRESHOLD`` rotations.
Args:
texts: Texts to embed.
task_type: Optional Gemini task type (e.g. ``RETRIEVAL_DOCUMENT``).
"""
if check_openrouter_only_sync():
logger.info("OpenRouter-only mode (sync) — bypassing Gemini for %d texts", len(texts))
return openrouter_embed_batch_sync(texts, model=self.model)
gemini_model = _gemini_model_name(self.model)
requests_list = []
for t in texts:
req: dict = {
"model": f"models/{gemini_model}",
"content": {"parts": [{"text": t}]},
"output_dimensionality": 3072,
}
if task_type:
req["taskType"] = task_type
requests_list.append(req)
payload = {"requests": requests_list}
client = self._get_client()
consecutive_429 = 0
switched_to_paid = False
max_attempts = 20
for attempt in range(max_attempts):
api_key = next_gemini_embed_key()
url = (
f"{GEMINI_EMBED_BASE}/{gemini_model}:batchEmbedContents"
f"?key={api_key}"
)
if attempt > 0:
delay = 1.0 if attempt <= PAID_KEY_FALLBACK_THRESHOLD else min(
EMBED_RETRY_BASE_DELAY * (2 ** (attempt - PAID_KEY_FALLBACK_THRESHOLD - 1)),
MAX_EMBED_DELAY,
)
time.sleep(delay)
last_error: str | None = None
try:
response = client.post(url, json=payload)
if response.status_code == 429:
consecutive_429 += 1
if (
consecutive_429 >= PAID_KEY_FALLBACK_THRESHOLD
and not switched_to_paid
):
paid = get_paid_fallback_key()
if paid:
logger.warning(
"Switching to paid Gemini key after %d "
"consecutive 429s",
consecutive_429,
)
switched_to_paid = True
elif switched_to_paid:
import gemini_embed_pool as _gep
_gep._openrouter_only = True
logger.warning(
"OpenRouter-only mode ACTIVATED (sync, in-memory only)",
)
try:
logger.warning(
"Paid Gemini key 429'd — trying OpenRouter",
)
return openrouter_embed_batch_sync(
texts, model=self.model,
)
except Exception:
logger.warning(
"OpenRouter embed fallback also failed",
exc_info=True,
)
last_error = "HTTP 429"
continue
if response.status_code in _RETRIABLE_STATUSES:
last_error = f"HTTP {response.status_code}"
continue
response.raise_for_status()
data = response.json()
return [item["values"] for item in data["embeddings"]]
except Exception as exc:
last_error = str(exc)
raise RuntimeError(
f"Gemini embed failed after {max_attempts} attempts: {last_error}"
)