"""Shared Gemini API key pool for rate-limit distribution.
All embedding calls use next_gemini_embed_key() and all flash-lite generation
calls use next_gemini_flash_key() for round-robin key selection. Both
accessors draw from the same underlying key pool but maintain independent
rotation cycles so they don't interfere with each other.
Supports GEMINI_EMBED_KEY_POOL env var (comma-separated keys) with fallback
to the default hardcoded pool.
When the free pool is exhausted (429s exceed PAID_KEY_FALLBACK_THRESHOLD),
callers should switch to the paid key returned by get_paid_fallback_key().
Daily quota tracking
--------------------
Quotas are tracked **per model class** (``"embed"`` vs ``"generate"``).
Embedding RPD and generation RPD are separate quotas on the same key, so a
key spent for embeddings can still serve flash-lite generation and vice versa.
Each key's usage is tracked in Redis and two in-memory spent sets. Keys that
receive a daily-quota 429 (``PerDay`` in the ``quotaId``) are excluded from
rotation for the relevant model class until midnight Pacific Time. A
background probe every 2 hours detects keys exhausted by external usage.
"""
from __future__ import annotations
import asyncio
import itertools
import logging
import os
import threading
from datetime import datetime, timezone, timedelta
from typing import Any, Literal
import httpx
ModelClass = Literal["embed", "generate"]
logger = logging.getLogger(__name__)
# Base URL for Gemini embedding API
GEMINI_EMBED_BASE = "https://generativelanguage.googleapis.com/v1beta/models"
# Always request full-resolution 3072-dimension embeddings.
# gemini-embedding-001 at max dimensions is pre-normalized by Google.
EMBED_DIMENSIONS = 3072
# After this many consecutive 429 responses, switch to the paid key.
PAID_KEY_FALLBACK_THRESHOLD = 4
# Timezone for daily quota reset (midnight Pacific)
_PACIFIC = timezone(timedelta(hours=-7)) # PDT; adjust to -8 for PST if needed
# Redis key prefix for daily usage tracking
_REDIS_KEY_PREFIX = "gemini_key_daily_usage"
# Redis keys for user-donated embedding keys
DONATED_KEYS_SET = "gemini_embed_pool:donated_keys"
DONATED_KEYS_META = "gemini_embed_pool:donated_meta"
# Default round-robin pool of Gemini API keys for embedding rate-limit distribution
_DEFAULT_KEY_POOL = [
"AIzaSyDSwQzMMYS2Cyu1XRHHKpoCkLik6XM3q9E",
"AIzaSyBLf0DpYMCDlNB0ChAE_uDDr_3Fq4k60uo",
"AIzaSyDHI-vLbCFUcXQ7xdV7nrr-Jqbpo0Sg1bE",
"AIzaSyBr5tROOuolC1uyZZIW-z5z_p45tfRlx1c",
"AIzaSyBRYPBboNGfLGzY0YyLptr98GnIK38PF_0",
"AIzaSyBwNI9aJG0HZ_AMiw2S5SWEHWaDZu0h0es",
"AIzaSyDEK3elQDZoTx5HRFXhsmuCqIe-ITsNsZ8",
"AIzaSyDLrc8-Rx3RvrIT5Zz3T-efMcEzxNX1MUc",
"AIzaSyDEK3elQDZoTx5HRFXhsmuCqIe-ITsNsZ8",
"AIzaSyBvOzFRDmw5W3Kg89mxPQgqLf2JXk6JCy0",
"AIzaSyDZ4siJzWHLwdajN94TvdJ5gXcu6bfKyyw",
"AIzaSyA1lV0_zn7pIIbvvTdu4ukbQYRXFc70MF4",
"AIzaSyDsjV9CA3sDIzH1O7TkPg0VWfiCw4lOILo"
]
_DEFAULT_PAID_KEY = "AIzaSyCCwz9WCsIKSWsfufU6E-JbPsP1acLhZTU"
def _get_key_pool() -> list[str]:
"""Return the key pool from env or default."""
env = os.environ.get("GEMINI_EMBED_KEY_POOL", "").strip()
if env:
keys = [k.strip() for k in env.split(",") if k.strip()]
if keys:
return keys
return _DEFAULT_KEY_POOL
[docs]
def get_paid_fallback_key() -> str | None:
"""Return the paid Gemini API key for embedding fallback, or *None*."""
key = os.environ.get("GEMINI_EMBED_PAID_KEY", "").strip()
return key or _DEFAULT_PAID_KEY
_GEMINI_EMBED_KEY_POOL = _get_key_pool()
_gemini_embed_key_cycle = itertools.cycle(_GEMINI_EMBED_KEY_POOL)
_gemini_embed_key_lock = threading.Lock()
_gemini_flash_key_cycle = itertools.cycle(_GEMINI_EMBED_KEY_POOL)
_gemini_flash_key_lock = threading.Lock()
# ---------------------------------------------------------------------------
# Daily quota tracking state (per model class)
# ---------------------------------------------------------------------------
_redis_client: Any | None = None
# Separate spent sets per model class so that a key exhausted for embeddings
# can still serve generation requests and vice versa.
_spent_keys_embed: set[str] = set()
_spent_keys_generate: set[str] = set()
_spent_keys_lock = threading.Lock()
_MODEL_CLASS_SETS: dict[ModelClass, set[str]] = {
"embed": _spent_keys_embed,
"generate": _spent_keys_generate,
}
# Track the current "quota day" so we know when midnight PT has passed.
_current_quota_day: str = datetime.now(_PACIFIC).strftime("%Y-%m-%d")
def _key_suffix(api_key: str) -> str:
"""Last 8 characters of an API key -- safe for use as a Redis field."""
return api_key[-8:]
[docs]
async def reload_pool() -> int:
"""Merge donated keys from Redis into the live pool and rebuild cycles.
Returns the new total pool size. Safe to call multiple times.
"""
global _GEMINI_EMBED_KEY_POOL, _gemini_embed_key_cycle, _gemini_flash_key_cycle
base = _get_key_pool()
donated: list[str] = []
if _redis_client is not None:
try:
raw_members = await _redis_client.smembers(DONATED_KEYS_SET)
for m in raw_members:
k = m if isinstance(m, str) else m.decode()
if k and k not in base:
donated.append(k)
except Exception:
logger.warning("Failed to load donated embed keys from Redis", exc_info=True)
merged = base + donated
with _gemini_embed_key_lock:
_GEMINI_EMBED_KEY_POOL = merged
_gemini_embed_key_cycle = itertools.cycle(merged)
with _gemini_flash_key_lock:
_gemini_flash_key_cycle = itertools.cycle(merged)
if donated:
logger.info(
"Embed key pool reloaded: %d base + %d donated = %d total",
len(base), len(donated), len(merged),
)
return len(merged)
[docs]
def init_quota_tracking(redis_client: Any) -> None:
"""Provide the async Redis client for daily quota persistence.
Must be called once at startup (from ``main.py``) before any
embedding calls are made. Call :func:`reload_pool` afterwards
(from an async context) to merge donated keys.
"""
global _redis_client
_redis_client = redis_client
logger.info(
"Gemini key pool quota tracking initialised (%d keys)",
len(_GEMINI_EMBED_KEY_POOL),
)
def _seconds_until_midnight_pt() -> int:
"""Seconds from now until the next midnight Pacific Time."""
now = datetime.now(_PACIFIC)
midnight = (now + timedelta(days=1)).replace(
hour=0, minute=0, second=0, microsecond=0,
)
return max(int((midnight - now).total_seconds()), 1)
def _maybe_reset_day() -> None:
"""Clear both in-memory spent sets if the Pacific date has rolled over."""
global _current_quota_day
today = datetime.now(_PACIFIC).strftime("%Y-%m-%d")
if today != _current_quota_day:
with _spent_keys_lock:
_spent_keys_embed.clear()
_spent_keys_generate.clear()
_current_quota_day = today
logger.info("Gemini quota day rolled over to %s – spent keys cleared", today)
def _next_active_key(
cycle: itertools.cycle, # type: ignore[type-arg]
lock: threading.Lock,
model_class: ModelClass,
) -> str:
"""Return the next non-spent key from *cycle*, or the paid fallback.
Only checks the spent set for *model_class* so that a key exhausted
for embeddings can still be returned for generation and vice versa.
"""
_maybe_reset_day()
spent_set = _MODEL_CLASS_SETS[model_class]
pool_size = len(_GEMINI_EMBED_KEY_POOL)
with lock:
for _ in range(pool_size):
key = next(cycle)
with _spent_keys_lock:
if key not in spent_set:
return key
paid = get_paid_fallback_key()
if paid:
logger.warning(
"All free Gemini keys spent for %s today; using paid key",
model_class,
)
return paid
raise RuntimeError(
f"All Gemini keys are daily-spent for {model_class} "
"and no paid fallback is set"
)
[docs]
def next_gemini_embed_key() -> str:
"""Thread-safe round-robin selection from the Gemini embedding key pool.
Skips keys that have been marked as daily-spent for embeddings.
"""
return _next_active_key(
_gemini_embed_key_cycle, _gemini_embed_key_lock, "embed",
)
[docs]
def next_gemini_flash_key() -> str:
"""Thread-safe round-robin selection for Gemini flash-lite generation calls.
Uses the same key pool as embeddings but an independent cycle so
embed and generation rotations don't interfere. Skips keys that
have been marked as daily-spent for generation.
"""
return _next_active_key(
_gemini_flash_key_cycle, _gemini_flash_key_lock, "generate",
)
# ---------------------------------------------------------------------------
# 429 parsing & spent-key management
# ---------------------------------------------------------------------------
# Model names that appear in quotaDimensions.model for each class
_EMBED_MODEL_PREFIXES = ("gemini-embedding",)
_GENERATE_MODEL_PREFIXES = ("gemini-",) # catch-all for non-embedding models
[docs]
def is_daily_quota_429(resp: httpx.Response) -> bool:
"""Return True if *resp* is a 429 caused by a daily (RPD) quota limit.
Parses the structured ``QuotaFailure`` violations in the response body
and looks for a ``quotaId`` containing ``PerDay``.
"""
if resp.status_code != 429:
return False
try:
body = resp.json()
for detail in body.get("error", {}).get("details", []):
for violation in detail.get("violations", []):
quota_id = violation.get("quotaId", "")
if "PerDay" in quota_id:
return True
except Exception:
pass
return False
[docs]
def is_daily_quota_429_for_model(resp: httpx.Response) -> ModelClass | None:
"""Identify which model class a daily-quota 429 belongs to.
Returns ``"embed"`` or ``"generate"`` based on the
``quotaDimensions.model`` field, or *None* if the response is not a
daily 429.
"""
if resp.status_code != 429:
return None
try:
body = resp.json()
for detail in body.get("error", {}).get("details", []):
for violation in detail.get("violations", []):
quota_id = violation.get("quotaId", "")
if "PerDay" not in quota_id:
continue
dims = violation.get("quotaDimensions", {})
model_name = dims.get("model", "")
if model_name.startswith(_EMBED_MODEL_PREFIXES):
return "embed"
return "generate"
except Exception:
pass
return None
[docs]
async def record_key_usage(api_key: str) -> None:
"""Increment the daily request counter for *api_key* in Redis."""
if _redis_client is None:
return
suffix = _key_suffix(api_key)
redis_key = f"{_REDIS_KEY_PREFIX}:{suffix}:count"
try:
pipe = _redis_client.pipeline(transaction=False)
pipe.incr(redis_key)
pipe.expire(redis_key, _seconds_until_midnight_pt())
await pipe.execute()
except Exception:
logger.debug("Failed to record Gemini key usage for ...%s", suffix)
[docs]
async def mark_key_daily_spent(
api_key: str,
model_class: ModelClass = "embed",
) -> None:
"""Mark *api_key* as daily-spent for *model_class*.
Updates both the in-memory set and the Redis flag.
"""
suffix = _key_suffix(api_key)
spent_set = _MODEL_CLASS_SETS[model_class]
with _spent_keys_lock:
spent_set.add(api_key)
logger.warning(
"Gemini key ...%s marked as daily-spent for %s", suffix, model_class,
)
if _redis_client is None:
return
try:
ttl = _seconds_until_midnight_pt()
pipe = _redis_client.pipeline(transaction=False)
pipe.set(
f"{_REDIS_KEY_PREFIX}:{suffix}:spent:{model_class}", "1", ex=ttl,
)
pipe.set(
f"{_REDIS_KEY_PREFIX}:{suffix}:spent_at:{model_class}",
datetime.now(timezone.utc).isoformat(),
ex=ttl,
)
await pipe.execute()
except Exception:
logger.debug("Failed to persist spent flag for ...%s/%s", suffix, model_class)
[docs]
async def sync_spent_keys_from_redis() -> None:
"""Refresh both in-memory spent sets from Redis."""
if _redis_client is None:
return
_maybe_reset_day()
new_embed: set[str] = set()
new_generate: set[str] = set()
try:
for key in _GEMINI_EMBED_KEY_POOL:
suffix = _key_suffix(key)
# Check per-model flags
val_e = await _redis_client.get(
f"{_REDIS_KEY_PREFIX}:{suffix}:spent:embed",
)
if val_e is not None:
new_embed.add(key)
val_g = await _redis_client.get(
f"{_REDIS_KEY_PREFIX}:{suffix}:spent:generate",
)
if val_g is not None:
new_generate.add(key)
# Backward compat: old `:spent` key (no model suffix) counts for both
val_old = await _redis_client.get(
f"{_REDIS_KEY_PREFIX}:{suffix}:spent",
)
if val_old is not None:
new_embed.add(key)
new_generate.add(key)
except Exception:
logger.debug("Failed to sync spent keys from Redis", exc_info=True)
return
with _spent_keys_lock:
_spent_keys_embed.clear()
_spent_keys_embed.update(new_embed)
_spent_keys_generate.clear()
_spent_keys_generate.update(new_generate)
if new_embed or new_generate:
logger.info(
"Synced spent keys from Redis: embed=%d/%d, generate=%d/%d",
len(new_embed), len(_GEMINI_EMBED_KEY_POOL),
len(new_generate), len(_GEMINI_EMBED_KEY_POOL),
)
[docs]
async def get_pool_status() -> dict[str, Any]:
"""Return a summary of pool health for diagnostics / logging."""
_maybe_reset_day()
with _spent_keys_lock:
embed_spent = len(_spent_keys_embed)
generate_spent = len(_spent_keys_generate)
total = len(_GEMINI_EMBED_KEY_POOL)
result: dict[str, Any] = {
"total_keys": total,
"embed_active": total - embed_spent,
"embed_spent": embed_spent,
"generate_active": total - generate_spent,
"generate_spent": generate_spent,
"quota_day": _current_quota_day,
}
if _redis_client is not None:
counts: dict[str, int] = {}
try:
for key in _GEMINI_EMBED_KEY_POOL:
suffix = _key_suffix(key)
raw = await _redis_client.get(
f"{_REDIS_KEY_PREFIX}:{suffix}:count",
)
counts[f"...{suffix}"] = int(raw) if raw else 0
except Exception:
pass
result["usage_counts"] = counts
return result
# ---------------------------------------------------------------------------
# Background probe -- runs every 2 hours
# ---------------------------------------------------------------------------
_PROBE_GENERATE_MODEL = "gemini-3.1-flash-lite-preview"
[docs]
async def probe_all_keys() -> None:
"""Probe each non-spent key for both embedding and generation daily exhaustion.
Sends a minimal request to each endpoint and marks the key spent for
the relevant model class on a daily 429.
"""
_maybe_reset_day()
await sync_spent_keys_from_redis()
with _spent_keys_lock:
already_embed = set(_spent_keys_embed)
already_generate = set(_spent_keys_generate)
newly_embed = 0
newly_generate = 0
async with httpx.AsyncClient(timeout=30.0) as client:
for key in _GEMINI_EMBED_KEY_POOL:
suffix = _key_suffix(key)
# --- Embedding probe ---
if key not in already_embed:
url = (
f"{GEMINI_EMBED_BASE}/gemini-embedding-001"
f":batchEmbedContents?key={key}"
)
payload = {
"requests": [{
"model": "models/gemini-embedding-001",
"content": {"parts": [{"text": "probe"}]},
"output_dimensionality": 768,
}],
}
try:
resp = await client.post(url, json=payload)
await record_key_usage(key)
if is_daily_quota_429(resp):
await mark_key_daily_spent(key, "embed")
newly_embed += 1
except Exception:
logger.debug(
"Embed probe failed for ...%s", suffix, exc_info=True,
)
# --- Generation probe ---
if key not in already_generate:
url = (
f"{GEMINI_EMBED_BASE}/{_PROBE_GENERATE_MODEL}"
f":generateContent?key={key}"
)
payload = {
"contents": [{"parts": [{"text": "probe"}]}],
"generationConfig": {"maxOutputTokens": 1},
}
try:
resp = await client.post(url, json=payload)
await record_key_usage(key)
if is_daily_quota_429(resp):
await mark_key_daily_spent(key, "generate")
newly_generate += 1
except Exception:
logger.debug(
"Generate probe failed for ...%s", suffix, exc_info=True,
)
total = len(_GEMINI_EMBED_KEY_POOL)
logger.info(
"Gemini key probe complete (%d keys): "
"embed %d active / %d newly spent / %d prev spent, "
"generate %d active / %d newly spent / %d prev spent",
total,
total - len(already_embed) - newly_embed, newly_embed, len(already_embed),
total - len(already_generate) - newly_generate, newly_generate, len(already_generate),
)
def _gemini_model_name(model: str) -> str:
"""Convert ``google/gemini-embedding-001`` → ``gemini-embedding-001``."""
return model.removeprefix("google/")
# ---------------------------------------------------------------------------
# OpenRouter embedding fallback (last resort when paid Gemini key 429s)
# ---------------------------------------------------------------------------
_OPENROUTER_EMBED_URL = "https://openrouter.ai/api/v1/embeddings"
_DEFAULT_OPENROUTER_KEY = "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"
[docs]
def get_openrouter_api_key() -> str | None:
"""Return the OpenRouter API key from environment or default."""
key = (
os.environ.get("OPENROUTER_API_KEY", "").strip()
or os.environ.get("API_KEY", "").strip()
)
return key or _DEFAULT_OPENROUTER_KEY
[docs]
async def openrouter_embed_batch(
texts: list[str],
*,
model: str = "google/gemini-embedding-001",
api_key: str | None = None,
dimensions: int = EMBED_DIMENSIONS,
) -> list[list[float]]:
"""Embed texts via the OpenRouter /embeddings endpoint (async).
Used as a last-resort fallback when all Gemini keys (including the
paid key) are rate-limited.
"""
key = api_key or get_openrouter_api_key()
if not key:
raise RuntimeError(
"OpenRouter fallback unavailable: no API key "
"(set OPENROUTER_API_KEY or API_KEY)"
)
headers = {
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
}
payload: dict = {"model": model, "input": texts}
if dimensions:
payload["dimensions"] = dimensions
async with httpx.AsyncClient(timeout=60.0) as client:
resp = await client.post(
_OPENROUTER_EMBED_URL, json=payload, headers=headers,
)
if resp.status_code != 200:
body = resp.text[:500]
raise RuntimeError(
f"OpenRouter embed fallback failed "
f"(HTTP {resp.status_code}): {body}"
)
data = resp.json()
items = sorted(data["data"], key=lambda x: x["index"])
return [item["embedding"] for item in items]
[docs]
def openrouter_embed_batch_sync(
texts: list[str],
*,
model: str = "google/gemini-embedding-001",
api_key: str | None = None,
dimensions: int = EMBED_DIMENSIONS,
) -> list[list[float]]:
"""Embed texts via the OpenRouter /embeddings endpoint (sync).
Synchronous version for callers that cannot use ``await``
(e.g. ChromaDB embedding functions).
"""
key = api_key or get_openrouter_api_key()
if not key:
raise RuntimeError(
"OpenRouter fallback unavailable: no API key "
"(set OPENROUTER_API_KEY or API_KEY)"
)
headers = {
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
}
payload: dict = {"model": model, "input": texts}
if dimensions:
payload["dimensions"] = dimensions
with httpx.Client(timeout=60.0) as client:
resp = client.post(
_OPENROUTER_EMBED_URL, json=payload, headers=headers,
)
if resp.status_code != 200:
body = resp.text[:500]
raise RuntimeError(
f"OpenRouter embed fallback failed "
f"(HTTP {resp.status_code}): {body}"
)
data = resp.json()
items = sorted(data["data"], key=lambda x: x["index"])
return [item["embedding"] for item in items]
# ---------------------------------------------------------------------------
# "OpenRouter-only" mode (set when paid Gemini key 429s)
# ---------------------------------------------------------------------------
_OPENROUTER_ONLY_KEY = "embed:openrouter_only"
_OPENROUTER_ONLY_TTL = 4 * 3600 # 4 hours
_openrouter_only: bool = False
_sync_redis_client: Any | None = None
[docs]
def is_openrouter_only() -> bool:
"""Return the in-memory OpenRouter-only flag (no I/O)."""
return _openrouter_only
[docs]
async def check_openrouter_only() -> bool:
"""Return whether OpenRouter-only mode is active.
Always checks Redis before each embedding call so the mode expires
when the Redis TTL lapses or is manually disabled.
"""
global _openrouter_only
if _redis_client is None:
_openrouter_only = False
return False
try:
val = await _redis_client.get(_OPENROUTER_ONLY_KEY)
_openrouter_only = val is not None
if _openrouter_only:
logger.info("OpenRouter-only flag detected in Redis")
return _openrouter_only
except Exception:
logger.debug("Failed to check openrouter_only flag", exc_info=True)
_openrouter_only = False
return False
[docs]
def check_openrouter_only_sync() -> bool:
"""Sync variant: check Redis before each embedding call.
Used by SyncOpenRouterEmbeddings (ChromaDB path). Lazy-creates a sync
Redis client from config on first use.
"""
global _openrouter_only, _sync_redis_client
if _sync_redis_client is None:
try:
from config import Config
cfg = Config.load()
if not cfg.redis_url:
return _openrouter_only
ssl_kwargs = cfg.redis_ssl_kwargs()
_sync_redis_client = __import__("redis").Redis.from_url(
cfg.redis_url, decode_responses=True, **ssl_kwargs
)
except Exception:
logger.debug("Failed to create sync Redis for openrouter_only", exc_info=True)
return _openrouter_only
try:
val = _sync_redis_client.get(_OPENROUTER_ONLY_KEY)
_openrouter_only = val is not None
if _openrouter_only:
logger.info("OpenRouter-only flag detected in Redis (sync)")
return _openrouter_only
except Exception:
logger.debug("Failed to check openrouter_only flag (sync)", exc_info=True)
return _openrouter_only
[docs]
async def set_openrouter_only() -> None:
"""Activate OpenRouter-only mode for embedding calls.
Sets both the in-memory flag and a Redis key with a 4-hour TTL.
"""
global _openrouter_only
_openrouter_only = True
logger.warning(
"OpenRouter-only mode ACTIVATED for embeddings (TTL=%ds)",
_OPENROUTER_ONLY_TTL,
)
if _redis_client is None:
return
try:
await _redis_client.set(
_OPENROUTER_ONLY_KEY, "1", ex=_OPENROUTER_ONLY_TTL,
)
except Exception:
logger.debug("Failed to persist openrouter_only flag", exc_info=True)
[docs]
async def clear_openrouter_only() -> None:
"""Deactivate OpenRouter-only mode (manual override)."""
global _openrouter_only
_openrouter_only = False
logger.info("OpenRouter-only mode CLEARED for embeddings")
if _redis_client is None:
return
try:
await _redis_client.delete(_OPENROUTER_ONLY_KEY)
except Exception:
logger.debug("Failed to delete openrouter_only flag", exc_info=True)
_RETRIABLE_STATUSES = {429, 500, 502, 503, 504}
_POOL_MAX_RETRIES = 12
_POOL_RETRY_BASE = 1.0
_POOL_RETRY_CAP = 8.0
[docs]
async def embed_batch_via_gemini(
texts: list[str],
model: str = "google/gemini-embedding-001",
*,
chunk_size: int = 50,
) -> list[list[float]]:
"""Embed a batch of texts via the native Gemini API using the shared key pool.
Returns one embedding vector per input text, in order. Empty or whitespace-only
texts are replaced with zero vectors. Retries on transient errors and falls
back to the paid key after ``PAID_KEY_FALLBACK_THRESHOLD`` consecutive 429s.
"""
if not texts:
return []
valid_indices: list[int] = []
valid_texts: list[str] = []
for i, t in enumerate(texts):
if t and t.strip():
valid_indices.append(i)
valid_texts.append(t)
if not valid_texts:
return [[0.0] * EMBED_DIMENSIONS for _ in texts]
if await check_openrouter_only():
logger.info("OpenRouter-only mode — bypassing Gemini for %d texts", len(valid_texts))
or_vecs = await openrouter_embed_batch(valid_texts, model=model)
result: list[list[float]] = [[0.0] * EMBED_DIMENSIONS for _ in texts]
for idx, vec in zip(valid_indices, or_vecs):
result[idx] = vec
return result
gemini_model = _gemini_model_name(model)
api_key = next_gemini_embed_key()
url = (
f"{GEMINI_EMBED_BASE}/{gemini_model}:batchEmbedContents"
f"?key={api_key}"
)
result: list[list[float]] = [[0.0] * EMBED_DIMENSIONS for _ in texts]
async with httpx.AsyncClient(timeout=60.0) as client:
for start in range(0, len(valid_texts), chunk_size):
batch_texts = valid_texts[start : start + chunk_size]
batch_indices = valid_indices[start : start + chunk_size]
requests_list = [
{
"model": f"models/{gemini_model}",
"content": {"parts": [{"text": t}]},
"output_dimensionality": EMBED_DIMENSIONS,
}
for t in batch_texts
]
payload = {"requests": requests_list}
attempt = 0
consecutive_429 = 0
cur_url = url
cur_key = api_key
while attempt < _POOL_MAX_RETRIES:
if attempt > 0:
delay = min(
_POOL_RETRY_BASE * (2 ** (attempt - 1)),
_POOL_RETRY_CAP,
)
await asyncio.sleep(delay)
try:
resp = await client.post(cur_url, json=payload)
except Exception:
attempt += 1
continue
await record_key_usage(cur_key)
if resp.status_code == 429:
if is_daily_quota_429(resp):
await mark_key_daily_spent(cur_key, "embed")
cur_key = next_gemini_embed_key()
cur_url = (
f"{GEMINI_EMBED_BASE}/{gemini_model}"
f":batchEmbedContents?key={cur_key}"
)
attempt += 1
continue
consecutive_429 += 1
if (
consecutive_429 >= PAID_KEY_FALLBACK_THRESHOLD
and cur_key is not get_paid_fallback_key()
):
paid = get_paid_fallback_key()
if paid:
logger.warning(
"Switching to paid Gemini key after %d "
"consecutive 429s",
consecutive_429,
)
cur_key = paid
cur_url = (
f"{GEMINI_EMBED_BASE}/{gemini_model}"
f":batchEmbedContents?key={paid}"
)
elif cur_key == get_paid_fallback_key():
await set_openrouter_only()
try:
logger.warning(
"Paid Gemini key 429'd — trying OpenRouter",
)
or_vecs = await openrouter_embed_batch(
batch_texts, model=model,
)
for idx, vec in enumerate(or_vecs):
if idx < len(batch_indices):
result[batch_indices[idx]] = vec
break
except Exception:
logger.warning(
"OpenRouter embed fallback also failed",
exc_info=True,
)
attempt += 1
continue
if resp.status_code in _RETRIABLE_STATUSES:
attempt += 1
continue
resp.raise_for_status()
data = resp.json()
for idx, item in enumerate(data.get("embeddings", [])):
vec = item.get("values", [])
if idx < len(batch_indices):
result[batch_indices[idx]] = vec
break
else:
raise RuntimeError(
f"embed_batch_via_gemini failed after "
f"{_POOL_MAX_RETRIES} attempts"
)
return result