Source code for rag_system.openrouter_embeddings

"""Embeddings client for RAG system.

Uses the native Google Gemini API only, via the shared key pool in
gemini_embed_pool. Provides both async (OpenRouterEmbeddings) and
synchronous (SyncOpenRouterEmbeddings, ChromaDB-compatible) interfaces.
"""

import asyncio
import logging
import time
from typing import List, Optional

import httpx
import numpy as np

from gemini_embed_pool import (
    GEMINI_EMBED_BASE,
    PAID_KEY_FALLBACK_THRESHOLD,
    check_openrouter_only,
    check_openrouter_only_sync,
    get_paid_fallback_key,
    next_gemini_embed_key,
    openrouter_embed_batch,
    openrouter_embed_batch_sync,
    set_openrouter_only,
)

logger = logging.getLogger(__name__)

MAX_499_RETRIES = 3
EMBED_RETRY_BASE_DELAY = 1.0
MAX_EMBED_DELAY = 8.0
_RETRIABLE_STATUSES = {429, 500, 502, 503, 504}


def _gemini_model_name(model: str) -> str:
    """Convert ``google/gemini-embedding-001`` → ``gemini-embedding-001``."""
    return model.removeprefix("google/")


[docs] class OpenRouterEmbeddings: """Async embeddings client using Gemini API via the shared key pool.""" DEFAULT_MODEL = "google/gemini-embedding-001" MAX_BATCH_SIZE = 50 MAX_BATCH_CHARS = 50_000
[docs] def __init__( self, api_key: Optional[str] = None, model: str = DEFAULT_MODEL, dimensions: Optional[int] = None, timeout: float = 30.0, gemini_api_key: Optional[str] = None, gemini_only: bool = True, ): """Initialize the instance. Args: api_key (Optional[str]): Unused; kept for backward compatibility. model (str): The model value. dimensions (Optional[int]): The dimensions value. timeout (float): Maximum wait time in seconds. gemini_api_key (Optional[str]): Unused; pool is used instead. gemini_only (bool): Always True; embeddings use Gemini API only. """ self.model = model self.dimensions = dimensions or 3072 self.timeout = timeout self.gemini_only = True # Always use Gemini API self._client = httpx.AsyncClient(timeout=timeout) logger.info( "Initialized embeddings client with model: %s (Gemini API)", model, )
[docs] async def embed_text(self, text: str) -> np.ndarray: """Embed text. Args: text (str): Text content. Returns: np.ndarray: The result. """ embeddings = await self.embed_texts([text]) return embeddings[0]
[docs] async def embed_texts(self, texts: List[str]) -> List[np.ndarray]: """Embed texts. Args: texts (List[str]): The texts value. Returns: List[np.ndarray]: The result. """ if not texts: return [] batches = self._create_batches(texts) all_embeddings: List[np.ndarray] = [] for batch in batches: all_embeddings.extend(await self._embed_batch(batch)) return all_embeddings
def _create_batches(self, texts: List[str]) -> List[List[str]]: """Internal helper: create batches. Args: texts (List[str]): The texts value. Returns: List[List[str]]: The result. """ batches: List[List[str]] = [] current_batch: List[str] = [] current_chars = 0 for text in texts: text_len = len(text) would_exceed_items = ( len(current_batch) >= self.MAX_BATCH_SIZE ) would_exceed_chars = ( current_chars + text_len > self.MAX_BATCH_CHARS ) if current_batch and ( would_exceed_items or would_exceed_chars ): batches.append(current_batch) current_batch = [] current_chars = 0 current_batch.append(text) current_chars += text_len if current_batch: batches.append(current_batch) return batches async def _embed_batch( self, texts: List[str], ) -> List[np.ndarray]: """Internal helper: embed batch via Gemini API.""" round_num = 0 while True: try: return await self._embed_batch_gemini(texts) except Exception as exc: round_num += 1 delay = min( EMBED_RETRY_BASE_DELAY * (2 ** (round_num - 1)), MAX_EMBED_DELAY, ) logger.warning( "Gemini embed failed (round %d), retrying in %.1fs: %s", round_num, delay, exc, ) await asyncio.sleep(delay) async def _embed_batch_gemini( self, texts: List[str], task_type: Optional[str] = None, ) -> List[np.ndarray]: """Embed a batch of texts via the native Gemini API (shared key pool). Rotates to a fresh key on every 429 and falls back to the paid key after exhausting ``PAID_KEY_FALLBACK_THRESHOLD`` rotations. Args: texts: Texts to embed. task_type: Optional Gemini task type (e.g. ``QUESTION_ANSWERING``, ``RETRIEVAL_DOCUMENT``). Omit for default behaviour. """ if await check_openrouter_only(): logger.info("OpenRouter-only mode — bypassing Gemini for %d texts", len(texts)) vecs = await openrouter_embed_batch(texts, model=self.model) return [np.array(v, dtype=np.float32) for v in vecs] gemini_model = _gemini_model_name(self.model) requests_list = [] for t in texts: req: dict = { "model": f"models/{gemini_model}", "content": {"parts": [{"text": t}]}, "output_dimensionality": 3072, } if task_type: req["taskType"] = task_type requests_list.append(req) payload = {"requests": requests_list} consecutive_429 = 0 switched_to_paid = False max_attempts = 20 for attempt in range(max_attempts): api_key = next_gemini_embed_key() url = ( f"{GEMINI_EMBED_BASE}/{gemini_model}:batchEmbedContents" f"?key={api_key}" ) if attempt > 0: delay = 1.0 if attempt <= PAID_KEY_FALLBACK_THRESHOLD else min( EMBED_RETRY_BASE_DELAY * (2 ** (attempt - PAID_KEY_FALLBACK_THRESHOLD - 1)), MAX_EMBED_DELAY, ) await asyncio.sleep(delay) last_error: str | None = None try: response = await self._client.post(url, json=payload) if response.status_code == 429: consecutive_429 += 1 if ( consecutive_429 >= PAID_KEY_FALLBACK_THRESHOLD and not switched_to_paid ): paid = get_paid_fallback_key() if paid: logger.warning( "Switching to paid Gemini key after %d " "consecutive 429s", consecutive_429, ) switched_to_paid = True elif switched_to_paid: await set_openrouter_only() try: logger.warning( "Paid Gemini key 429'd — trying OpenRouter", ) vecs = await openrouter_embed_batch( texts, model=self.model, ) return [ np.array(v, dtype=np.float32) for v in vecs ] except Exception: logger.warning( "OpenRouter embed fallback also failed", exc_info=True, ) last_error = "HTTP 429" continue if response.status_code in _RETRIABLE_STATUSES: last_error = f"HTTP {response.status_code}" continue response.raise_for_status() data = response.json() return [ np.array(item["values"], dtype=np.float32) for item in data["embeddings"] ] except Exception as exc: last_error = str(exc) raise RuntimeError( f"Gemini embed failed after {max_attempts} attempts: {last_error}" )
[docs] async def close(self): """Close. """ await self._client.aclose()
[docs] async def __aenter__(self): """Internal helper: aenter . """ return self
[docs] async def __aexit__(self, exc_type, exc_val, exc_tb): """Internal helper: aexit . Args: exc_type: The exc type value. exc_val: The exc val value. exc_tb: The exc tb value. """ await self.close()
[docs] class SyncOpenRouterEmbeddings: """Synchronous wrapper used by ChromaDB's embedding function interface. Uses Gemini API via the shared key pool. Batches are dispatched concurrently via a ThreadPoolExecutor when there are multiple batches. """ MAX_BATCH_SIZE = 50 MAX_BATCH_CHARS = 50_000 MAX_EMBED_WORKERS = 20
[docs] def __init__( self, api_key: Optional[str] = None, model: str = OpenRouterEmbeddings.DEFAULT_MODEL, dimensions: Optional[int] = None, timeout: float = 30.0, gemini_api_key: Optional[str] = None, gemini_only: bool = True, document_task_type: Optional[str] = None, query_task_type: Optional[str] = None, ): """Initialize the instance. Args: api_key (Optional[str]): Unused; kept for backward compatibility. model (str): The model value. dimensions (Optional[int]): The dimensions value. timeout (float): Maximum wait time in seconds. gemini_api_key (Optional[str]): Unused; pool is used instead. gemini_only (bool): Unused; always Gemini API. document_task_type: Optional Gemini ``taskType`` for corpus (e.g. ``RETRIEVAL_DOCUMENT``); used by ``embed_documents``. query_task_type: Optional Gemini ``taskType`` for queries (e.g. ``RETRIEVAL_QUERY``); used by ``embed_query``. """ self.model = model self.dimensions = dimensions or 3072 self.timeout = timeout self._name = f"openrouter_{model.replace('/', '_')}" self.is_legacy = False self.document_task_type = document_task_type self.query_task_type = query_task_type import threading self._local = threading.local()
def _get_client(self) -> httpx.Client: """Return a thread-local httpx.Client to avoid connection pool corruption.""" client = getattr(self._local, "client", None) if client is None or client.is_closed: client = httpx.Client(timeout=self.timeout) self._local.client = client return client
[docs] def name(self) -> str: """Name. Returns: str: Result string. """ return self._name
[docs] def dimension(self) -> int: """Dimension. Returns: int: The result. """ return 3072
[docs] def __call__(self, input: List[str]) -> List[List[float]]: """ChromaDB EmbeddingFunction interface (legacy). Uses ``document_task_type`` when set (same as :meth:`embed_documents`). """ return self._embed_inputs(input, self.document_task_type)
def _embed_inputs( self, input: List[str], task_type: Optional[str], ) -> List[List[float]]: """Batch-embed *input* with optional Gemini ``taskType`` per batch.""" if not input: return [] batches = self._create_batches(input) if len(batches) == 1: return self._embed_batch(batches[0], task_type=task_type) from concurrent.futures import ThreadPoolExecutor, as_completed workers = min(len(batches), self.MAX_EMBED_WORKERS) ordered: list[tuple[int, List[List[float]]]] = [] with ThreadPoolExecutor(max_workers=workers) as pool: futures = { pool.submit(self._embed_batch, batch, task_type): idx for idx, batch in enumerate(batches) } for future in as_completed(futures): ordered.append((futures[future], future.result())) ordered.sort(key=lambda x: x[0]) result: List[List[float]] = [] for _, embs in ordered: result.extend(embs) return result # ChromaDB >=0.6 calls these instead of __call__
[docs] def embed_documents(self, input: List[str]) -> List[List[float]]: """Embed documents (ChromaDB interface for upsert).""" return self._embed_inputs(input, self.document_task_type)
[docs] def embed_query(self, input: List[str]) -> List[List[float]]: """Embed query texts (ChromaDB interface for query).""" return self._embed_inputs(input, self.query_task_type)
def _create_batches(self, texts: List[str]) -> List[List[str]]: """Internal helper: create batches. Args: texts (List[str]): The texts value. Returns: List[List[str]]: The result. """ batches: List[List[str]] = [] current_batch: List[str] = [] current_chars = 0 for text in texts: text_len = len(text) would_exceed_items = len(current_batch) >= self.MAX_BATCH_SIZE would_exceed_chars = ( current_chars + text_len > self.MAX_BATCH_CHARS ) if current_batch and (would_exceed_items or would_exceed_chars): batches.append(current_batch) current_batch = [] current_chars = 0 current_batch.append(text) current_chars += text_len if current_batch: batches.append(current_batch) return batches def _embed_batch( self, texts: List[str], task_type: Optional[str] = None, ) -> List[List[float]]: """Internal helper: embed batch via Gemini API.""" round_num = 0 while True: try: return self._embed_batch_gemini(texts, task_type=task_type) except Exception as exc: round_num += 1 delay = min( EMBED_RETRY_BASE_DELAY * (2 ** (round_num - 1)), MAX_EMBED_DELAY, ) logger.warning( "Gemini embed failed (round %d), retrying in %.1fs: %s", round_num, delay, exc, ) time.sleep(delay) def _embed_batch_gemini( self, texts: List[str], task_type: Optional[str] = None, ) -> List[List[float]]: """Embed a batch of texts via the native Gemini API (shared key pool). Rotates to a fresh key on every 429 and falls back to the paid key after exhausting ``PAID_KEY_FALLBACK_THRESHOLD`` rotations. Args: texts: Texts to embed. task_type: Optional Gemini task type (e.g. ``RETRIEVAL_DOCUMENT``). """ if check_openrouter_only_sync(): logger.info("OpenRouter-only mode (sync) — bypassing Gemini for %d texts", len(texts)) return openrouter_embed_batch_sync(texts, model=self.model) gemini_model = _gemini_model_name(self.model) requests_list = [] for t in texts: req: dict = { "model": f"models/{gemini_model}", "content": {"parts": [{"text": t}]}, "output_dimensionality": 3072, } if task_type: req["taskType"] = task_type requests_list.append(req) payload = {"requests": requests_list} client = self._get_client() consecutive_429 = 0 switched_to_paid = False max_attempts = 20 for attempt in range(max_attempts): api_key = next_gemini_embed_key() url = ( f"{GEMINI_EMBED_BASE}/{gemini_model}:batchEmbedContents" f"?key={api_key}" ) if attempt > 0: delay = 1.0 if attempt <= PAID_KEY_FALLBACK_THRESHOLD else min( EMBED_RETRY_BASE_DELAY * (2 ** (attempt - PAID_KEY_FALLBACK_THRESHOLD - 1)), MAX_EMBED_DELAY, ) time.sleep(delay) last_error: str | None = None try: response = client.post(url, json=payload) if response.status_code == 429: consecutive_429 += 1 if ( consecutive_429 >= PAID_KEY_FALLBACK_THRESHOLD and not switched_to_paid ): paid = get_paid_fallback_key() if paid: logger.warning( "Switching to paid Gemini key after %d " "consecutive 429s", consecutive_429, ) switched_to_paid = True elif switched_to_paid: import gemini_embed_pool as _gep _gep._openrouter_only = True logger.warning( "OpenRouter-only mode ACTIVATED (sync, in-memory only)", ) try: logger.warning( "Paid Gemini key 429'd — trying OpenRouter", ) return openrouter_embed_batch_sync( texts, model=self.model, ) except Exception: logger.warning( "OpenRouter embed fallback also failed", exc_info=True, ) last_error = "HTTP 429" continue if response.status_code in _RETRIABLE_STATUSES: last_error = f"HTTP {response.status_code}" continue response.raise_for_status() data = response.json() return [item["values"] for item in data["embeddings"]] except Exception as exc: last_error = str(exc) raise RuntimeError( f"Gemini embed failed after {max_attempts} attempts: {last_error}" )