"""Async LLM API client with automatic tool-call loop.
Uses any OpenAI-compatible chat-completions endpoint (configurable via
``base_url``). Embeddings use the native Google Gemini API only, via the
shared key pool in gemini_embed_pool.
"""
from __future__ import annotations
import asyncio
import json
import logging
import re
import time
import uuid
from collections import Counter
from typing import Any, Awaitable, Callable, TYPE_CHECKING
import os
import httpx
from gemini_embed_pool import (
EMBED_DIMENSIONS,
GEMINI_EMBED_BASE,
PAID_KEY_FALLBACK_THRESHOLD,
check_openrouter_only,
get_paid_fallback_key,
is_daily_quota_429,
mark_key_daily_spent,
next_gemini_embed_key,
openrouter_embed_batch,
record_key_usage,
set_openrouter_only,
)
from platforms.media_common import media_to_content_parts
from tools import ToolRegistry
if TYPE_CHECKING:
from tool_context import ToolContext
logger = logging.getLogger(__name__)
# Default Gemini model id for countTokens (Google Generative Language API).
DEFAULT_GEMINI_COUNT_TOKENS_MODEL = "gemini-3.1-flash-lite-preview"
MAX_499_RETRIES = 3
MAX_HEADER_RETRIES = 1
MAX_TOOLS_PER_REQUEST = 250
EMBED_RETRY_BASE_DELAY = 1.0
MAX_EMBED_DELAY = 8.0
MAX_EMBED_RETRIES = 14
# HTTP status codes that are transient and should be retried
_RETRIABLE_STATUSES = {429, 500, 502, 503, 504}
_HEADER_REGEN_MESSAGE = (
"Your previous response did not comply with the system prompt's rules for "
"header generation. Responses MUST begin with a response header bracket "
"[`model` :: emojis :: thought :: `tools`]. Regenerate your response now, "
"starting with the required header."
)
# Decode large JSON bodies off the event loop (chat completions, big embed batches).
_JSON_BODY_DECODE_THRESHOLD = 256 * 1024
def _json_loads_utf8(body: bytes) -> Any:
return json.loads(body.decode("utf-8"))
async def _async_response_json(resp: httpx.Response) -> Any:
"""Read response body asynchronously; parse JSON in a thread if large."""
body = await resp.aread()
if len(body) >= _JSON_BODY_DECODE_THRESHOLD:
return await asyncio.to_thread(_json_loads_utf8, body)
return json.loads(body.decode("utf-8"))
[docs]
class OpenRouterClient:
"""Thin async wrapper around an OpenAI-compatible chat-completions endpoint.
Implements the full tool-call loop: when the model responds with
``tool_calls``, each tool is executed via the :class:`ToolRegistry`,
the results are appended, and the API is called again. This repeats
until the model produces a final text response (or a safety limit is
reached).
"""
[docs]
def __init__(
self,
api_key: str,
model: str = "x-ai/grok-4.1-fast",
temperature: float = 1.0,
max_tokens: int = 60_000,
tool_registry: ToolRegistry | None = None,
max_tool_rounds: int = 10,
base_url: str = "http://localhost:3000/openai",
gemini_api_key: str = "",
gemini_count_tokens_model: str = DEFAULT_GEMINI_COUNT_TOKENS_MODEL,
max_tool_output_chars: int = 150_000,
) -> None:
"""Initialize the instance.
Args:
api_key (str): The api key value.
model (str): The model value.
temperature (float): The temperature value.
max_tokens (int): The max tokens value.
tool_registry (ToolRegistry | None): The tool registry value.
max_tool_rounds (int): The max tool rounds value.
base_url (str): The base url value.
gemini_api_key (str): The gemini api key value.
gemini_count_tokens_model (str): Model id for Gemini countTokens API.
max_tool_output_chars (int): Max tool result string length before
truncation in the chat loop; ``<= 0`` disables the cap.
"""
self.api_key = api_key
self.gemini_api_key = gemini_api_key or self._resolve_gemini_api_key()
self.gemini_count_tokens_model = gemini_count_tokens_model
self.model = model
self.temperature = temperature
self.max_tokens = max_tokens
self.tool_registry = tool_registry or ToolRegistry()
self.max_tool_rounds = max_tool_rounds
self.max_tool_output_chars = max_tool_output_chars
self._chat_url = base_url.rstrip("/") + "/chat/completions"
self._http = httpx.AsyncClient(
timeout=httpx.Timeout(300.0, connect=10.0),
)
@staticmethod
def _resolve_gemini_api_key() -> str:
"""Return the Gemini API key from the environment or config file."""
key = os.getenv("GEMINI_API_KEY", "")
if key:
return key
try:
from config import Config
cfg = Config.load()
return cfg.gemini_api_key or ""
except Exception:
pass
return ""
# ------------------------------------------------------------------
# Token counting via Gemini API
# ------------------------------------------------------------------
async def _count_tokens(
self,
messages: list[dict[str, Any]],
*,
gemini_model: str | None = None,
) -> int | None:
"""Count input tokens using the Gemini countTokens API.
Converts OpenAI-format messages to Gemini ``contents`` format
and calls the free countTokens endpoint. Returns the total
token count or ``None`` on any failure (fail-open).
"""
if not self.gemini_api_key:
return None
model_id = gemini_model or self.gemini_count_tokens_model
# Convert OpenAI messages to Gemini contents format
contents = []
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
if not content:
continue
# Map OpenAI roles to Gemini roles
if role in ("system", "assistant"):
gemini_role = "model"
else:
gemini_role = "user"
if isinstance(content, list):
# Multi-part content: extract text parts
text = " ".join(
p.get("text", "") for p in content
if p.get("type") == "text"
)
if not text:
continue
contents.append({
"role": gemini_role,
"parts": [{"text": text}],
})
else:
contents.append({
"role": gemini_role,
"parts": [{"text": str(content)}],
})
if not contents:
return None
# Ensure alternating roles (Gemini requirement)
deduped: list[dict] = [contents[0]]
for c in contents[1:]:
if c["role"] == deduped[-1]["role"]:
# Merge into previous message
deduped[-1]["parts"].extend(c["parts"])
else:
deduped.append(c)
url = (
f"https://generativelanguage.googleapis.com/v1beta/"
f"models/{model_id}:countTokens"
f"?key={self.gemini_api_key}"
)
payload = {"contents": deduped}
try:
resp = await self._http.post(
url, json=payload, timeout=5.0,
)
if resp.status_code == 200:
data = await _async_response_json(resp)
count = data.get("totalTokens")
if isinstance(count, int):
logger.info(
"Gemini countTokens: %d tokens", count,
)
return count
else:
logger.debug(
"countTokens returned %d: %s",
resp.status_code,
resp.text[:200],
)
except Exception:
logger.debug(
"countTokens call failed", exc_info=True,
)
return None
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
[docs]
async def chat(
self,
messages: list[dict[str, Any]],
user_id: str = "",
ctx: ToolContext | None = None,
tool_names: list[str] | None = None,
validate_header: bool = False,
token_count: int | None = None,
on_intermediate_text: Callable[[str], Awaitable[None]] | None = None,
) -> str:
"""Send *messages* to the LLM and return the final assistant text.
If the model requests tool calls, they are executed automatically and
the conversation is continued until a text response is produced.
*user_id* is forwarded to :meth:`ToolRegistry.call` for permission
checking. *ctx* is forwarded to tools that opt-in to receiving it.
*tool_names*, when provided, restricts which tools the LLM sees
to the given subset of registered tool names.
*token_count*, when provided, is used directly instead of calling
``_count_tokens`` β allows the caller to pre-compute the count
concurrently with other work.
*on_intermediate_text*, when provided, is called with any text
content the model produces alongside tool calls. Without this
callback, such text is silently carried in the conversation
history but never surfaced to the user.
"""
msgs = [dict(m) for m in messages]
call_history: list[tuple[str, str]] = [] # (name, args_hash)
total_calls = 0
# Count input tokens β use pre-computed value when available,
# otherwise call Gemini API (serial fallback).
if token_count is None:
token_count = await self._count_tokens(msgs)
_count_str = str(token_count) if token_count is not None else "unavailable"
for _m in msgs:
if _m.get("role") == "system" and "__INPUT_TOKEN_COUNT__" in str(_m.get("content", "")):
_m["content"] = _m["content"].replace(
"__INPUT_TOKEN_COUNT__", _count_str,
)
break
round_num = 0
while round_num < self.max_tool_rounds:
logger.info(
"LLM API call: round=%d, model=%s, messages=%d",
round_num + 1, self.model, len(msgs),
)
_t0 = time.monotonic()
response_message = await self._call_api(
msgs, tool_names=tool_names, user_id=user_id,
)
response_message = self._normalize_assistant_tool_message(
response_message,
)
logger.info(
"LLM API call round=%d completed in %.0f ms",
round_num + 1, (time.monotonic() - _t0) * 1000,
)
tool_calls = response_message.get("tool_calls")
# Tool rounds (including legacy function_call β tool_calls) carry
# optional narration in content; never run header validation on that.
if not tool_calls:
content = response_message.get("content") or ""
if content.startswith("[PROXY ERROR]"):
logger.warning("Proxy returned an error response:\n%s", content)
return content
# --- Header validation: final text only (not tool-round narration) ---
if validate_header and content and not self._validate_header(content):
content = await self._regenerate_with_header_hint(
msgs, content, tool_names=None, user_id=user_id,
)
return content
# --- Emit intermediate text alongside tool calls ------
intermediate_text = (response_message.get("content") or "").strip()
if intermediate_text and on_intermediate_text:
logger.debug(
"Invoking on_intermediate_text (%d raw chars, %d tool call(s))",
len(intermediate_text), len(tool_calls),
)
try:
await on_intermediate_text(intermediate_text)
except Exception:
logger.warning(
"on_intermediate_text callback failed",
exc_info=True,
)
msgs.append(response_message)
# --- Loop / excessive-call detection --------------------------
round_keys: list[tuple[str, str]] = []
for tc in tool_calls:
fn = tc["function"]
name = fn["name"]
args_str = fn.get("arguments", "")
round_keys.append((name, args_str))
call_history.extend(round_keys)
total_calls += len(tool_calls)
if total_calls > 200:
logger.warning("Excessive tool calls (%d), aborting", total_calls)
return "(Stopped: too many tool calls.)"
exempt = self.tool_registry.repeat_allowed_tools()
filtered_history = [
(n, a) for n, a in call_history if n not in exempt
]
if self._detect_loop(filtered_history):
logger.warning("Repetitive tool-call loop detected, aborting")
return "(Stopped: repetitive tool-call loop detected.)"
# --- Execute tool calls in parallel ---------------------------
t0 = time.monotonic()
sent_files_before = len(ctx.sent_files) if ctx else 0
async def _run_one(tc: dict) -> tuple[str, str]:
"""Internal helper: run one.
Args:
tc (dict): The tc value.
Returns:
tuple[str, str]: The result.
"""
fn = tc["function"]
tool_name = fn["name"]
try:
arguments = json.loads(fn["arguments"])
arguments = {
k.strip("'\""): v for k, v in arguments.items()
}
except (json.JSONDecodeError, TypeError):
arguments = {}
logger.debug(
"Calling tool %s with %s", tool_name, arguments,
)
result = await self.tool_registry.call(
tool_name, arguments, user_id=user_id, ctx=ctx,
)
return tc["id"], result
results = await asyncio.gather(
*[_run_one(tc) for tc in tool_calls],
)
elapsed_ms = (time.monotonic() - t0) * 1000
logger.info(
"Round %d: executed %d tool(s) in %.0f ms",
round_num + 1, len(tool_calls), elapsed_ms,
)
new_files = (
ctx.sent_files[sent_files_before:]
if ctx and sent_files_before < len(ctx.sent_files)
else []
)
for idx, (call_id, result) in enumerate(results):
remaining = self.max_tool_rounds - round_num - 1
# Cap tool output to avoid blowing past API request limits
cap = self.max_tool_output_chars
if cap > 0 and len(result) > cap:
_orig_len = len(result)
result = (
result[:cap]
+ f"\n\n[TRUNCATED β output was {_orig_len:,} chars, "
f"showing first {cap:,}]"
)
logger.warning(
"Truncated tool output from %d to %d chars",
_orig_len, cap,
)
text_content = f"[Remaining tool rounds: {remaining}]\n{result}"
# Inline multimodal: attach sent_files media to first tool result
# so the model sees screenshots/audio/video during the same turn.
if idx == 0 and new_files:
content_parts: list[dict[str, Any]] = [
{"type": "text", "text": text_content},
]
for sf in new_files:
sf_parts = await media_to_content_parts(
sf["data"], sf["mimetype"], sf["filename"],
)
content_parts.extend(sf_parts)
msgs.append({
"role": "tool",
"tool_call_id": call_id,
"content": content_parts,
})
else:
msgs.append({
"role": "tool",
"tool_call_id": call_id,
"content": text_content,
})
# --- Merge dynamically injected tools -------------------------
if (
ctx is not None
and ctx.injected_tools
and tool_names is not None
):
existing = set(tool_names)
new_tools = [
t for t in ctx.injected_tools if t not in existing
]
if new_tools:
tool_names.extend(new_tools)
logger.info(
"Injected %d tool(s) into active set: %s",
len(new_tools), new_tools,
)
ctx.injected_tools = None # reset for next round
round_num += 1
logger.warning("Reached max tool-call rounds (%d)", self.max_tool_rounds)
# --- Final round: call LLM with no tools so it can respond --------
msgs.append({
"role": "user",
"content": (
"[SYSTEM] You have reached the maximum number of allowed tool calls "
f"({self.max_tool_rounds} rounds) for this cycle. You cannot "
"make any more tool calls. Please generate your final response "
"to the user now, summarizing what you accomplished and noting "
"anything you were unable to complete."
),
})
logger.info(
"LLM API call: final no-tools round, model=%s, messages=%d",
self.model, len(msgs),
)
_t0 = time.monotonic()
try:
final_response = await self._call_api(
msgs, tool_names=[], user_id=user_id,
)
logger.info(
"LLM API final no-tools round completed in %.0f ms",
(time.monotonic() - _t0) * 1000,
)
content = final_response.get("content") or "(max tool-call rounds reached)"
except Exception as exc:
logger.error("Final no-tools round failed: %s", exc)
content = response_message.get("content") or "(max tool-call rounds reached)"
if content.startswith("[PROXY ERROR]"):
logger.warning("Proxy returned an error response:\n%s", content)
return content
# --- Header validation on final-round response ---
if validate_header and content and not self._validate_header(content):
content = await self._regenerate_with_header_hint(
msgs, content, tool_names=[], user_id=user_id,
)
return content
[docs]
async def embed(self, text: str, model: str) -> list[float]:
"""Generate an embedding vector for *text* via the Gemini API.
Uses the shared key pool for rate-limit distribution. Retries with
exponential back-off (capped at ``MAX_EMBED_DELAY``) up to
``MAX_EMBED_RETRIES`` times before raising.
Raises :class:`ValueError` immediately if *text* is empty or
whitespace-only (the embedding API rejects such input with 400).
Parameters
----------
text:
The text to embed.
model:
The embedding model identifier (e.g. ``"google/gemini-embedding-001"``).
Returns
-------
list[float]
The embedding vector.
"""
if not text or not text.strip():
raise ValueError("Cannot embed empty or whitespace-only text")
round_num = 0
last_error: str | None = None
while round_num < MAX_EMBED_RETRIES:
try:
return await self._embed_gemini(text, model)
except ValueError:
raise # non-retriable (bad input), don't retry
except Exception as exc:
last_error = str(exc)
logger.warning(
"Gemini embed failed (round %d): %s",
round_num + 1, exc,
)
delay = min(EMBED_RETRY_BASE_DELAY * (2 ** round_num), MAX_EMBED_DELAY)
await asyncio.sleep(delay)
round_num += 1
raise RuntimeError(
f"Gemini embed failed after {MAX_EMBED_RETRIES} rounds: {last_error}"
)
[docs]
async def embed_batch(
self, texts: list[str], model: str,
) -> list[list[float]]:
"""Generate embedding vectors for multiple texts via the Gemini API.
Uses the shared key pool. Empty or whitespace-only texts are filtered
out; their positions are filled with zero vectors. Retries with
exponential back-off up to ``MAX_EMBED_RETRIES`` times.
Parameters
----------
texts:
List of texts to embed.
model:
The embedding model identifier.
Returns
-------
list[list[float]]
One embedding vector per input text, in the same order.
"""
if not texts:
return []
# Filter out blank texts, keeping track of original indices.
valid_indices: list[int] = []
valid_texts: list[str] = []
for i, t in enumerate(texts):
if t and t.strip():
valid_indices.append(i)
valid_texts.append(t)
if not valid_texts:
logger.warning(
"embed_batch called with %d text(s), all empty/blank β returning zero vectors",
len(texts),
)
return [[0.0] * EMBED_DIMENSIONS for _ in texts]
if len(valid_texts) < len(texts):
logger.info(
"embed_batch: filtered %d blank text(s) from batch of %d",
len(texts) - len(valid_texts), len(texts),
)
round_num = 0
last_error: str | None = None
while round_num < MAX_EMBED_RETRIES:
try:
valid_embeddings = await self._embed_gemini_batch(valid_texts, model)
result: list[list[float]] = [[0.0] * EMBED_DIMENSIONS for _ in texts]
for idx, emb in zip(valid_indices, valid_embeddings):
result[idx] = emb
return result
except ValueError:
raise # non-retriable (bad input), don't retry
except Exception as exc:
last_error = str(exc)
logger.warning(
"Gemini embed_batch failed (round %d): %s",
round_num + 1, exc,
)
delay = min(EMBED_RETRY_BASE_DELAY * (2 ** round_num), MAX_EMBED_DELAY)
await asyncio.sleep(delay)
round_num += 1
raise RuntimeError(
f"Gemini embed_batch failed after {MAX_EMBED_RETRIES} rounds: {last_error}"
)
# ------------------------------------------------------------------
# Embedding provider implementations
# ------------------------------------------------------------------
@staticmethod
def _gemini_model_name(model: str) -> str:
"""Convert OpenRouter model name to Gemini native model name."""
return model.removeprefix("google/")
async def _embed_gemini(
self, text: str, model: str, task_type: str | None = None,
) -> list[float]:
"""Embed a single text via the native Gemini API.
Falls back to the paid key after ``PAID_KEY_FALLBACK_THRESHOLD``
consecutive 429 responses.
Args:
text: Text to embed.
model: Embedding model identifier.
task_type: Optional Gemini task type (e.g. ``QUESTION_ANSWERING``).
"""
if await check_openrouter_only():
logger.info("OpenRouter-only mode β bypassing Gemini for single text")
vecs = await openrouter_embed_batch(
[text], model=model, api_key=self.api_key,
)
return vecs[0]
gemini_model = self._gemini_model_name(model)
api_key = next_gemini_embed_key()
url = f"{GEMINI_EMBED_BASE}/{gemini_model}:embedContent?key={api_key}"
payload: dict = {
"model": f"models/{gemini_model}",
"content": {"parts": [{"text": text}]},
"output_dimensionality": EMBED_DIMENSIONS,
}
if task_type:
payload["taskType"] = task_type
attempt = 0
consecutive_429 = 0
switched_to_paid = False
while attempt < MAX_EMBED_RETRIES:
if attempt > 0:
delay = min(
EMBED_RETRY_BASE_DELAY * (2 ** (attempt - 1)),
MAX_EMBED_DELAY,
)
logger.warning(
"Gemini embed error, retrying in %.1fs (attempt %d): %s",
delay, attempt + 1, last_error,
)
await asyncio.sleep(delay)
last_error: str | None = None
try:
_t0 = time.monotonic()
resp = await self._http.post(url, json=payload)
logger.info(
"Gemini embed HTTP request completed in %.0f ms (status=%d)",
(time.monotonic() - _t0) * 1000,
resp.status_code,
)
await record_key_usage(api_key)
if resp.status_code == 400:
body = resp.text[:500]
logger.error(
"Gemini embed 400 Bad Request (non-retriable): %s",
body,
)
raise ValueError(
f"Gemini embed 400 Bad Request: {body}"
)
if resp.status_code == 429:
if is_daily_quota_429(resp):
await mark_key_daily_spent(api_key, "embed")
api_key = next_gemini_embed_key()
url = (
f"{GEMINI_EMBED_BASE}/{gemini_model}"
f":embedContent?key={api_key}"
)
last_error = "HTTP 429 (daily quota)"
attempt += 1
continue
consecutive_429 += 1
if (
consecutive_429 >= PAID_KEY_FALLBACK_THRESHOLD
and not switched_to_paid
):
paid = get_paid_fallback_key()
if paid:
logger.warning(
"Switching to paid Gemini key after %d "
"consecutive 429s",
consecutive_429,
)
api_key = paid
url = (
f"{GEMINI_EMBED_BASE}/{gemini_model}"
f":embedContent?key={paid}"
)
switched_to_paid = True
elif switched_to_paid:
await set_openrouter_only()
try:
logger.warning(
"Paid Gemini key 429'd β trying OpenRouter",
)
vecs = await openrouter_embed_batch(
[text], model=model,
api_key=self.api_key,
)
return vecs[0]
except Exception:
logger.warning(
"OpenRouter embed fallback also failed",
exc_info=True,
)
last_error = "HTTP 429"
attempt += 1
continue
if resp.status_code in _RETRIABLE_STATUSES:
last_error = f"HTTP {resp.status_code}"
attempt += 1
continue
resp.raise_for_status()
data = await _async_response_json(resp)
return data["embedding"]["values"]
except ValueError:
raise # non-retriable, propagate immediately
except Exception as exc:
last_error = str(exc)
attempt += 1
raise RuntimeError(
f"Gemini embed failed after {MAX_EMBED_RETRIES} attempts: {last_error}"
)
async def _embed_gemini_batch(
self, texts: list[str], model: str, task_type: str | None = None,
) -> list[list[float]]:
"""Embed a batch of texts via the native Gemini API.
Falls back to the paid key after ``PAID_KEY_FALLBACK_THRESHOLD``
consecutive 429 responses.
Args:
texts: Texts to embed.
model: Embedding model identifier.
task_type: Optional Gemini task type (e.g. ``RETRIEVAL_DOCUMENT``).
"""
if await check_openrouter_only():
logger.info("OpenRouter-only mode β bypassing Gemini for %d texts", len(texts))
return await openrouter_embed_batch(
texts, model=model, api_key=self.api_key,
)
gemini_model = self._gemini_model_name(model)
api_key = next_gemini_embed_key()
url = (
f"{GEMINI_EMBED_BASE}/{gemini_model}:batchEmbedContents"
f"?key={api_key}"
)
requests_list = []
for t in texts:
req: dict = {
"model": f"models/{gemini_model}",
"content": {"parts": [{"text": t}]},
"output_dimensionality": EMBED_DIMENSIONS,
}
if task_type:
req["taskType"] = task_type
requests_list.append(req)
payload = {"requests": requests_list}
attempt = 0
consecutive_429 = 0
switched_to_paid = False
while attempt < MAX_EMBED_RETRIES:
if attempt > 0:
delay = min(
EMBED_RETRY_BASE_DELAY * (2 ** (attempt - 1)),
MAX_EMBED_DELAY,
)
logger.warning(
"Gemini batch embed error, retrying in %.1fs (attempt %d): %s",
delay, attempt + 1, last_error,
)
await asyncio.sleep(delay)
last_error: str | None = None
try:
_t0 = time.monotonic()
resp = await self._http.post(url, json=payload)
logger.info(
"Gemini batch embed HTTP request completed in %.0f ms (status=%d)",
(time.monotonic() - _t0) * 1000,
resp.status_code,
)
await record_key_usage(api_key)
if resp.status_code == 400:
body = resp.text[:500]
logger.error(
"Gemini batch embed 400 Bad Request (non-retriable): %s",
body,
)
raise ValueError(
f"Gemini batch embed 400 Bad Request: {body}"
)
if resp.status_code == 429:
if is_daily_quota_429(resp):
await mark_key_daily_spent(api_key, "embed")
api_key = next_gemini_embed_key()
url = (
f"{GEMINI_EMBED_BASE}/{gemini_model}"
f":batchEmbedContents?key={api_key}"
)
last_error = "HTTP 429 (daily quota)"
attempt += 1
continue
consecutive_429 += 1
if (
consecutive_429 >= PAID_KEY_FALLBACK_THRESHOLD
and not switched_to_paid
):
paid = get_paid_fallback_key()
if paid:
logger.warning(
"Switching to paid Gemini key after %d "
"consecutive 429s",
consecutive_429,
)
api_key = paid
url = (
f"{GEMINI_EMBED_BASE}/{gemini_model}"
f":batchEmbedContents?key={paid}"
)
switched_to_paid = True
elif switched_to_paid:
await set_openrouter_only()
try:
logger.warning(
"Paid Gemini key 429'd β trying OpenRouter",
)
return await openrouter_embed_batch(
texts, model=model,
api_key=self.api_key,
)
except Exception:
logger.warning(
"OpenRouter embed fallback also failed",
exc_info=True,
)
last_error = "HTTP 429"
attempt += 1
continue
if resp.status_code in _RETRIABLE_STATUSES:
last_error = f"HTTP {resp.status_code}"
attempt += 1
continue
resp.raise_for_status()
data = await _async_response_json(resp)
return [item["values"] for item in data["embeddings"]]
except ValueError:
raise # non-retriable, propagate immediately
except Exception as exc:
last_error = str(exc)
attempt += 1
raise RuntimeError(
f"Gemini batch embed failed after {MAX_EMBED_RETRIES} attempts: {last_error}"
)
# ------------------------------------------------------------------
# Header validation helpers
# ------------------------------------------------------------------
@staticmethod
def _normalize_assistant_tool_message(
message: dict[str, Any],
) -> dict[str, Any]:
"""Map legacy ``function_call`` to ``tool_calls`` when needed.
Some OpenAI-compatible providers still return the pre-tool-use
``function_call`` field instead of ``tool_calls``. Without this,
the loop treats the turn as a final text response and runs header
validation on optional narration alongside the tool invocation.
"""
existing = message.get("tool_calls")
if existing:
return message
fc = message.get("function_call")
if not isinstance(fc, dict):
return message
name = fc.get("name")
if not name:
return message
args = fc.get("arguments")
if not isinstance(args, str):
args = json.dumps(args) if args is not None else "{}"
out = dict(message)
out["tool_calls"] = [{
"id": f"legacy_{uuid.uuid4().hex[:16]}",
"type": "function",
"function": {"name": name, "arguments": args},
}]
out.pop("function_call", None)
return out
@staticmethod
def _validate_header(content: str) -> bool:
"""Return *True* if *content* starts with the required ``[`` header.
Thought / thinking tags are stripped first so that wrapped responses
like ``<thinking>...\n</thinking>\n[header ...]`` still pass.
"""
# Strip closed thought/thinking/glitch blocks first
cleaned = re.sub(r"<thinking>.*?</thinking>", "", content, flags=re.DOTALL)
cleaned = re.sub(r"<thought>.*?</thought>", "", cleaned, flags=re.DOTALL)
cleaned = re.sub(r"π‘thought.*?</font>", "", cleaned, flags=re.DOTALL)
# Strip unclosed opening tags (truncated or never closed) β
# remove from the tag to end-of-string, then check what remains
cleaned = re.sub(r"<thinking>.*", "", cleaned, flags=re.DOTALL)
cleaned = re.sub(r"<thought>.*", "", cleaned, flags=re.DOTALL)
cleaned = cleaned.lstrip()
if not cleaned:
# Content was entirely thinking tags β the model produced only
# internal reasoning with no user-facing text. Regenerating
# would just waste an API call; accept as-is.
has_thinking = bool(
re.search(r"<think(?:ing|ought)>", content)
)
return has_thinking
return cleaned.startswith("[")
async def _regenerate_with_header_hint(
self,
msgs: list[dict[str, Any]],
bad_content: str,
*,
tool_names: list[str] | None,
user_id: str,
) -> str:
"""Re-call the LLM once with an ephemeral correction message.
Appends the failed assistant response and a system hint to *msgs*,
then calls ``_call_api`` with no tools. Returns whatever the model
produces (even if it still lacks a header β we only retry once).
"""
logger.warning(
"Response missing required header (starts with %r), "
"regenerating with header hint",
bad_content,
)
# Append the bad response so the model sees what it did wrong
msgs.append({"role": "assistant", "content": bad_content})
msgs.append({"role": "system", "content": _HEADER_REGEN_MESSAGE})
try:
_t0 = time.monotonic()
retry_response = await self._call_api(
msgs, tool_names=[], user_id=user_id,
)
logger.info(
"Header-hint regeneration completed in %.0f ms",
(time.monotonic() - _t0) * 1000,
)
return retry_response.get("content") or bad_content
except Exception as exc:
logger.error("Header-hint regeneration failed: %s", exc)
return bad_content
[docs]
async def close(self) -> None:
"""Close.
"""
await self._http.aclose()
# ------------------------------------------------------------------
# Loop detection
# ------------------------------------------------------------------
@staticmethod
def _detect_loop(
history: list[tuple[str, str]], threshold: int = 3,
) -> bool:
"""Return *True* if any (name, args) pair appears >= *threshold* times."""
counts = Counter(history)
return any(c >= threshold for c in counts.values())
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
async def _call_api(
self,
messages: list[dict[str, Any]],
tool_names: list[str] | None = None,
user_id: str = "",
) -> dict[str, Any]:
"""Make a single chat-completions request and return the first choice message."""
payload: dict[str, Any] = {
"model": self.model,
"messages": messages,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
}
if user_id:
payload["user"] = user_id
if self.tool_registry.has_tools and tool_names is not None:
tools = self.tool_registry.get_openai_tools_by_names(set(tool_names))
if len(tools) > MAX_TOOLS_PER_REQUEST:
logger.warning(
"Tool count %d exceeds hard cap %d, truncating",
len(tools), MAX_TOOLS_PER_REQUEST,
)
tools = tools[:MAX_TOOLS_PER_REQUEST]
if tools:
payload["tools"] = tools
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"HTTP-Referer": "https://github.com/matrix-llm-bot",
"X-Title": "Matrix LLM Bot",
}
for _attempt in range(MAX_499_RETRIES):
resp = await self._http.post(
self._chat_url,
json=payload,
headers=headers,
)
if resp.status_code == 499 and _attempt < MAX_499_RETRIES - 1:
logger.warning(
"HTTP 499, retrying immediately (%d/%d)",
_attempt + 1, MAX_499_RETRIES,
)
continue
if resp.status_code >= 400:
_body = (resp.text or "")[:4000]
logger.error(
"LLM chat HTTP %s (non-success). Response body (truncated): %s",
resp.status_code,
_body,
)
resp.raise_for_status()
break
data = await _async_response_json(resp)
if "error" in data:
error_msg = data["error"].get("message", str(data["error"]))
raise RuntimeError(f"LLM API error: {error_msg}")
choices = data.get("choices")
if not choices:
raise RuntimeError(f"LLM API returned no choices: {data}")
return choices[0]["message"]