Source code for openrouter_client

"""Async LLM API client with automatic tool-call loop.

Uses any OpenAI-compatible chat-completions endpoint (configurable via
``base_url``).  Embeddings use the native Google Gemini API only, via the
shared key pool in gemini_embed_pool.
"""

from __future__ import annotations

import asyncio
import json
import logging
import re
import time
import uuid
from collections import Counter
from typing import Any, Awaitable, Callable, TYPE_CHECKING

import os

import httpx

from gemini_embed_pool import (
    EMBED_DIMENSIONS,
    GEMINI_EMBED_BASE,
    PAID_KEY_FALLBACK_THRESHOLD,
    check_openrouter_only,
    get_paid_fallback_key,
    is_daily_quota_429,
    mark_key_daily_spent,
    next_gemini_embed_key,
    openrouter_embed_batch,
    record_key_usage,
    set_openrouter_only,
)
from platforms.media_common import media_to_content_parts
from tools import ToolRegistry

if TYPE_CHECKING:
    from tool_context import ToolContext

logger = logging.getLogger(__name__)

# Default Gemini model id for countTokens (Google Generative Language API).
DEFAULT_GEMINI_COUNT_TOKENS_MODEL = "gemini-3.1-flash-lite-preview"

MAX_499_RETRIES = 3
MAX_HEADER_RETRIES = 1
MAX_TOOLS_PER_REQUEST = 250
EMBED_RETRY_BASE_DELAY = 1.0
MAX_EMBED_DELAY = 8.0
MAX_EMBED_RETRIES = 14

# HTTP status codes that are transient and should be retried
_RETRIABLE_STATUSES = {429, 500, 502, 503, 504}

_HEADER_REGEN_MESSAGE = (
    "Your previous response did not comply with the system prompt's rules for "
    "header generation. Responses MUST begin with a response header bracket "
    "[`model` :: emojis :: thought :: `tools`]. Regenerate your response now, "
    "starting with the required header."
)

# Decode large JSON bodies off the event loop (chat completions, big embed batches).
_JSON_BODY_DECODE_THRESHOLD = 256 * 1024


def _json_loads_utf8(body: bytes) -> Any:
    return json.loads(body.decode("utf-8"))


async def _async_response_json(resp: httpx.Response) -> Any:
    """Read response body asynchronously; parse JSON in a thread if large."""
    body = await resp.aread()
    if len(body) >= _JSON_BODY_DECODE_THRESHOLD:
        return await asyncio.to_thread(_json_loads_utf8, body)
    return json.loads(body.decode("utf-8"))


