"""Shared Gemini API key pool for rate-limit distribution.
All embedding calls use next_gemini_embed_key() and all flash-lite generation
calls use next_gemini_flash_key() for round-robin key selection. Both
accessors draw from the same underlying key pool but maintain independent
rotation cycles so they don't interfere with each other.
Supports GEMINI_EMBED_KEY_POOL env var (comma-separated keys) with fallback
to the default hardcoded pool.
Embedding fallback cascade
--------------------------
When the free pool starts 429-ing (consecutive failures exceed
``PAID_KEY_FALLBACK_THRESHOLD``), the embed call sites cascade in this order:
1. **Free-tier pool** (round-robin via :func:`next_gemini_embed_key`).
2. **OpenRouter** (:func:`openrouter_embed_batch` / ``..._sync``) — engaged via
:func:`set_openrouter_only` so subsequent calls go straight here.
3. **Paid tier-3 Gemini key** (:func:`gemini_embed_paid_fallback` /
``..._sync``) — used as the **absolute last resort** when both the free
pool and OpenRouter have failed. Its key is returned by
:func:`get_paid_fallback_key` (env override
``GEMINI_EMBED_PAID_KEY``, default ``_DEFAULT_PAID_KEY``).
When ``openrouter_only`` is pinned, callers must still fall through to the
paid-key helper if OpenRouter fails on a specific call before raising.
OpenRouter embed fallback (``openrouter_embed_batch`` / ``openrouter_embed_batch_sync``)
retries transient network errors and HTTP 429 / 5xx with exponential backoff.
Tune with env: ``OPENROUTER_EMBED_MAX_ATTEMPTS`` (default 24),
``OPENROUTER_EMBED_RETRY_BASE_DELAY``, ``OPENROUTER_EMBED_RETRY_MAX_DELAY``.
Daily quota tracking
--------------------
Quotas are tracked **per model class** (``"embed"`` vs ``"generate"``).
Embedding RPD and generation RPD are separate quotas on the same key, so a
key spent for embeddings can still serve flash-lite generation and vice versa.
Each key's usage is tracked in Redis and two in-memory spent sets. Keys that
receive a daily-quota 429 (``PerDay`` in the ``quotaId``) are excluded from
rotation for the relevant model class until midnight Pacific Time. A
background probe every 2 hours detects keys exhausted by external usage.
"""
from __future__ import annotations
import asyncio
import itertools
import logging
import os
import random
import threading
import time
from datetime import datetime, timezone, timedelta
from typing import Any, Literal
import httpx
from observability import publish_http_error_event
ModelClass = Literal["embed", "generate"]
logger = logging.getLogger(__name__)
# Base URL for Gemini embedding API
GEMINI_EMBED_BASE = "https://generativelanguage.googleapis.com/v1beta/models"
# Always request full-resolution 3072-dimension embeddings.
# gemini-embedding-001 at max dimensions is pre-normalized by Google.
EMBED_DIMENSIONS = 3072
# After this many consecutive non-daily 429 responses on the free pool,
# escalate to the next tier of the cascade.
#
# NOTE: The constant name is historical. Under the current cascade
# (free pool -> OpenRouter -> paid) hitting this threshold escalates to
# OpenRouter first; the paid key is only tried as a last resort when
# OpenRouter itself has failed. Renamed-in-spirit but kept as
# ``PAID_KEY_FALLBACK_THRESHOLD`` to avoid touching every call site.
PAID_KEY_FALLBACK_THRESHOLD = 4
# Timezone for daily quota reset (midnight Pacific)
_PACIFIC = timezone(timedelta(hours=-7)) # PDT; adjust to -8 for PST if needed
# Redis key prefix for daily usage tracking
_REDIS_KEY_PREFIX = "gemini_key_daily_usage"
# Redis keys for user-donated embedding keys
DONATED_KEYS_SET = "gemini_embed_pool:donated_keys"
DONATED_KEYS_META = "gemini_embed_pool:donated_meta"
# Default round-robin pool of Gemini API keys for embedding rate-limit distribution
_DEFAULT_KEY_POOL = [
"AIzaSyDSwQzMMYS2Cyu1XRHHKpoCkLik6XM3q9E",
"AIzaSyBLf0DpYMCDlNB0ChAE_uDDr_3Fq4k60uo",
"AIzaSyDHI-vLbCFUcXQ7xdV7nrr-Jqbpo0Sg1bE",
"AIzaSyBr5tROOuolC1uyZZIW-z5z_p45tfRlx1c",
"AIzaSyBRYPBboNGfLGzY0YyLptr98GnIK38PF_0",
"AIzaSyBwNI9aJG0HZ_AMiw2S5SWEHWaDZu0h0es",
"AIzaSyDEK3elQDZoTx5HRFXhsmuCqIe-ITsNsZ8",
"AIzaSyDLrc8-Rx3RvrIT5Zz3T-efMcEzxNX1MUc",
"AIzaSyDEK3elQDZoTx5HRFXhsmuCqIe-ITsNsZ8",
"AIzaSyBvOzFRDmw5W3Kg89mxPQgqLf2JXk6JCy0",
"AIzaSyDZ4siJzWHLwdajN94TvdJ5gXcu6bfKyyw",
"AIzaSyA1lV0_zn7pIIbvvTdu4ukbQYRXFc70MF4",
"AIzaSyDsjV9CA3sDIzH1O7TkPg0VWfiCw4lOILo",
]
_DEFAULT_PAID_KEY = "AIzaSyCCwz9WCsIKSWsfufU6E-JbPsP1acLhZTU"
def _get_key_pool() -> list[str]:
"""Return the base round-robin Gemini key pool from env or the hardcoded default.
Reads ``GEMINI_EMBED_KEY_POOL`` (comma-separated) and falls back to
``_DEFAULT_KEY_POOL`` when unset or empty, so deployments can swap the free
pool without code changes. This is the *base* pool only; donated keys from
Redis are merged on top by :func:`reload_pool`.
Called at module import to seed ``_GEMINI_EMBED_KEY_POOL`` and the cycles,
again inside :func:`reload_pool` to recompute the merged pool, and directly
by ``scripts/embeddings_redis_cli.py``.
Returns:
list[str]: The base API keys, in pool order.
"""
env = os.environ.get("GEMINI_EMBED_KEY_POOL", "").strip()
if env:
keys = [k.strip() for k in env.split(",") if k.strip()]
if keys:
return keys
return _DEFAULT_KEY_POOL
[docs]
def get_paid_fallback_key() -> str | None:
"""Return the paid tier-3 Gemini API key used as the absolute last-resort fallback.
Reads the ``GEMINI_EMBED_PAID_KEY`` env override and falls back to
``_DEFAULT_PAID_KEY``. This key is only reached once the free pool and
OpenRouter have both failed, and is also handed out by
:func:`_next_active_key` when every free key is daily-spent.
Called within this module by :func:`_next_active_key`,
:func:`gemini_embed_paid_fallback`, and :func:`gemini_embed_paid_fallback_sync`,
and externally by ``gemini_kg_bulk_client.py`` and
``classifiers/build_tool_index.py`` for their own paid-key fallbacks.
Returns:
str | None: The paid key, or *None* if neither the env var nor the
default is set.
"""
key = os.environ.get("GEMINI_EMBED_PAID_KEY", "").strip()
return key or _DEFAULT_PAID_KEY
_GEMINI_EMBED_KEY_POOL = _get_key_pool()
_gemini_embed_key_cycle = itertools.cycle(_GEMINI_EMBED_KEY_POOL)
_gemini_embed_key_lock = threading.Lock()
_gemini_flash_key_cycle = itertools.cycle(_GEMINI_EMBED_KEY_POOL)
_gemini_flash_key_lock = threading.Lock()
# ---------------------------------------------------------------------------
# Daily quota tracking state (per model class)
# ---------------------------------------------------------------------------
_redis_client: Any | None = None
# Separate spent sets per model class so that a key exhausted for embeddings
# can still serve generation requests and vice versa.
_spent_keys_embed: set[str] = set()
_spent_keys_generate: set[str] = set()
_spent_keys_lock = threading.Lock()
_MODEL_CLASS_SETS: dict[ModelClass, set[str]] = {
"embed": _spent_keys_embed,
"generate": _spent_keys_generate,
}
# Track the current "quota day" so we know when midnight PT has passed.
_current_quota_day: str = datetime.now(_PACIFIC).strftime("%Y-%m-%d")
def _key_suffix(api_key: str) -> str:
"""Return the last 8 characters of an API key for safe Redis fields and logging.
Truncating to a suffix avoids persisting or logging the full secret while
still uniquely identifying a key within the pool. Used to build the
per-key Redis fields (``gemini_key_daily_usage:<suffix>:count``,
``...:spent:<class>``) and the ``...%s`` log markers throughout this module.
Called by :func:`record_key_usage`, :func:`mark_key_daily_spent`,
:func:`sync_spent_keys_from_redis`, :func:`get_pool_status`,
:func:`probe_all_keys`, :func:`gemini_embed_paid_fallback_sync`, and
externally by ``scripts/embeddings_redis_cli.py``.
Args:
api_key: The full Gemini API key.
Returns:
str: The trailing 8 characters of *api_key*.
"""
return api_key[-8:]
[docs]
async def reload_pool() -> int:
"""Merge donated keys from Redis into the live pool and rebuild cycles.
Returns the new total pool size. Safe to call multiple times.
"""
global _GEMINI_EMBED_KEY_POOL, _gemini_embed_key_cycle, _gemini_flash_key_cycle
base = _get_key_pool()
donated: list[str] = []
if _redis_client is not None:
try:
raw_members = await _redis_client.smembers(DONATED_KEYS_SET)
for m in raw_members:
k = m if isinstance(m, str) else m.decode()
if k and k not in base:
donated.append(k)
except Exception:
logger.warning(
"Failed to load donated embed keys from Redis", exc_info=True
)
merged = base + donated
with _gemini_embed_key_lock:
_GEMINI_EMBED_KEY_POOL = merged
_gemini_embed_key_cycle = itertools.cycle(merged)
with _gemini_flash_key_lock:
_gemini_flash_key_cycle = itertools.cycle(merged)
if donated:
logger.info(
"Embed key pool reloaded: %d base + %d donated = %d total",
len(base),
len(donated),
len(merged),
)
return len(merged)
[docs]
def init_quota_tracking(redis_client: Any) -> None:
"""Provide the async Redis client for daily quota persistence.
Must be called once at service startup before any embedding calls
are made. Call :func:`reload_pool` afterwards (from an async
context) to merge donated keys.
"""
global _redis_client
_redis_client = redis_client
logger.info(
"Gemini key pool quota tracking initialised (%d keys)",
len(_GEMINI_EMBED_KEY_POOL),
)
def _seconds_until_midnight_pt() -> int:
"""Return the seconds remaining until the next midnight Pacific Time.
Gemini daily (RPD) quotas reset at midnight Pacific, so this value is used
as the Redis TTL for per-key usage counters and spent flags, letting those
keys expire automatically when the quota day rolls over. Computed against
``_PACIFIC`` (a fixed -7h offset) and clamped to at least 1 second.
Called by :func:`record_key_usage`, :func:`mark_key_daily_spent`, and the
sync paid-key path in :func:`gemini_embed_paid_fallback_sync`.
Returns:
int: Whole seconds until midnight Pacific, minimum 1.
"""
now = datetime.now(_PACIFIC)
midnight = (now + timedelta(days=1)).replace(
hour=0,
minute=0,
second=0,
microsecond=0,
)
return max(int((midnight - now).total_seconds()), 1)
def _maybe_reset_day() -> None:
"""Clear both in-memory spent sets when the Pacific quota day has rolled over.
Compares the current Pacific date against the cached ``_current_quota_day``;
on a mismatch it empties ``_spent_keys_embed`` and ``_spent_keys_generate``
(under ``_spent_keys_lock``) and advances the cached day, so keys that were
daily-spent yesterday rejoin rotation. This is the in-memory complement to
the Redis TTLs from :func:`_seconds_until_midnight_pt`.
Called at the start of every selection/diagnostic path:
:func:`_next_active_key`, :func:`sync_spent_keys_from_redis`,
:func:`get_pool_status`, and :func:`probe_all_keys`.
"""
global _current_quota_day
today = datetime.now(_PACIFIC).strftime("%Y-%m-%d")
if today != _current_quota_day:
with _spent_keys_lock:
_spent_keys_embed.clear()
_spent_keys_generate.clear()
_current_quota_day = today
logger.info("Gemini quota day rolled over to %s – spent keys cleared", today)
def _next_active_key(
cycle: itertools.cycle, # type: ignore[type-arg]
lock: threading.Lock,
model_class: ModelClass,
) -> str:
"""Return the next non-spent key from *cycle*, or the paid fallback.
Only checks the spent set for *model_class* so that a key exhausted
for embeddings can still be returned for generation and vice versa.
"""
_maybe_reset_day()
spent_set = _MODEL_CLASS_SETS[model_class]
pool_size = len(_GEMINI_EMBED_KEY_POOL)
with lock:
for _ in range(pool_size):
key = next(cycle)
with _spent_keys_lock:
if key not in spent_set:
return key
paid = get_paid_fallback_key()
if paid:
logger.warning(
"All free Gemini keys spent for %s today; using paid key",
model_class,
)
return paid
raise RuntimeError(
f"All Gemini keys are daily-spent for {model_class} "
"and no paid fallback is set"
)
[docs]
def next_gemini_embed_key() -> str:
"""Thread-safe round-robin selection from the Gemini embedding key pool.
Skips keys that have been marked as daily-spent for embeddings.
"""
return _next_active_key(
_gemini_embed_key_cycle,
_gemini_embed_key_lock,
"embed",
)
[docs]
def next_gemini_flash_key() -> str:
"""Thread-safe round-robin selection for Gemini flash-lite generation calls.
Uses the same key pool as embeddings but an independent cycle so
embed and generation rotations don't interfere. Skips keys that
have been marked as daily-spent for generation.
"""
return _next_active_key(
_gemini_flash_key_cycle,
_gemini_flash_key_lock,
"generate",
)
# ---------------------------------------------------------------------------
# 429 parsing & spent-key management
# ---------------------------------------------------------------------------
# Model names that appear in quotaDimensions.model for each class
_EMBED_MODEL_PREFIXES = ("gemini-embedding",)
_GENERATE_MODEL_PREFIXES = ("gemini-",) # catch-all for non-embedding models
[docs]
def is_daily_quota_429(resp: httpx.Response) -> bool:
"""Return True if *resp* is a 429 caused by a daily (RPD) quota limit.
Parses the structured ``QuotaFailure`` violations in the response body
and looks for a ``quotaId`` containing ``PerDay``.
"""
if resp.status_code != 429:
return False
try:
body = resp.json()
for detail in body.get("error", {}).get("details", []):
for violation in detail.get("violations", []):
quota_id = violation.get("quotaId", "")
if "PerDay" in quota_id:
return True
except Exception:
pass
return False
[docs]
def is_daily_quota_429_for_model(resp: httpx.Response) -> ModelClass | None:
"""Identify which model class a daily-quota 429 belongs to.
Returns ``"embed"`` or ``"generate"`` based on the
``quotaDimensions.model`` field, or *None* if the response is not a
daily 429.
"""
if resp.status_code != 429:
return None
try:
body = resp.json()
for detail in body.get("error", {}).get("details", []):
for violation in detail.get("violations", []):
quota_id = violation.get("quotaId", "")
if "PerDay" not in quota_id:
continue
dims = violation.get("quotaDimensions", {})
model_name = dims.get("model", "")
if model_name.startswith(_EMBED_MODEL_PREFIXES):
return "embed"
return "generate"
except Exception:
pass
return None
[docs]
async def record_key_usage(api_key: str) -> None:
"""Increment the daily request counter for *api_key* in Redis (best-effort).
Bumps ``gemini_key_daily_usage:<suffix>:count`` and (re)sets its TTL to the
next midnight Pacific via a non-transactional pipeline, giving the
diagnostics in :func:`get_pool_status` a rough per-key call count that
self-expires daily. No-ops when the async Redis client is unwired, and
swallows any Redis error rather than disrupting the embed call.
Called within this module by :func:`probe_all_keys` and
:func:`gemini_embed_paid_fallback`, and externally by the embed transport in
``openrouter_client/transport.py`` after each Gemini POST.
Args:
api_key: The Gemini key whose usage to record.
"""
if _redis_client is None:
return
suffix = _key_suffix(api_key)
redis_key = f"{_REDIS_KEY_PREFIX}:{suffix}:count"
try:
pipe = _redis_client.pipeline(transaction=False)
pipe.incr(redis_key)
pipe.expire(redis_key, _seconds_until_midnight_pt())
await pipe.execute()
except Exception:
logger.debug("Failed to record Gemini key usage for ...%s", suffix)
[docs]
async def mark_key_daily_spent(
api_key: str,
model_class: ModelClass = "embed",
) -> None:
"""Mark *api_key* as daily-spent for *model_class*.
Updates both the in-memory set and the Redis flag.
"""
suffix = _key_suffix(api_key)
spent_set = _MODEL_CLASS_SETS[model_class]
with _spent_keys_lock:
spent_set.add(api_key)
logger.warning(
"Gemini key ...%s marked as daily-spent for %s",
suffix,
model_class,
)
if _redis_client is None:
return
try:
ttl = _seconds_until_midnight_pt()
pipe = _redis_client.pipeline(transaction=False)
pipe.set(
f"{_REDIS_KEY_PREFIX}:{suffix}:spent:{model_class}",
"1",
ex=ttl,
)
pipe.set(
f"{_REDIS_KEY_PREFIX}:{suffix}:spent_at:{model_class}",
datetime.now(timezone.utc).isoformat(),
ex=ttl,
)
await pipe.execute()
except Exception:
logger.debug("Failed to persist spent flag for ...%s/%s", suffix, model_class)
[docs]
async def sync_spent_keys_from_redis() -> None:
"""Rebuild both in-memory daily-spent sets from the per-key flags in Redis.
Lets a worker pick up exhaustion decisions made by other processes: for
every pooled key it reads ``gemini_key_daily_usage:<suffix>:spent:embed`` and
``...:spent:generate`` (plus the legacy suffix-less ``...:spent`` key, which
counts for both classes) and replaces ``_spent_keys_embed`` /
``_spent_keys_generate`` accordingly under ``_spent_keys_lock``. Calls
:func:`_maybe_reset_day` first so a quota-day rollover is honored, and no-ops
when the async Redis client is unwired; Redis errors are swallowed.
Called within this module at the start of :func:`probe_all_keys`; no external
callers were found.
"""
if _redis_client is None:
return
_maybe_reset_day()
new_embed: set[str] = set()
new_generate: set[str] = set()
try:
for key in _GEMINI_EMBED_KEY_POOL:
suffix = _key_suffix(key)
# Check per-model flags
val_e = await _redis_client.get(
f"{_REDIS_KEY_PREFIX}:{suffix}:spent:embed",
)
if val_e is not None:
new_embed.add(key)
val_g = await _redis_client.get(
f"{_REDIS_KEY_PREFIX}:{suffix}:spent:generate",
)
if val_g is not None:
new_generate.add(key)
# Backward compat: old `:spent` key (no model suffix) counts for both
val_old = await _redis_client.get(
f"{_REDIS_KEY_PREFIX}:{suffix}:spent",
)
if val_old is not None:
new_embed.add(key)
new_generate.add(key)
except Exception:
logger.debug("Failed to sync spent keys from Redis", exc_info=True)
return
with _spent_keys_lock:
_spent_keys_embed.clear()
_spent_keys_embed.update(new_embed)
_spent_keys_generate.clear()
_spent_keys_generate.update(new_generate)
if new_embed or new_generate:
logger.info(
"Synced spent keys from Redis: embed=%d/%d, generate=%d/%d",
len(new_embed),
len(_GEMINI_EMBED_KEY_POOL),
len(new_generate),
len(_GEMINI_EMBED_KEY_POOL),
)
[docs]
async def get_pool_status() -> dict[str, Any]:
"""Return a snapshot of pool health (active vs spent keys, usage) for diagnostics.
Calls :func:`_maybe_reset_day` to honor a quota-day rollover, then reports
the total key count, active and daily-spent counts split per model class
(embed vs generate), and the current quota day. When the async Redis client
is wired it also reads each key's ``gemini_key_daily_usage:<suffix>:count``
and includes a ``...<suffix>``-keyed ``usage_counts`` map; Redis failures
are tolerated and simply omit those counts.
Called externally by the ``donate_embed_key`` tool
(``tools/donate_embed_key.py``) to surface pool status to users.
Returns:
dict[str, Any]: Pool health fields including ``total_keys``,
``embed_active`` / ``embed_spent``, ``generate_active`` /
``generate_spent``, ``quota_day``, and (when Redis is available)
``usage_counts``.
"""
_maybe_reset_day()
with _spent_keys_lock:
embed_spent = len(_spent_keys_embed)
generate_spent = len(_spent_keys_generate)
total = len(_GEMINI_EMBED_KEY_POOL)
result: dict[str, Any] = {
"total_keys": total,
"embed_active": total - embed_spent,
"embed_spent": embed_spent,
"generate_active": total - generate_spent,
"generate_spent": generate_spent,
"quota_day": _current_quota_day,
}
if _redis_client is not None:
counts: dict[str, int] = {}
try:
for key in _GEMINI_EMBED_KEY_POOL:
suffix = _key_suffix(key)
raw = await _redis_client.get(
f"{_REDIS_KEY_PREFIX}:{suffix}:count",
)
counts[f"...{suffix}"] = int(raw) if raw else 0
except Exception:
pass
result["usage_counts"] = counts
return result
# ---------------------------------------------------------------------------
# Background probe -- runs every 2 hours
# ---------------------------------------------------------------------------
_PROBE_GENERATE_MODEL = "gemini-3.1-flash-lite"
[docs]
async def probe_all_keys() -> None:
"""Probe each non-spent key for both embedding and generation daily exhaustion.
Sends a minimal request to each endpoint and marks the key spent for
the relevant model class on a daily 429.
"""
t0_probe = time.monotonic()
_maybe_reset_day()
await sync_spent_keys_from_redis()
with _spent_keys_lock:
already_embed = set(_spent_keys_embed)
already_generate = set(_spent_keys_generate)
newly_embed = 0
newly_generate = 0
async with httpx.AsyncClient(timeout=30.0) as client:
for key in _GEMINI_EMBED_KEY_POOL:
suffix = _key_suffix(key)
# --- Embedding probe ---
if key not in already_embed:
url = (
f"{GEMINI_EMBED_BASE}/gemini-embedding-001"
f":batchEmbedContents?key={key}"
)
payload = {
"requests": [
{
"model": "models/gemini-embedding-001",
"content": {"parts": [{"text": "probe"}]},
"output_dimensionality": 768,
}
],
}
try:
resp = await client.post(url, json=payload)
await record_key_usage(key)
if is_daily_quota_429(resp):
await mark_key_daily_spent(key, "embed")
newly_embed += 1
except Exception:
logger.debug(
"Embed probe failed for ...%s",
suffix,
exc_info=True,
)
# --- Generation probe ---
if key not in already_generate:
url = (
f"{GEMINI_EMBED_BASE}/{_PROBE_GENERATE_MODEL}"
f":generateContent?key={key}"
)
payload = {
"contents": [{"parts": [{"text": "probe"}]}],
"generationConfig": {"maxOutputTokens": 1},
}
try:
resp = await client.post(url, json=payload)
await record_key_usage(key)
if is_daily_quota_429(resp):
await mark_key_daily_spent(key, "generate")
newly_generate += 1
except Exception:
logger.debug(
"Generate probe failed for ...%s",
suffix,
exc_info=True,
)
total = len(_GEMINI_EMBED_KEY_POOL)
with _spent_keys_lock:
final_spent_embed = set(_spent_keys_embed)
exhausted_indices = [
i for i, k in enumerate(_GEMINI_EMBED_KEY_POOL) if k in final_spent_embed
]
exhausted = len(exhausted_indices)
healthy = total - exhausted
if healthy == total:
status = "ok"
elif healthy == 0:
status = "all_exhausted"
else:
status = "partial"
from observability import publish_debug_event
asyncio.create_task(
publish_debug_event(
"gemini_key_probe",
"embedding_queue",
status=status,
duration_ms=(time.monotonic() - t0_probe) * 1000,
preview=f"keys_probed={total} healthy={healthy} exhausted={exhausted}",
payload={
"keys_probed": total,
"keys_healthy": healthy,
"keys_exhausted": exhausted,
"exhausted_key_indices": exhausted_indices,
},
),
name="obs_gemini_key_probe",
)
logger.info(
"Gemini key probe complete (%d keys): "
"embed %d active / %d newly spent / %d prev spent, "
"generate %d active / %d newly spent / %d prev spent",
total,
total - len(already_embed) - newly_embed,
newly_embed,
len(already_embed),
total - len(already_generate) - newly_generate,
newly_generate,
len(already_generate),
)
def _gemini_model_name(model: str) -> str:
"""Strip the ``google/`` provider prefix from an OpenRouter-style model id.
Internal callers pass model ids in OpenRouter form (e.g.
``google/gemini-embedding-001``), but the native Gemini REST endpoints want
the bare name. This normalizes by removing a leading ``google/`` and leaves
already-bare names untouched.
Called by :func:`_build_paid_embed_request` and :func:`embed_batch_via_gemini`
when constructing native Gemini URLs and request bodies. (A separately
defined ``OpenRouterClient._gemini_model_name`` method in
``openrouter_client/transport.py`` is the transport-side equivalent.)
Args:
model: A model id, optionally prefixed with ``google/``.
Returns:
str: The model id with any ``google/`` prefix removed.
"""
return model.removeprefix("google/")
# ---------------------------------------------------------------------------
# OpenRouter embedding fallback (preferred fallback when free pool 429s;
# the paid Gemini key is the absolute last resort if OpenRouter also fails)
# ---------------------------------------------------------------------------
_OPENROUTER_EMBED_URL = "https://openrouter.ai/api/v1/embeddings"
# OpenRouter embed: retry until we get vectors or hit the cap. Transient DNS/TCP
# failures (e.g. errno 101 ENETUNREACH) must not abort after a single POST.
_OPENROUTER_EMBED_MAX_ATTEMPTS = max(
1,
int(os.environ.get("OPENROUTER_EMBED_MAX_ATTEMPTS", "24")),
)
_OPENROUTER_EMBED_RETRY_BASE_DELAY = float(
os.environ.get("OPENROUTER_EMBED_RETRY_BASE_DELAY", "2.0"),
)
_OPENROUTER_EMBED_RETRY_MAX_DELAY = float(
os.environ.get("OPENROUTER_EMBED_RETRY_MAX_DELAY", "120.0"),
)
_OPENROUTER_EMBED_HTTP_RETRIABLE = frozenset({429, 500, 502, 503, 504})
[docs]
class OpenRouterEmbedParseError(RuntimeError):
"""OpenRouter returned HTTP 200 but the body was not the expected
``{"data": [{"index": int, "embedding": [..]}...]}`` shape.
Treated as non-retriable by the outer retry loops in
:mod:`openrouter_client` and :mod:`rag_system.openrouter_embeddings`,
since the underlying provider is returning a malformed success payload
(typically an upstream error surfaced as a 200) and immediate retries
will hit the same issue.
"""
def _parse_openrouter_embed_body(
body_json: Any,
expected_count: int,
body_text: str,
) -> list[list[float]]:
"""Validate and extract vectors from an OpenRouter 200 embed response.
Raises :class:`OpenRouterEmbedParseError` (with a truncated body preview)
if the response is missing the ``data`` array, has malformed items, or
returns a different number of vectors than were requested.
"""
if not isinstance(body_json, dict):
raise OpenRouterEmbedParseError(
f"OpenRouter 200 OK body is not a JSON object: {body_text[:300]!r}"
)
items = body_json.get("data")
if not isinstance(items, list) or not items:
err = body_json.get("error")
raise OpenRouterEmbedParseError(
f"OpenRouter 200 OK missing/empty 'data' "
f"(error={err!r}, body={body_text[:500]!r})"
)
try:
ordered = sorted(items, key=lambda x: x["index"])
out = [item["embedding"] for item in ordered]
except (KeyError, TypeError) as exc:
raise OpenRouterEmbedParseError(
f"OpenRouter embed items malformed ({exc}); body={body_text[:500]!r}"
) from exc
if len(out) != expected_count:
raise OpenRouterEmbedParseError(
f"OpenRouter returned {len(out)} vectors, expected {expected_count}; "
f"body={body_text[:300]!r}"
)
return out
def _openrouter_embed_retry_delay(attempt_index: int) -> float:
"""Compute the OpenRouter embed retry sleep with exponential backoff and jitter.
Doubles ``_OPENROUTER_EMBED_RETRY_BASE_DELAY`` per attempt (capping the
exponent at 12 and the result at ``_OPENROUTER_EMBED_RETRY_MAX_DELAY``), then
adds up to 15 percent jitter to spread out concurrent retriers. Used so the
embed fallback rides out transient 429 / 5xx and network failures instead of
hammering OpenRouter.
Called only within this module, by :func:`openrouter_embed_batch` and
:func:`openrouter_embed_batch_sync` between attempts.
Args:
attempt_index: Zero-based retry index, where 0 is the wait before the
first retry.
Returns:
float: The number of seconds to sleep before the next attempt.
"""
delay = min(
_OPENROUTER_EMBED_RETRY_BASE_DELAY * (2 ** min(attempt_index, 12)),
_OPENROUTER_EMBED_RETRY_MAX_DELAY,
)
return delay + random.uniform(0, min(delay * 0.15, 15.0))
def _is_openrouter_transient_http(status: int) -> bool:
"""Return whether an OpenRouter HTTP status code is worth retrying.
Checks *status* against ``_OPENROUTER_EMBED_HTTP_RETRIABLE`` (429 and the
5xx codes 500/502/503/504), the set of transient server-side / rate-limit
failures the embed retry loops back off on rather than abort.
Called by both :func:`openrouter_embed_batch` and
:func:`openrouter_embed_batch_sync` after a non-200 response, to decide
whether to ``continue`` the retry loop or surface the error. No internal
callers exist outside this module.
Args:
status: The HTTP status code returned by the OpenRouter request.
Returns:
bool: ``True`` if the status is in the retriable set, else ``False``.
"""
return status in _OPENROUTER_EMBED_HTTP_RETRIABLE
def _is_openrouter_network_error(exc: BaseException) -> bool:
"""Return whether an exception is a transient network/transport failure.
Classifies *exc* as retriable when it is one of the ``httpx`` timeout /
connection / read-write / protocol errors (or a bare :class:`TimeoutError`,
which covers DNS/TCP failures such as errno 101 ENETUNREACH). Any other
exception type is treated as fatal so it propagates immediately.
Called by :func:`openrouter_embed_batch` and
:func:`openrouter_embed_batch_sync` inside their ``except`` blocks: a return
of ``False`` re-raises at once, while ``True`` lets the backoff loop retry.
No callers exist outside this module.
Args:
exc: The exception raised by the OpenRouter HTTP request.
Returns:
bool: ``True`` if *exc* is a transient network/transport error, else
``False``.
"""
if isinstance(
exc,
(
httpx.TimeoutException,
httpx.ConnectError,
httpx.ConnectTimeout,
httpx.ReadTimeout,
httpx.WriteTimeout,
httpx.PoolTimeout,
httpx.RemoteProtocolError,
httpx.ReadError,
httpx.WriteError,
),
):
return True
if isinstance(exc, TimeoutError):
return True
return False
_OR_CLIENT_TIMEOUT = httpx.Timeout(120.0, connect=45.0)
_DEFAULT_OPENROUTER_KEY = (
"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"
)
[docs]
def get_openrouter_api_key() -> str | None:
"""Return the OpenRouter API key for the embed fallback, from env or the default.
Prefers ``OPENROUTER_API_KEY``, then the legacy ``API_KEY`` env var, and
finally ``_DEFAULT_OPENROUTER_KEY`` so the OpenRouter embed path works
out-of-the-box. The result becomes the ``Bearer`` token for the embeddings
endpoint.
Called within this module by :func:`openrouter_embed_batch` and
:func:`openrouter_embed_batch_sync` when no explicit ``api_key`` is passed.
Returns:
str | None: The resolved OpenRouter key (effectively always non-*None*
given the baked-in default).
"""
key = (
os.environ.get("OPENROUTER_API_KEY", "").strip()
or os.environ.get("API_KEY", "").strip()
)
return key or _DEFAULT_OPENROUTER_KEY
[docs]
async def openrouter_embed_batch(
texts: list[str],
*,
model: str = "google/gemini-embedding-001",
api_key: str | None = None,
dimensions: int = EMBED_DIMENSIONS,
) -> list[list[float]]:
"""Embed texts via the OpenRouter /embeddings endpoint (async).
Retries on transient network failures and retriable HTTP codes (429, 5xx)
until success or ``OPENROUTER_EMBED_MAX_ATTEMPTS`` is exhausted.
Empty or whitespace-only strings are not sent to the API; those positions
receive zero vectors of length *dimensions*.
"""
if not texts:
return []
valid_indices: list[int] = []
valid_texts: list[str] = []
for i, t in enumerate(texts):
if t and t.strip():
valid_indices.append(i)
valid_texts.append(t)
if not valid_texts:
return [[0.0] * dimensions for _ in texts]
if len(valid_texts) < len(texts):
merged = await openrouter_embed_batch(
valid_texts,
model=model,
api_key=api_key,
dimensions=dimensions,
)
out: list[list[float]] = [[0.0] * dimensions for _ in texts]
for idx, emb in zip(valid_indices, merged):
out[idx] = emb
return out
texts = valid_texts
key = api_key or get_openrouter_api_key()
if not key:
raise RuntimeError(
"OpenRouter fallback unavailable: no API key "
"(set OPENROUTER_API_KEY or API_KEY)"
)
headers = {
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
}
payload: dict = {"model": model, "input": texts}
if dimensions:
payload["dimensions"] = dimensions
last_err: BaseException | None = None
for attempt in range(_OPENROUTER_EMBED_MAX_ATTEMPTS):
if attempt > 0:
delay = _openrouter_embed_retry_delay(attempt - 1)
logger.warning(
"OpenRouter embed retry %d/%d (sleep %.1fs): %s",
attempt + 1,
_OPENROUTER_EMBED_MAX_ATTEMPTS,
delay,
last_err,
)
await asyncio.sleep(delay)
try:
from observability import get_http_call_origin
origin = get_http_call_origin()
logger.info(
"Sending OpenRouter batch embedding request of %d texts for model %s (origin: %s)",
len(texts),
model,
origin,
)
async with httpx.AsyncClient(timeout=_OR_CLIENT_TIMEOUT) as client:
resp = await client.post(
_OPENROUTER_EMBED_URL,
json=payload,
headers=headers,
)
if resp.status_code == 200:
logger.info(
"OpenRouter batch embedding request completed successfully (status=200) (origin: %s)",
origin,
)
body_text = resp.text
try:
body_json = resp.json()
except Exception as json_exc:
parse_err = OpenRouterEmbedParseError(
f"OpenRouter 200 OK body is not valid JSON ({json_exc}); "
f"body={body_text[:500]!r}"
)
logger.warning(
"OpenRouter embed 200 OK but body not JSON: %s",
body_text[:300],
)
asyncio.create_task(
publish_http_error_event(
http_service="openrouter_embed_batch",
http_status=200,
endpoint=_OPENROUTER_EMBED_URL[:120],
detail=str(parse_err)[:500],
error_kind="parse_error",
),
)
await _auto_clear_openrouter_only_on_parse_error(parse_err)
raise parse_err from json_exc
try:
return _parse_openrouter_embed_body(
body_json,
len(texts),
body_text,
)
except OpenRouterEmbedParseError as parse_err:
logger.warning(
"OpenRouter embed 200 OK but malformed payload "
"(non-retriable): %s",
str(parse_err)[:500],
)
asyncio.create_task(
publish_http_error_event(
http_service="openrouter_embed_batch",
http_status=200,
endpoint=_OPENROUTER_EMBED_URL[:120],
detail=str(parse_err)[:500],
error_kind="parse_error",
),
)
await _auto_clear_openrouter_only_on_parse_error(parse_err)
raise
body = resp.text[:500]
last_err = RuntimeError(
f"OpenRouter embed HTTP {resp.status_code}: {body}",
)
if _is_openrouter_transient_http(resp.status_code):
logger.warning(
"OpenRouter embed HTTP %s (transient): %s",
resp.status_code,
body[:300],
)
continue
asyncio.create_task(
publish_http_error_event(
http_service="openrouter_embed_batch",
http_status=resp.status_code,
endpoint=_OPENROUTER_EMBED_URL[:120],
detail=body[:500],
),
)
raise last_err
except OpenRouterEmbedParseError:
raise
except Exception as exc:
if not _is_openrouter_network_error(exc):
raise
last_err = exc
if attempt >= _OPENROUTER_EMBED_MAX_ATTEMPTS - 1:
asyncio.create_task(
publish_http_error_event(
http_service="openrouter_embed_batch",
http_status=0,
endpoint=_OPENROUTER_EMBED_URL[:120],
detail=str(exc)[:500],
error_kind="network",
),
)
raise
logger.warning(
"OpenRouter embed network error (attempt %d/%d): %s",
attempt + 1,
_OPENROUTER_EMBED_MAX_ATTEMPTS,
exc,
)
asyncio.create_task(
publish_http_error_event(
http_service="openrouter_embed_batch",
http_status=0,
endpoint=_OPENROUTER_EMBED_URL[:120],
detail=str(last_err)[:500],
error_kind="exhausted_retries",
),
)
raise RuntimeError(
f"OpenRouter embed failed after {_OPENROUTER_EMBED_MAX_ATTEMPTS} attempts",
) from last_err
def _publish_http_error_sync(
*,
http_service: str,
http_status: int = 0,
endpoint: str = "",
detail: str = "",
error_kind: str = "",
) -> None:
"""Fire-and-forget sync publish of an http_error observability event.
Uses the module-level ``_sync_redis_client`` (if available) to write
the same Redis HASH + PUBLISH that the async ``publish_http_error_event``
would. Silently no-ops when Redis is unavailable.
"""
global _sync_redis_client
if _sync_redis_client is None:
try:
from config import Config
cfg = Config.load()
if not cfg.redis_url:
return
ssl_kwargs = cfg.redis_connection_kwargs_for_url(cfg.redis_url)
_sync_redis_client = __import__("redis").Redis.from_url(
cfg.redis_url, decode_responses=True, **ssl_kwargs
)
except Exception:
return
import uuid as _uuid
ts = time.time()
key = f"obs:{_uuid.uuid4().hex}"
mapping = {
"event_type": "http_error",
"platform": "-",
"channel_id": "-",
"user_id": "-",
"tool_name": "-",
"request_id": "-",
"preview": f"{http_service} status={http_status} {endpoint[:80]} {detail[:120]}"[
:2000
],
"timestamp": str(ts),
"http_status": str(int(http_status)),
"http_service": http_service or "-",
"payload_json": __import__("json").dumps(
{
"duration_ms": 0,
"endpoint": (endpoint or "")[:500],
"detail": (detail or "")[:800],
**({"error_kind": error_kind} if error_kind else {}),
}
),
}
try:
pipe = _sync_redis_client.pipeline()
pipe.hset(key, mapping=mapping)
pipe.expire(key, 45 * 86400)
pipe.execute()
wire = __import__("json").dumps(
{
"event_type": "http_error",
"platform": "",
"channel_id": "",
"user_id": "",
"tool_name": "",
"request_id": "",
"preview": mapping["preview"],
"timestamp": ts,
"http_status": int(http_status),
"http_service": http_service,
"payload": {
"duration_ms": 0,
"endpoint": (endpoint or "")[:500],
"detail": (detail or "")[:800],
**({"error_kind": error_kind} if error_kind else {}),
},
"doc_id": key,
}
)
_sync_redis_client.publish("stargazer:observability", wire)
except Exception:
logger.debug("_publish_http_error_sync failed", exc_info=True)
[docs]
def openrouter_embed_batch_sync(
texts: list[str],
*,
model: str = "google/gemini-embedding-001",
api_key: str | None = None,
dimensions: int = EMBED_DIMENSIONS,
) -> list[list[float]]:
"""Embed texts via the OpenRouter /embeddings endpoint (sync).
Synchronous version for callers that cannot use ``await``
(e.g. ChromaDB embedding functions). Same retry policy as
:func:`openrouter_embed_batch`.
Empty or whitespace-only strings are not sent to the API; those positions
receive zero vectors of length *dimensions*.
"""
if not texts:
return []
valid_indices: list[int] = []
valid_texts: list[str] = []
for i, t in enumerate(texts):
if t and t.strip():
valid_indices.append(i)
valid_texts.append(t)
if not valid_texts:
return [[0.0] * dimensions for _ in texts]
if len(valid_texts) < len(texts):
merged = openrouter_embed_batch_sync(
valid_texts,
model=model,
api_key=api_key,
dimensions=dimensions,
)
out: list[list[float]] = [[0.0] * dimensions for _ in texts]
for idx, emb in zip(valid_indices, merged):
out[idx] = emb
return out
texts = valid_texts
key = api_key or get_openrouter_api_key()
if not key:
raise RuntimeError(
"OpenRouter fallback unavailable: no API key "
"(set OPENROUTER_API_KEY or API_KEY)"
)
headers = {
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
}
payload: dict = {"model": model, "input": texts}
if dimensions:
payload["dimensions"] = dimensions
last_err: BaseException | None = None
for attempt in range(_OPENROUTER_EMBED_MAX_ATTEMPTS):
if attempt > 0:
delay = _openrouter_embed_retry_delay(attempt - 1)
logger.warning(
"OpenRouter embed retry %d/%d (sleep %.1fs): %s",
attempt + 1,
_OPENROUTER_EMBED_MAX_ATTEMPTS,
delay,
last_err,
)
time.sleep(delay)
try:
from observability import get_http_call_origin
origin = get_http_call_origin()
logger.info(
"Sending OpenRouter batch embedding request of %d texts for model %s (origin: %s)",
len(texts),
model,
origin,
)
with httpx.Client(timeout=_OR_CLIENT_TIMEOUT) as client:
resp = client.post(
_OPENROUTER_EMBED_URL,
json=payload,
headers=headers,
)
if resp.status_code == 200:
logger.info(
"OpenRouter batch embedding request completed successfully (status=200) (origin: %s)",
origin,
)
body_text = resp.text
try:
body_json = resp.json()
except Exception as json_exc:
parse_err = OpenRouterEmbedParseError(
f"OpenRouter 200 OK body is not valid JSON ({json_exc}); "
f"body={body_text[:500]!r}"
)
logger.warning(
"OpenRouter embed 200 OK but body not JSON: %s",
body_text[:300],
)
_publish_http_error_sync(
http_service="openrouter_embed_batch_sync",
http_status=200,
endpoint=_OPENROUTER_EMBED_URL[:120],
detail=str(parse_err)[:500],
error_kind="parse_error",
)
_auto_clear_openrouter_only_on_parse_error_sync(parse_err)
raise parse_err from json_exc
try:
return _parse_openrouter_embed_body(
body_json,
len(texts),
body_text,
)
except OpenRouterEmbedParseError as parse_err:
logger.warning(
"OpenRouter embed 200 OK but malformed payload "
"(non-retriable): %s",
str(parse_err)[:500],
)
_publish_http_error_sync(
http_service="openrouter_embed_batch_sync",
http_status=200,
endpoint=_OPENROUTER_EMBED_URL[:120],
detail=str(parse_err)[:500],
error_kind="parse_error",
)
_auto_clear_openrouter_only_on_parse_error_sync(parse_err)
raise
body = resp.text[:500]
last_err = RuntimeError(
f"OpenRouter embed HTTP {resp.status_code}: {body}",
)
if _is_openrouter_transient_http(resp.status_code):
logger.warning(
"OpenRouter embed HTTP %s (transient): %s",
resp.status_code,
body[:300],
)
continue
_publish_http_error_sync(
http_service="openrouter_embed_batch_sync",
http_status=resp.status_code,
endpoint=_OPENROUTER_EMBED_URL[:120],
detail=body[:500],
)
raise last_err
except OpenRouterEmbedParseError:
raise
except Exception as exc:
if not _is_openrouter_network_error(exc):
raise
last_err = exc
if attempt >= _OPENROUTER_EMBED_MAX_ATTEMPTS - 1:
_publish_http_error_sync(
http_service="openrouter_embed_batch_sync",
http_status=0,
endpoint=_OPENROUTER_EMBED_URL[:120],
detail=str(exc)[:500],
error_kind="network",
)
raise
logger.warning(
"OpenRouter embed network error (attempt %d/%d): %s",
attempt + 1,
_OPENROUTER_EMBED_MAX_ATTEMPTS,
exc,
)
_publish_http_error_sync(
http_service="openrouter_embed_batch_sync",
http_status=0,
endpoint=_OPENROUTER_EMBED_URL[:120],
detail=str(last_err)[:500],
error_kind="exhausted_retries",
)
raise RuntimeError(
f"OpenRouter embed failed after {_OPENROUTER_EMBED_MAX_ATTEMPTS} attempts",
) from last_err
# ---------------------------------------------------------------------------
# Paid-key last-resort fallback (used when free pool AND OpenRouter fail)
# ---------------------------------------------------------------------------
_PAID_FALLBACK_TIMEOUT = httpx.Timeout(60.0, connect=30.0)
def _build_paid_embed_request(
texts: list[str],
*,
model: str,
dimensions: int,
task_type: str | None,
paid_key: str,
) -> tuple[str, dict[str, Any]]:
"""Build the URL and JSON payload for a native Gemini ``batchEmbedContents`` call.
Assembles the request for the paid tier-3 key path: it normalizes the model
name via :func:`_gemini_model_name`, builds the keyed endpoint URL, and packs
one request per text (each with the model, text part, and
``output_dimensionality``), optionally attaching ``taskType`` when *task_type*
is given. Pure construction only; it performs no I/O.
Called within this module by :func:`gemini_embed_paid_fallback` and
:func:`gemini_embed_paid_fallback_sync` to share request-building logic
between the async and sync paid-key fallbacks.
Args:
texts: The texts to embed (already filtered to non-empty).
model: The model id, optionally ``google/``-prefixed.
dimensions: Output embedding dimensionality.
task_type: Optional Gemini task type to set per request.
paid_key: The paid Gemini API key to embed into the URL.
Returns:
tuple[str, dict[str, Any]]: The request URL and its JSON body
(``{"requests": [...]}``).
"""
gemini_model = _gemini_model_name(model)
url = f"{GEMINI_EMBED_BASE}/{gemini_model}:batchEmbedContents" f"?key={paid_key}"
requests_list: list[dict[str, Any]] = []
for t in texts:
req: dict[str, Any] = {
"model": f"models/{gemini_model}",
"content": {"parts": [{"text": t}]},
"output_dimensionality": dimensions,
}
if task_type:
req["taskType"] = task_type
requests_list.append(req)
return url, {"requests": requests_list}
[docs]
async def gemini_embed_paid_fallback(
texts: list[str],
*,
model: str = "google/gemini-embedding-001",
dimensions: int = EMBED_DIMENSIONS,
task_type: str | None = None,
) -> list[list[float]]:
"""Last-resort embedding via the paid tier-3 Gemini key (async).
Single attempt, no retries — caller decides what to do on failure.
Records key usage in Redis (when wired) and marks the paid key
daily-spent for embeddings on a daily 429.
Empty/whitespace-only inputs receive zero vectors of length *dimensions*.
Raises ``RuntimeError`` when no paid key is configured, when the call
fails non-2xx, or when the paid key 429s.
"""
if not texts:
return []
valid_indices: list[int] = []
valid_texts: list[str] = []
for i, t in enumerate(texts):
if t and t.strip():
valid_indices.append(i)
valid_texts.append(t)
if not valid_texts:
return [[0.0] * dimensions for _ in texts]
paid = get_paid_fallback_key()
if not paid:
raise RuntimeError(
"Paid Gemini fallback unavailable: no key set "
"(set GEMINI_EMBED_PAID_KEY)",
)
url, payload = _build_paid_embed_request(
valid_texts,
model=model,
dimensions=dimensions,
task_type=task_type,
paid_key=paid,
)
async with httpx.AsyncClient(timeout=_PAID_FALLBACK_TIMEOUT) as client:
resp = await client.post(url, json=payload)
await record_key_usage(paid)
if resp.status_code == 429:
if is_daily_quota_429(resp):
await mark_key_daily_spent(paid, "embed")
raise RuntimeError(
"Paid Gemini key daily quota exhausted: " f"{resp.text[:300]}",
)
raise RuntimeError(
f"Paid Gemini key returned 429 (rate-limited): {resp.text[:300]}",
)
resp.raise_for_status()
try:
data = resp.json()
partial = [item["values"] for item in data["embeddings"]]
except Exception as exc:
raise RuntimeError(
f"Paid Gemini fallback returned malformed body: " f"{resp.text[:300]}",
) from exc
if len(valid_texts) == len(texts):
return partial
out: list[list[float]] = [[0.0] * dimensions for _ in texts]
for idx, vec in zip(valid_indices, partial):
out[idx] = vec
return out
[docs]
def gemini_embed_paid_fallback_sync(
texts: list[str],
*,
model: str = "google/gemini-embedding-001",
dimensions: int = EMBED_DIMENSIONS,
task_type: str | None = None,
) -> list[list[float]]:
"""Sync mirror of :func:`gemini_embed_paid_fallback`.
Used by the ChromaDB sync embedding paths. Daily-quota tracking in
Redis is best-effort (no-op when the sync Redis client is unavailable),
since the paid key is only hit on the slow last-resort path.
"""
if not texts:
return []
valid_indices: list[int] = []
valid_texts: list[str] = []
for i, t in enumerate(texts):
if t and t.strip():
valid_indices.append(i)
valid_texts.append(t)
if not valid_texts:
return [[0.0] * dimensions for _ in texts]
paid = get_paid_fallback_key()
if not paid:
raise RuntimeError(
"Paid Gemini fallback unavailable: no key set "
"(set GEMINI_EMBED_PAID_KEY)",
)
url, payload = _build_paid_embed_request(
valid_texts,
model=model,
dimensions=dimensions,
task_type=task_type,
paid_key=paid,
)
with httpx.Client(timeout=_PAID_FALLBACK_TIMEOUT) as client:
resp = client.post(url, json=payload)
suffix = _key_suffix(paid)
if _sync_redis_client is not None:
try:
pipe = _sync_redis_client.pipeline(transaction=False)
pipe.incr(f"{_REDIS_KEY_PREFIX}:{suffix}:count")
pipe.expire(
f"{_REDIS_KEY_PREFIX}:{suffix}:count",
_seconds_until_midnight_pt(),
)
pipe.execute()
except Exception:
logger.debug(
"Failed to record paid Gemini key usage (sync)",
exc_info=True,
)
if resp.status_code == 429:
if is_daily_quota_429(resp):
with _spent_keys_lock:
_spent_keys_embed.add(paid)
if _sync_redis_client is not None:
try:
ttl = _seconds_until_midnight_pt()
pipe = _sync_redis_client.pipeline(transaction=False)
pipe.set(
f"{_REDIS_KEY_PREFIX}:{suffix}:spent:embed",
"1",
ex=ttl,
)
pipe.set(
f"{_REDIS_KEY_PREFIX}:{suffix}:spent_at:embed",
datetime.now(timezone.utc).isoformat(),
ex=ttl,
)
pipe.execute()
except Exception:
logger.debug(
"Failed to persist paid spent flag (sync)",
exc_info=True,
)
raise RuntimeError(
"Paid Gemini key daily quota exhausted (sync): " f"{resp.text[:300]}",
)
raise RuntimeError(
f"Paid Gemini key returned 429 (sync, rate-limited): " f"{resp.text[:300]}",
)
resp.raise_for_status()
try:
data = resp.json()
partial = [item["values"] for item in data["embeddings"]]
except Exception as exc:
raise RuntimeError(
f"Paid Gemini fallback returned malformed body (sync): "
f"{resp.text[:300]}",
) from exc
if len(valid_texts) == len(texts):
return partial
out: list[list[float]] = [[0.0] * dimensions for _ in texts]
for idx, vec in zip(valid_indices, partial):
out[idx] = vec
return out
# ---------------------------------------------------------------------------
# "OpenRouter-only" mode (set when free pool 429s; OpenRouter is preferred,
# but callers must still try the paid key as a last resort if OR fails)
# ---------------------------------------------------------------------------
_OPENROUTER_ONLY_KEY = "embed:openrouter_only"
_OPENROUTER_ONLY_TTL = 30 * 60 # 30 minutes (short blast-radius circuit breaker)
_openrouter_only: bool = False
_sync_redis_client: Any | None = None
[docs]
def is_openrouter_only() -> bool:
"""Return the in-memory OpenRouter-only flag without touching Redis.
A cheap, synchronous read of the process-local ``_openrouter_only`` state
set by :func:`set_openrouter_only` / cleared by :func:`clear_openrouter_only`.
Unlike :func:`check_openrouter_only`, it never consults Redis, so it cannot
observe the flag's TTL expiry or activations from other processes.
No in-repo callers were found outside this module; it serves as a public
no-I/O accessor for the circuit-breaker state.
Returns:
bool: ``True`` if OpenRouter-only mode is currently flagged in this
process.
"""
return _openrouter_only
[docs]
async def check_openrouter_only() -> bool:
"""Return whether OpenRouter-only mode is active.
When Redis is wired (``init_quota_tracking``), reads the TTL key so the
mode expires when Redis lapses or is cleared. When Redis is not wired,
relies on in-memory state set by :func:`set_openrouter_only` only.
"""
global _openrouter_only
if _redis_client is None:
# Do not clear memory: standalone scripts often skip init_quota_tracking,
# but set_openrouter_only() still flips this flag for the process.
return _openrouter_only
try:
val = await _redis_client.get(_OPENROUTER_ONLY_KEY)
_openrouter_only = val is not None
if _openrouter_only:
logger.debug("OpenRouter-only flag detected in Redis")
return _openrouter_only
except Exception:
logger.debug("Failed to check openrouter_only flag", exc_info=True)
return _openrouter_only
[docs]
def check_openrouter_only_sync() -> bool:
"""Sync variant: check Redis before each embedding call.
Used by SyncOpenRouterEmbeddings (ChromaDB path). Lazy-creates a sync
Redis client from config on first use.
"""
global _openrouter_only, _sync_redis_client
if _sync_redis_client is None:
try:
from config import Config
cfg = Config.load()
if not cfg.redis_url:
return _openrouter_only
ssl_kwargs = cfg.redis_connection_kwargs_for_url(cfg.redis_url)
_sync_redis_client = __import__("redis").Redis.from_url(
cfg.redis_url, decode_responses=True, **ssl_kwargs
)
except Exception:
logger.debug(
"Failed to create sync Redis for openrouter_only", exc_info=True
)
return _openrouter_only
try:
val = _sync_redis_client.get(_OPENROUTER_ONLY_KEY)
_openrouter_only = val is not None
if _openrouter_only:
logger.info("OpenRouter-only flag detected in Redis (sync)")
return _openrouter_only
except Exception:
logger.debug("Failed to check openrouter_only flag (sync)", exc_info=True)
return _openrouter_only
[docs]
async def set_openrouter_only() -> None:
"""Activate OpenRouter-only mode for embedding calls.
Sets both the in-memory flag and a Redis key with a TTL of
``_OPENROUTER_ONLY_TTL`` (30 minutes).
"""
global _openrouter_only
_openrouter_only = True
logger.warning(
"OpenRouter-only mode ACTIVATED for embeddings (TTL=%ds)",
_OPENROUTER_ONLY_TTL,
)
if _redis_client is None:
return
try:
await _redis_client.set(
_OPENROUTER_ONLY_KEY,
"1",
ex=_OPENROUTER_ONLY_TTL,
)
except Exception:
logger.debug("Failed to persist openrouter_only flag", exc_info=True)
[docs]
async def clear_openrouter_only() -> None:
"""Deactivate OpenRouter-only mode, restoring the free Gemini pool as primary.
Clears the in-memory ``_openrouter_only`` flag and deletes the
``embed:openrouter_only`` Redis key (when the async client is wired) so the
embed cascade resumes using the free pool first. Acts as the manual / auto
counterpart to :func:`set_openrouter_only`; Redis errors are swallowed.
Called within this module by :func:`_auto_clear_openrouter_only_on_parse_error`,
and externally by the embedding-refresh jobs
``classifiers/update_tool_embeddings.py`` and
``classifiers/update_changed_tool_embeddings.py`` at the end of a run.
"""
global _openrouter_only
_openrouter_only = False
logger.info("OpenRouter-only mode CLEARED for embeddings")
if _redis_client is None:
return
try:
await _redis_client.delete(_OPENROUTER_ONLY_KEY)
except Exception:
logger.debug("Failed to delete openrouter_only flag", exc_info=True)
[docs]
def clear_openrouter_only_sync() -> None:
"""Sync variant: deactivate OpenRouter-only mode.
Mirrors :func:`clear_openrouter_only` for code paths that cannot await,
using the same sync Redis client that :func:`check_openrouter_only_sync`
initializes on first use.
"""
global _openrouter_only, _sync_redis_client
_openrouter_only = False
logger.info("OpenRouter-only mode CLEARED for embeddings (sync)")
if _sync_redis_client is None:
try:
from config import Config
cfg = Config.load()
if not cfg.redis_url:
return
ssl_kwargs = cfg.redis_connection_kwargs_for_url(cfg.redis_url)
_sync_redis_client = __import__("redis").Redis.from_url(
cfg.redis_url,
decode_responses=True,
**ssl_kwargs,
)
except Exception:
logger.debug(
"Failed to create sync Redis for openrouter_only clear",
exc_info=True,
)
return
try:
_sync_redis_client.delete(_OPENROUTER_ONLY_KEY)
except Exception:
logger.debug(
"Failed to delete openrouter_only flag (sync)",
exc_info=True,
)
async def _auto_clear_openrouter_only_on_parse_error(
parse_err: "OpenRouterEmbedParseError",
) -> None:
"""If we're currently pinned to OpenRouter-only and it just parse-errored,
clear the flag so the next call can try Gemini direct again.
Rationale: the flag was set because Gemini direct 429'd. If OpenRouter is
now also failing (200-with-error-body), staying pinned there is actively
destroying vectors. Clearing lets the native Gemini path retry; if it is
also still broken, the existing 429-escalation trigger re-sets the flag.
"""
if not _openrouter_only:
return
logger.warning(
"OpenRouter parse-error while pinned to OpenRouter-only — auto-clearing "
"flag to allow Gemini-direct retry on next call",
)
await clear_openrouter_only()
try:
await publish_http_error_event(
http_service="openrouter_embed_batch",
http_status=0,
endpoint=_OPENROUTER_EMBED_URL[:120],
detail=f"auto-cleared openrouter_only after parse error: "
f"{str(parse_err)[:300]}",
error_kind="openrouter_only_cleared",
)
except Exception:
logger.debug("Failed to publish openrouter_only_cleared event", exc_info=True)
def _auto_clear_openrouter_only_on_parse_error_sync(
parse_err: "OpenRouterEmbedParseError",
) -> None:
"""Sync mirror of :func:`_auto_clear_openrouter_only_on_parse_error`.
When OpenRouter parse-errors while the pool is pinned to OpenRouter-only,
staying pinned would keep destroying vectors, so this clears the flag (via
:func:`clear_openrouter_only_sync`) to let the next call retry Gemini-direct,
and publishes an ``openrouter_only_cleared`` observability event through
:func:`_publish_http_error_sync`. No-ops when the flag is not set.
Called only within this module, by :func:`openrouter_embed_batch_sync` on its
two parse-error branches.
Args:
parse_err: The parse error that triggered the auto-clear, summarized into
the published event detail.
"""
if not _openrouter_only:
return
logger.warning(
"OpenRouter parse-error while pinned to OpenRouter-only (sync) — "
"auto-clearing flag to allow Gemini-direct retry on next call",
)
clear_openrouter_only_sync()
try:
_publish_http_error_sync(
http_service="openrouter_embed_batch_sync",
http_status=0,
endpoint=_OPENROUTER_EMBED_URL[:120],
detail=f"auto-cleared openrouter_only after parse error: "
f"{str(parse_err)[:300]}",
error_kind="openrouter_only_cleared",
)
except Exception:
logger.debug(
"Failed to publish openrouter_only_cleared event (sync)",
exc_info=True,
)
_RETRIABLE_STATUSES = {429, 500, 502, 503, 504}
_POOL_MAX_RETRIES = 12
_POOL_RETRY_BASE = 1.0
_POOL_RETRY_CAP = 8.0
async def _openrouter_then_paid_fallback(
batch_texts: list[str],
*,
model: str,
dimensions: int = EMBED_DIMENSIONS,
task_type: str | None = None,
pin_openrouter_only: bool,
) -> list[list[float]] | None:
"""Try OpenRouter first, then the paid Gemini key.
Returns the resolved vectors on success, or ``None`` if both fallbacks
failed (so the caller can keep retrying its primary path).
When ``pin_openrouter_only`` is True and OpenRouter succeeds, the
``embed:openrouter_only`` flag is set so subsequent calls skip the
free pool for the remainder of the TTL.
"""
try:
or_vecs = await openrouter_embed_batch(
batch_texts,
model=model,
dimensions=dimensions,
)
if pin_openrouter_only:
await set_openrouter_only()
return or_vecs
except Exception:
logger.warning(
"OpenRouter embed fallback failed — trying paid Gemini key",
exc_info=True,
)
try:
return await gemini_embed_paid_fallback(
batch_texts,
model=model,
dimensions=dimensions,
task_type=task_type,
)
except Exception:
logger.warning(
"Paid Gemini key fallback also failed",
exc_info=True,
)
return None
[docs]
async def embed_batch_via_gemini(
texts: list[str],
model: str = "google/gemini-embedding-001",
*,
chunk_size: int = 50,
) -> list[list[float]]:
"""Embed a batch of texts via the native Gemini API using the shared key pool.
Returns one embedding vector per input text, in order. Empty or whitespace-only
texts are replaced with zero vectors. Retries on transient errors.
Cascade on failure (after ``PAID_KEY_FALLBACK_THRESHOLD`` consecutive
non-daily 429s on the free pool):
1. OpenRouter via :func:`openrouter_embed_batch` (and pin
``openrouter_only`` for the next 30 minutes).
2. Paid tier-3 Gemini key via :func:`gemini_embed_paid_fallback`
(last resort).
When ``openrouter_only`` is already pinned, OpenRouter is tried first;
if it fails on a specific batch, the paid Gemini key is tried before
raising.
"""
if not texts:
return []
valid_indices: list[int] = []
valid_texts: list[str] = []
for i, t in enumerate(texts):
if t and t.strip():
valid_indices.append(i)
valid_texts.append(t)
if not valid_texts:
return [[0.0] * EMBED_DIMENSIONS for _ in texts]
async def _fill_result_from(
result: list[list[float]],
indices: list[int],
vecs: list[list[float]],
) -> None:
"""Scatter computed vectors back into their original input positions.
Writes each vector in *vecs* into *result* at the corresponding
``valid_texts`` slot recorded in *indices*, mutating *result* in place
so empty/whitespace inputs keep their pre-filled zero vectors. The
``async`` signature is purely for call-site uniformity; it performs no
I/O or awaits.
Called only within the enclosing :func:`embed_batch_via_gemini` after a
chunk resolves through the OpenRouter/paid fallback paths (the native
Gemini success branch scatters inline instead).
Args:
result: The full output list (length of the original ``texts``),
pre-seeded with zero vectors, mutated in place.
indices: Original-input positions for each vector in *vecs*.
vecs: Embedding vectors aligned positionally with *indices*.
"""
for idx, vec in zip(indices, vecs):
result[idx] = vec
if await check_openrouter_only():
logger.info(
"OpenRouter-only mode — bypassing Gemini for %d texts",
len(valid_texts),
)
result: list[list[float]] = [[0.0] * EMBED_DIMENSIONS for _ in texts]
for start in range(0, len(valid_texts), chunk_size):
batch_texts = valid_texts[start : start + chunk_size]
batch_indices = valid_indices[start : start + chunk_size]
vecs = await _openrouter_then_paid_fallback(
batch_texts,
model=model,
pin_openrouter_only=False, # already pinned
)
if vecs is None:
raise RuntimeError(
"embed_batch_via_gemini: openrouter_only pinned but "
"OpenRouter and paid Gemini key both failed",
)
await _fill_result_from(result, batch_indices, vecs)
return result
gemini_model = _gemini_model_name(model)
api_key = next_gemini_embed_key()
url = f"{GEMINI_EMBED_BASE}/{gemini_model}:batchEmbedContents" f"?key={api_key}"
result: list[list[float]] = [[0.0] * EMBED_DIMENSIONS for _ in texts]
async with httpx.AsyncClient(timeout=60.0) as client:
for start in range(0, len(valid_texts), chunk_size):
batch_texts = valid_texts[start : start + chunk_size]
batch_indices = valid_indices[start : start + chunk_size]
if await check_openrouter_only():
vecs = await _openrouter_then_paid_fallback(
batch_texts,
model=model,
pin_openrouter_only=False,
)
if vecs is None:
raise RuntimeError(
"embed_batch_via_gemini: openrouter_only pinned mid-loop "
"but OpenRouter and paid Gemini key both failed",
)
await _fill_result_from(result, batch_indices, vecs)
continue
requests_list = [
{
"model": f"models/{gemini_model}",
"content": {"parts": [{"text": t}]},
"output_dimensionality": EMBED_DIMENSIONS,
}
for t in batch_texts
]
payload = {"requests": requests_list}
attempt = 0
consecutive_429 = 0
tried_or_then_paid = False
cur_url = url
cur_key = api_key
while attempt < _POOL_MAX_RETRIES:
if attempt > 0:
delay = min(
_POOL_RETRY_BASE * (2 ** (attempt - 1)),
_POOL_RETRY_CAP,
)
await asyncio.sleep(delay)
try:
resp = await client.post(cur_url, json=payload)
except Exception:
attempt += 1
continue
await record_key_usage(cur_key)
if resp.status_code == 429:
if is_daily_quota_429(resp):
await mark_key_daily_spent(cur_key, "embed")
cur_key = next_gemini_embed_key()
cur_url = (
f"{GEMINI_EMBED_BASE}/{gemini_model}"
f":batchEmbedContents?key={cur_key}"
)
attempt += 1
continue
consecutive_429 += 1
if (
consecutive_429 >= PAID_KEY_FALLBACK_THRESHOLD
and not tried_or_then_paid
):
logger.warning(
"Free Gemini pool 429'd %d times — escalating to "
"OpenRouter, then paid key",
consecutive_429,
)
tried_or_then_paid = True
vecs = await _openrouter_then_paid_fallback(
batch_texts,
model=model,
pin_openrouter_only=True,
)
if vecs is not None:
await _fill_result_from(
result,
batch_indices,
vecs,
)
break
# Both fallbacks failed; keep retrying primary path
# for any remaining attempts in this loop.
attempt += 1
continue
if resp.status_code in _RETRIABLE_STATUSES:
attempt += 1
continue
resp.raise_for_status()
data = resp.json()
for idx, item in enumerate(data.get("embeddings", [])):
vec = item.get("values", [])
if idx < len(batch_indices):
result[batch_indices[idx]] = vec
break
else:
raise RuntimeError(
f"embed_batch_via_gemini failed after "
f"{_POOL_MAX_RETRIES} attempts (free pool, OpenRouter, "
f"and paid Gemini key all failed)",
)
return result
[docs]
async def batch_check_keys_usage(
redis_client: Any, api_keys: list[str]
) -> dict[str, int]:
"""Fetch the daily usage counts for many API keys in one pipelined Redis MGET.
Reads ``spent:api_key:<suffix>`` for every supplied key (keyed by the last
8 characters for privacy) in a single round trip and maps each full key to
its integer count, defaulting missing entries to 0. Note this uses a
different key namespace from the ``gemini_key_daily_usage:*`` counters written
by :func:`record_key_usage`, so it reflects a separately maintained tally.
No production callers were found in the repo; it is currently exercised by
``tests/test_context_whitelisting.py``.
Args:
redis_client: An async Redis client supporting ``mget``.
api_keys: The full API keys to look up.
Returns:
dict[str, int]: A mapping from each input API key to its usage count.
"""
key_prefixes = [
f"spent:api_key:{k[-8:]}" for k in api_keys
] # Verify last 8 chars for privacy
# Execute batch MGET pipeline
results = await redis_client.mget(key_prefixes)
usage_map: dict[str, int] = {}
for api_key, result in zip(api_keys, results):
usage_map[api_key] = int(result) if result is not None else 0
return usage_map