[docs] class OpenRouterClient: """Thin async wrapper around an OpenAI-compatible chat-completions endpoint. Implements the full tool-call loop: when the model responds with ``tool_calls``, each tool is executed via the :class:`ToolRegistry`, the results are appended, and the API is called again. This repeats until the model produces a final text response (or a safety limit is reached). """
[docs] def __init__( self, api_key: str, model: str = "x-ai/grok-4.1-fast", temperature: float = 1.0, max_tokens: int = 60_000, tool_registry: ToolRegistry | None = None, max_tool_rounds: int = 10, base_url: str = "http://localhost:3000/openai", gemini_api_key: str = "", gemini_count_tokens_model: str = DEFAULT_GEMINI_COUNT_TOKENS_MODEL, max_tool_output_chars: int = 150_000, ) -> None: """Initialize the instance. Args: api_key (str): The api key value. model (str): The model value. temperature (float): The temperature value. max_tokens (int): The max tokens value. tool_registry (ToolRegistry | None): The tool registry value. max_tool_rounds (int): The max tool rounds value. base_url (str): The base url value. gemini_api_key (str): The gemini api key value. gemini_count_tokens_model (str): Model id for Gemini countTokens API. max_tool_output_chars (int): Max tool result string length before truncation in the chat loop; ``<= 0`` disables the cap. """ self.api_key = api_key self.gemini_api_key = gemini_api_key or self._resolve_gemini_api_key() self.gemini_count_tokens_model = gemini_count_tokens_model self.model = model self.temperature = temperature self.max_tokens = max_tokens self.tool_registry = tool_registry or ToolRegistry() self.max_tool_rounds = max_tool_rounds self.max_tool_output_chars = max_tool_output_chars self._chat_url = base_url.rstrip("/") + "/chat/completions" self._http = httpx.AsyncClient( timeout=httpx.Timeout(300.0, connect=10.0), )
@staticmethod def _resolve_gemini_api_key() -> str: """Return the Gemini API key from the environment or config file.""" key = os.getenv("GEMINI_API_KEY", "") if key: return key try: from config import Config cfg = Config.load() return cfg.gemini_api_key or "" except Exception: pass return "" # ------------------------------------------------------------------ # Token counting via Gemini API # ------------------------------------------------------------------ async def _count_tokens( self, messages: list[dict[str, Any]], *, gemini_model: str | None = None, ) -> int | None: """Count input tokens using the Gemini countTokens API. Converts OpenAI-format messages to Gemini ``contents`` format and calls the free countTokens endpoint. Returns the total token count or ``None`` on any failure (fail-open). """ if not self.gemini_api_key: return None model_id = gemini_model or self.gemini_count_tokens_model # Convert OpenAI messages to Gemini contents format contents = [] for msg in messages: role = msg.get("role", "user") content = msg.get("content", "") if not content: continue # Map OpenAI roles to Gemini roles if role in ("system", "assistant"): gemini_role = "model" else: gemini_role = "user" if isinstance(content, list): # Multi-part content: extract text parts text = " ".join( p.get("text", "") for p in content if p.get("type") == "text" ) if not text: continue contents.append({ "role": gemini_role, "parts": [{"text": text}], }) else: contents.append({ "role": gemini_role, "parts": [{"text": str(content)}], }) if not contents: return None # Ensure alternating roles (Gemini requirement) deduped: list[dict] = [contents[0]] for c in contents[1:]: if c["role"] == deduped[-1]["role"]: # Merge into previous message deduped[-1]["parts"].extend(c["parts"]) else: deduped.append(c) url = ( f"https://generativelanguage.googleapis.com/v1beta/" f"models/{model_id}:countTokens" f"?key={self.gemini_api_key}" ) payload = {"contents": deduped} try: resp = await self._http.post( url, json=payload, timeout=5.0, ) if resp.status_code == 200: data = await _async_response_json(resp) count = data.get("totalTokens") if isinstance(count, int): logger.info( "Gemini countTokens: %d tokens", count, ) return count else: logger.debug( "countTokens returned %d: %s", resp.status_code, resp.text[:200], ) except Exception: logger.debug( "countTokens call failed", exc_info=True, ) return None
[docs] async def count_input_tokens( self, messages: list[dict[str, Any]], *, gemini_model: str | None = None, ) -> int | None: """Public wrapper for Gemini ``countTokens`` on OpenAI-shaped *messages*. *gemini_model* overrides :attr:`gemini_count_tokens_model` for this call. """ return await self._count_tokens( messages, gemini_model=gemini_model, )
# ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------
[docs] async def chat( self, messages: list[dict[str, Any]], user_id: str = "", ctx: ToolContext | None = None, tool_names: list[str] | None = None, validate_header: bool = False, token_count: int | None = None, on_intermediate_text: Callable[[str], Awaitable[None]] | None = None, ) -> str: """Send *messages* to the LLM and return the final assistant text. If the model requests tool calls, they are executed automatically and the conversation is continued until a text response is produced. *user_id* is forwarded to :meth:`ToolRegistry.call` for permission checking. *ctx* is forwarded to tools that opt-in to receiving it. *tool_names*, when provided, restricts which tools the LLM sees to the given subset of registered tool names. *token_count*, when provided, is used directly instead of calling ``_count_tokens`` β€” allows the caller to pre-compute the count concurrently with other work. *on_intermediate_text*, when provided, is called with any text content the model produces alongside tool calls. Without this callback, such text is silently carried in the conversation history but never surfaced to the user. """ msgs = [dict(m) for m in messages] call_history: list[tuple[str, str]] = [] # (name, args_hash) total_calls = 0 # Count input tokens β€” use pre-computed value when available, # otherwise call Gemini API (serial fallback). if token_count is None: token_count = await self._count_tokens(msgs) _count_str = str(token_count) if token_count is not None else "unavailable" for _m in msgs: if _m.get("role") == "system" and "__INPUT_TOKEN_COUNT__" in str(_m.get("content", "")): _m["content"] = _m["content"].replace( "__INPUT_TOKEN_COUNT__", _count_str, ) break round_num = 0 while round_num < self.max_tool_rounds: logger.info( "LLM API call: round=%d, model=%s, messages=%d", round_num + 1, self.model, len(msgs), ) _t0 = time.monotonic() response_message = await self._call_api( msgs, tool_names=tool_names, user_id=user_id, ) response_message = self._normalize_assistant_tool_message( response_message, ) logger.info( "LLM API call round=%d completed in %.0f ms", round_num + 1, (time.monotonic() - _t0) * 1000, ) tool_calls = response_message.get("tool_calls") # Tool rounds (including legacy function_call β†’ tool_calls) carry # optional narration in content; never run header validation on that. if not tool_calls: content = response_message.get("content") or "" if content.startswith("[PROXY ERROR]"): logger.warning("Proxy returned an error response:\n%s", content) return content # --- Header validation: final text only (not tool-round narration) --- if validate_header and content and not self._validate_header(content): content = await self._regenerate_with_header_hint( msgs, content, tool_names=None, user_id=user_id, ) return content # --- Emit intermediate text alongside tool calls ------ intermediate_text = (response_message.get("content") or "").strip() if intermediate_text and on_intermediate_text: logger.debug( "Invoking on_intermediate_text (%d raw chars, %d tool call(s))", len(intermediate_text), len(tool_calls), ) try: await on_intermediate_text(intermediate_text) except Exception: logger.warning( "on_intermediate_text callback failed", exc_info=True, ) msgs.append(response_message) # --- Loop / excessive-call detection -------------------------- round_keys: list[tuple[str, str]] = [] for tc in tool_calls: fn = tc["function"] name = fn["name"] args_str = fn.get("arguments", "") round_keys.append((name, args_str)) call_history.extend(round_keys) total_calls += len(tool_calls) if total_calls > 200: logger.warning("Excessive tool calls (%d), aborting", total_calls) return "(Stopped: too many tool calls.)" exempt = self.tool_registry.repeat_allowed_tools() filtered_history = [ (n, a) for n, a in call_history if n not in exempt ] if self._detect_loop(filtered_history): logger.warning("Repetitive tool-call loop detected, aborting") return "(Stopped: repetitive tool-call loop detected.)" # --- Execute tool calls in parallel --------------------------- t0 = time.monotonic() sent_files_before = len(ctx.sent_files) if ctx else 0 async def _run_one(tc: dict) -> tuple[str, str]: """Internal helper: run one. Args: tc (dict): The tc value. Returns: tuple[str, str]: The result. """ fn = tc["function"] tool_name = fn["name"] try: arguments = json.loads(fn["arguments"]) arguments = { k.strip("'\""): v for k, v in arguments.items() } except (json.JSONDecodeError, TypeError): arguments = {} logger.debug( "Calling tool %s with %s", tool_name, arguments, ) result = await self.tool_registry.call( tool_name, arguments, user_id=user_id, ctx=ctx, ) return tc["id"], result results = await asyncio.gather( *[_run_one(tc) for tc in tool_calls], ) elapsed_ms = (time.monotonic() - t0) * 1000 logger.info( "Round %d: executed %d tool(s) in %.0f ms", round_num + 1, len(tool_calls), elapsed_ms, ) new_files = ( ctx.sent_files[sent_files_before:] if ctx and sent_files_before < len(ctx.sent_files) else [] ) for idx, (call_id, result) in enumerate(results): remaining = self.max_tool_rounds - round_num - 1 # Cap tool output to avoid blowing past API request limits cap = self.max_tool_output_chars if cap > 0 and len(result) > cap: _orig_len = len(result) result = ( result[:cap] + f"\n\n[TRUNCATED β€” output was {_orig_len:,} chars, " f"showing first {cap:,}]" ) logger.warning( "Truncated tool output from %d to %d chars", _orig_len, cap, ) text_content = f"[Remaining tool rounds: {remaining}]\n{result}" # Inline multimodal: attach sent_files media to first tool result # so the model sees screenshots/audio/video during the same turn. if idx == 0 and new_files: content_parts: list[dict[str, Any]] = [ {"type": "text", "text": text_content}, ] for sf in new_files: sf_parts = await media_to_content_parts( sf["data"], sf["mimetype"], sf["filename"], ) content_parts.extend(sf_parts) msgs.append({ "role": "tool", "tool_call_id": call_id, "content": content_parts, }) else: msgs.append({ "role": "tool", "tool_call_id": call_id, "content": text_content, }) # --- Merge dynamically injected tools ------------------------- if ( ctx is not None and ctx.injected_tools and tool_names is not None ): existing = set(tool_names) new_tools = [ t for t in ctx.injected_tools if t not in existing ] if new_tools: tool_names.extend(new_tools) logger.info( "Injected %d tool(s) into active set: %s", len(new_tools), new_tools, ) ctx.injected_tools = None # reset for next round round_num += 1 logger.warning("Reached max tool-call rounds (%d)", self.max_tool_rounds) # --- Final round: call LLM with no tools so it can respond -------- msgs.append({ "role": "user", "content": ( "[SYSTEM] You have reached the maximum number of allowed tool calls " f"({self.max_tool_rounds} rounds) for this cycle. You cannot " "make any more tool calls. Please generate your final response " "to the user now, summarizing what you accomplished and noting " "anything you were unable to complete." ), }) logger.info( "LLM API call: final no-tools round, model=%s, messages=%d", self.model, len(msgs), ) _t0 = time.monotonic() try: final_response = await self._call_api( msgs, tool_names=[], user_id=user_id, ) logger.info( "LLM API final no-tools round completed in %.0f ms", (time.monotonic() - _t0) * 1000, ) content = final_response.get("content") or "(max tool-call rounds reached)" except Exception as exc: logger.error("Final no-tools round failed: %s", exc) content = response_message.get("content") or "(max tool-call rounds reached)" if content.startswith("[PROXY ERROR]"): logger.warning("Proxy returned an error response:\n%s", content) return content # --- Header validation on final-round response --- if validate_header and content and not self._validate_header(content): content = await self._regenerate_with_header_hint( msgs, content, tool_names=[], user_id=user_id, ) return content
[docs] async def embed(self, text: str, model: str) -> list[float]: """Generate an embedding vector for *text* via the Gemini API. Uses the shared key pool for rate-limit distribution. Retries with exponential back-off (capped at ``MAX_EMBED_DELAY``) up to ``MAX_EMBED_RETRIES`` times before raising. Raises :class:`ValueError` immediately if *text* is empty or whitespace-only (the embedding API rejects such input with 400). Parameters ---------- text: The text to embed. model: The embedding model identifier (e.g. ``"google/gemini-embedding-001"``). Returns ------- list[float] The embedding vector. """ if not text or not text.strip(): raise ValueError("Cannot embed empty or whitespace-only text") round_num = 0 last_error: str | None = None while round_num < MAX_EMBED_RETRIES: try: return await self._embed_gemini(text, model) except ValueError: raise # non-retriable (bad input), don't retry except Exception as exc: last_error = str(exc) logger.warning( "Gemini embed failed (round %d): %s", round_num + 1, exc, ) delay = min(EMBED_RETRY_BASE_DELAY * (2 ** round_num), MAX_EMBED_DELAY) await asyncio.sleep(delay) round_num += 1 raise RuntimeError( f"Gemini embed failed after {MAX_EMBED_RETRIES} rounds: {last_error}" )
[docs] async def embed_batch( self, texts: list[str], model: str, ) -> list[list[float]]: """Generate embedding vectors for multiple texts via the Gemini API. Uses the shared key pool. Empty or whitespace-only texts are filtered out; their positions are filled with zero vectors. Retries with exponential back-off up to ``MAX_EMBED_RETRIES`` times. Parameters ---------- texts: List of texts to embed. model: The embedding model identifier. Returns ------- list[list[float]] One embedding vector per input text, in the same order. """ if not texts: return [] # Filter out blank texts, keeping track of original indices. valid_indices: list[int] = [] valid_texts: list[str] = [] for i, t in enumerate(texts): if t and t.strip(): valid_indices.append(i) valid_texts.append(t) if not valid_texts: logger.warning( "embed_batch called with %d text(s), all empty/blank β€” returning zero vectors", len(texts), ) return [[0.0] * EMBED_DIMENSIONS for _ in texts] if len(valid_texts) < len(texts): logger.info( "embed_batch: filtered %d blank text(s) from batch of %d", len(texts) - len(valid_texts), len(texts), ) round_num = 0 last_error: str | None = None while round_num < MAX_EMBED_RETRIES: try: valid_embeddings = await self._embed_gemini_batch(valid_texts, model) result: list[list[float]] = [[0.0] * EMBED_DIMENSIONS for _ in texts] for idx, emb in zip(valid_indices, valid_embeddings): result[idx] = emb return result except ValueError: raise # non-retriable (bad input), don't retry except Exception as exc: last_error = str(exc) logger.warning( "Gemini embed_batch failed (round %d): %s", round_num + 1, exc, ) delay = min(EMBED_RETRY_BASE_DELAY * (2 ** round_num), MAX_EMBED_DELAY) await asyncio.sleep(delay) round_num += 1 raise RuntimeError( f"Gemini embed_batch failed after {MAX_EMBED_RETRIES} rounds: {last_error}" )
# ------------------------------------------------------------------ # Embedding provider implementations # ------------------------------------------------------------------ @staticmethod def _gemini_model_name(model: str) -> str: """Convert OpenRouter model name to Gemini native model name.""" return model.removeprefix("google/") async def _embed_gemini( self, text: str, model: str, task_type: str | None = None, ) -> list[float]: """Embed a single text via the native Gemini API. Falls back to the paid key after ``PAID_KEY_FALLBACK_THRESHOLD`` consecutive 429 responses. Args: text: Text to embed. model: Embedding model identifier. task_type: Optional Gemini task type (e.g. ``QUESTION_ANSWERING``). """ if await check_openrouter_only(): logger.info("OpenRouter-only mode β€” bypassing Gemini for single text") vecs = await openrouter_embed_batch( [text], model=model, api_key=self.api_key, ) return vecs[0] gemini_model = self._gemini_model_name(model) api_key = next_gemini_embed_key() url = f"{GEMINI_EMBED_BASE}/{gemini_model}:embedContent?key={api_key}" payload: dict = { "model": f"models/{gemini_model}", "content": {"parts": [{"text": text}]}, "output_dimensionality": EMBED_DIMENSIONS, } if task_type: payload["taskType"] = task_type attempt = 0 consecutive_429 = 0 switched_to_paid = False while attempt < MAX_EMBED_RETRIES: if attempt > 0: delay = min( EMBED_RETRY_BASE_DELAY * (2 ** (attempt - 1)), MAX_EMBED_DELAY, ) logger.warning( "Gemini embed error, retrying in %.1fs (attempt %d): %s", delay, attempt + 1, last_error, ) await asyncio.sleep(delay) last_error: str | None = None try: _t0 = time.monotonic() resp = await self._http.post(url, json=payload) logger.info( "Gemini embed HTTP request completed in %.0f ms (status=%d)", (time.monotonic() - _t0) * 1000, resp.status_code, ) await record_key_usage(api_key) if resp.status_code == 400: body = resp.text[:500] logger.error( "Gemini embed 400 Bad Request (non-retriable): %s", body, ) raise ValueError( f"Gemini embed 400 Bad Request: {body}" ) if resp.status_code == 429: if is_daily_quota_429(resp): await mark_key_daily_spent(api_key, "embed") api_key = next_gemini_embed_key() url = ( f"{GEMINI_EMBED_BASE}/{gemini_model}" f":embedContent?key={api_key}" ) last_error = "HTTP 429 (daily quota)" attempt += 1 continue consecutive_429 += 1 if ( consecutive_429 >= PAID_KEY_FALLBACK_THRESHOLD and not switched_to_paid ): paid = get_paid_fallback_key() if paid: logger.warning( "Switching to paid Gemini key after %d " "consecutive 429s", consecutive_429, ) api_key = paid url = ( f"{GEMINI_EMBED_BASE}/{gemini_model}" f":embedContent?key={paid}" ) switched_to_paid = True elif switched_to_paid: await set_openrouter_only() try: logger.warning( "Paid Gemini key 429'd β€” trying OpenRouter", ) vecs = await openrouter_embed_batch( [text], model=model, api_key=self.api_key, ) return vecs[0] except Exception: logger.warning( "OpenRouter embed fallback also failed", exc_info=True, ) last_error = "HTTP 429" attempt += 1 continue if resp.status_code in _RETRIABLE_STATUSES: last_error = f"HTTP {resp.status_code}" attempt += 1 continue resp.raise_for_status() data = await _async_response_json(resp) return data["embedding"]["values"] except ValueError: raise # non-retriable, propagate immediately except Exception as exc: last_error = str(exc) attempt += 1 raise RuntimeError( f"Gemini embed failed after {MAX_EMBED_RETRIES} attempts: {last_error}" ) async def _embed_gemini_batch( self, texts: list[str], model: str, task_type: str | None = None, ) -> list[list[float]]: """Embed a batch of texts via the native Gemini API. Falls back to the paid key after ``PAID_KEY_FALLBACK_THRESHOLD`` consecutive 429 responses. Args: texts: Texts to embed. model: Embedding model identifier. task_type: Optional Gemini task type (e.g. ``RETRIEVAL_DOCUMENT``). """ if await check_openrouter_only(): logger.info("OpenRouter-only mode β€” bypassing Gemini for %d texts", len(texts)) return await openrouter_embed_batch( texts, model=model, api_key=self.api_key, ) gemini_model = self._gemini_model_name(model) api_key = next_gemini_embed_key() url = ( f"{GEMINI_EMBED_BASE}/{gemini_model}:batchEmbedContents" f"?key={api_key}" ) requests_list = [] for t in texts: req: dict = { "model": f"models/{gemini_model}", "content": {"parts": [{"text": t}]}, "output_dimensionality": EMBED_DIMENSIONS, } if task_type: req["taskType"] = task_type requests_list.append(req) payload = {"requests": requests_list} attempt = 0 consecutive_429 = 0 switched_to_paid = False while attempt < MAX_EMBED_RETRIES: if attempt > 0: delay = min( EMBED_RETRY_BASE_DELAY * (2 ** (attempt - 1)), MAX_EMBED_DELAY, ) logger.warning( "Gemini batch embed error, retrying in %.1fs (attempt %d): %s", delay, attempt + 1, last_error, ) await asyncio.sleep(delay) last_error: str | None = None try: _t0 = time.monotonic() resp = await self._http.post(url, json=payload) logger.info( "Gemini batch embed HTTP request completed in %.0f ms (status=%d)", (time.monotonic() - _t0) * 1000, resp.status_code, ) await record_key_usage(api_key) if resp.status_code == 400: body = resp.text[:500] logger.error( "Gemini batch embed 400 Bad Request (non-retriable): %s", body, ) raise ValueError( f"Gemini batch embed 400 Bad Request: {body}" ) if resp.status_code == 429: if is_daily_quota_429(resp): await mark_key_daily_spent(api_key, "embed") api_key = next_gemini_embed_key() url = ( f"{GEMINI_EMBED_BASE}/{gemini_model}" f":batchEmbedContents?key={api_key}" ) last_error = "HTTP 429 (daily quota)" attempt += 1 continue consecutive_429 += 1 if ( consecutive_429 >= PAID_KEY_FALLBACK_THRESHOLD and not switched_to_paid ): paid = get_paid_fallback_key() if paid: logger.warning( "Switching to paid Gemini key after %d " "consecutive 429s", consecutive_429, ) api_key = paid url = ( f"{GEMINI_EMBED_BASE}/{gemini_model}" f":batchEmbedContents?key={paid}" ) switched_to_paid = True elif switched_to_paid: await set_openrouter_only() try: logger.warning( "Paid Gemini key 429'd β€” trying OpenRouter", ) return await openrouter_embed_batch( texts, model=model, api_key=self.api_key, ) except Exception: logger.warning( "OpenRouter embed fallback also failed", exc_info=True, ) last_error = "HTTP 429" attempt += 1 continue if resp.status_code in _RETRIABLE_STATUSES: last_error = f"HTTP {resp.status_code}" attempt += 1 continue resp.raise_for_status() data = await _async_response_json(resp) return [item["values"] for item in data["embeddings"]] except ValueError: raise # non-retriable, propagate immediately except Exception as exc: last_error = str(exc) attempt += 1 raise RuntimeError( f"Gemini batch embed failed after {MAX_EMBED_RETRIES} attempts: {last_error}" ) # ------------------------------------------------------------------ # Header validation helpers # ------------------------------------------------------------------ @staticmethod def _normalize_assistant_tool_message( message: dict[str, Any], ) -> dict[str, Any]: """Map legacy ``function_call`` to ``tool_calls`` when needed. Some OpenAI-compatible providers still return the pre-tool-use ``function_call`` field instead of ``tool_calls``. Without this, the loop treats the turn as a final text response and runs header validation on optional narration alongside the tool invocation. """ existing = message.get("tool_calls") if existing: return message fc = message.get("function_call") if not isinstance(fc, dict): return message name = fc.get("name") if not name: return message args = fc.get("arguments") if not isinstance(args, str): args = json.dumps(args) if args is not None else "{}" out = dict(message) out["tool_calls"] = [{ "id": f"legacy_{uuid.uuid4().hex[:16]}", "type": "function", "function": {"name": name, "arguments": args}, }] out.pop("function_call", None) return out @staticmethod def _validate_header(content: str) -> bool: """Return *True* if *content* starts with the required ``[`` header. Thought / thinking tags are stripped first so that wrapped responses like ``<thinking>...\n</thinking>\n[header ...]`` still pass. """ # Strip closed thought/thinking/glitch blocks first cleaned = re.sub(r"<thinking>.*?</thinking>", "", content, flags=re.DOTALL) cleaned = re.sub(r"<thought>.*?</thought>", "", cleaned, flags=re.DOTALL) cleaned = re.sub(r"πŸ’‘thought.*?</font>", "", cleaned, flags=re.DOTALL) # Strip unclosed opening tags (truncated or never closed) β€” # remove from the tag to end-of-string, then check what remains cleaned = re.sub(r"<thinking>.*", "", cleaned, flags=re.DOTALL) cleaned = re.sub(r"<thought>.*", "", cleaned, flags=re.DOTALL) cleaned = cleaned.lstrip() if not cleaned: # Content was entirely thinking tags β€” the model produced only # internal reasoning with no user-facing text. Regenerating # would just waste an API call; accept as-is. has_thinking = bool( re.search(r"<think(?:ing|ought)>", content) ) return has_thinking return cleaned.startswith("[") async def _regenerate_with_header_hint( self, msgs: list[dict[str, Any]], bad_content: str, *, tool_names: list[str] | None, user_id: str, ) -> str: """Re-call the LLM once with an ephemeral correction message. Appends the failed assistant response and a system hint to *msgs*, then calls ``_call_api`` with no tools. Returns whatever the model produces (even if it still lacks a header β€” we only retry once). """ logger.warning( "Response missing required header (starts with %r), " "regenerating with header hint", bad_content, ) # Append the bad response so the model sees what it did wrong msgs.append({"role": "assistant", "content": bad_content}) msgs.append({"role": "system", "content": _HEADER_REGEN_MESSAGE}) try: _t0 = time.monotonic() retry_response = await self._call_api( msgs, tool_names=[], user_id=user_id, ) logger.info( "Header-hint regeneration completed in %.0f ms", (time.monotonic() - _t0) * 1000, ) return retry_response.get("content") or bad_content except Exception as exc: logger.error("Header-hint regeneration failed: %s", exc) return bad_content
[docs] async def close(self) -> None: """Close. """ await self._http.aclose()
# ------------------------------------------------------------------ # Loop detection # ------------------------------------------------------------------ @staticmethod def _detect_loop( history: list[tuple[str, str]], threshold: int = 3, ) -> bool: """Return *True* if any (name, args) pair appears >= *threshold* times.""" counts = Counter(history) return any(c >= threshold for c in counts.values()) # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ async def _call_api( self, messages: list[dict[str, Any]], tool_names: list[str] | None = None, user_id: str = "", ) -> dict[str, Any]: """Make a single chat-completions request and return the first choice message.""" payload: dict[str, Any] = { "model": self.model, "messages": messages, "temperature": self.temperature, "max_tokens": self.max_tokens, } if user_id: payload["user"] = user_id if self.tool_registry.has_tools and tool_names is not None: tools = self.tool_registry.get_openai_tools_by_names(set(tool_names)) if len(tools) > MAX_TOOLS_PER_REQUEST: logger.warning( "Tool count %d exceeds hard cap %d, truncating", len(tools), MAX_TOOLS_PER_REQUEST, ) tools = tools[:MAX_TOOLS_PER_REQUEST] if tools: payload["tools"] = tools headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", "HTTP-Referer": "https://github.com/matrix-llm-bot", "X-Title": "Matrix LLM Bot", } for _attempt in range(MAX_499_RETRIES): resp = await self._http.post( self._chat_url, json=payload, headers=headers, ) if resp.status_code == 499 and _attempt < MAX_499_RETRIES - 1: logger.warning( "HTTP 499, retrying immediately (%d/%d)", _attempt + 1, MAX_499_RETRIES, ) continue if resp.status_code >= 400: _body = (resp.text or "")[:4000] logger.error( "LLM chat HTTP %s (non-success). Response body (truncated): %s", resp.status_code, _body, ) resp.raise_for_status() break data = await _async_response_json(resp) if "error" in data: error_msg = data["error"].get("message", str(data["error"])) raise RuntimeError(f"LLM API error: {error_msg}") choices = data.get("choices") if not choices: raise RuntimeError(f"LLM API returned no choices: {data}") return choices[0]["message"]