Source code for gemini_kg_bulk_client

"""Native Gemini (google.genai) client for bulk agentic KG extraction.

Uses :func:`gemini_embed_pool.next_gemini_flash_key` for API key rotation,
``client.aio`` for async I/O, and automatic function calling (AFC) with
thin async callables that delegate to :meth:`tools.ToolRegistry.call`.
"""

from __future__ import annotations

import asyncio
import json
import logging
import time
from typing import Any, Callable, Awaitable

from google import genai
from google.genai import types
from google.genai.errors import APIError

from gemini_embed_pool import (
    get_paid_fallback_key,
    mark_key_daily_spent,
    next_gemini_flash_key,
)
from tool_context import ToolContext
from tools import ToolRegistry

logger = logging.getLogger(__name__)

_MAX_KEY_ATTEMPTS = 32


def _error_body_daily_quota(details: Any) -> bool:
    """True if Gemini error details include a *daily* (PerDay) quota violation.

    Avoids relying on ``response.status_code`` / httpx-only parsing: the GenAI
    async client may attach aiohttp responses.
    """
    if details is None:
        return False
    body = details
    if isinstance(body, list) and body:
        body = body[0]
    if not isinstance(body, dict):
        return False
    err = body.get("error", body)
    if not isinstance(err, dict):
        return False
    for detail in err.get("details", []) or []:
        if not isinstance(detail, dict):
            continue
        for violation in detail.get("violations", []) or []:
            qid = str(violation.get("quotaId", "") or "")
            if "PerDay" in qid:
                return True
    return False


_GEN_HTTP_TIMEOUT_MS = 600_000


def _truncate_tool_output(result: str, cap: int) -> str:
    if cap <= 0 or len(result) <= cap:
        return result
    orig = len(result)
    return (
        result[:cap]
        + f"\n\n[TRUNCATED — output was {orig:,} chars, showing first {cap:,}]"
    )


def _openai_message_text(content: Any) -> str:
    if isinstance(content, str):
        return content
    if isinstance(content, list):
        return " ".join(
            str(p.get("text", ""))
            for p in content
            if isinstance(p, dict) and p.get("type") == "text"
        )
    return ""


[docs] def openai_messages_to_gemini( messages: list[dict[str, Any]], ) -> tuple[str | None, list[types.Content]]: """Split system instruction and build Gemini ``contents`` (no system in contents).""" system_chunks: list[str] = [] rest: list[dict[str, Any]] = [] seen_non_system = False for msg in messages: role = msg.get("role", "user") if role == "system" and not seen_non_system: t = _openai_message_text(msg.get("content", "")) if t: system_chunks.append(t) continue seen_non_system = True rest.append(msg) system_instruction = "\n\n".join(system_chunks) if system_chunks else None contents: list[types.Content] = [] for msg in rest: role = msg.get("role", "user") if role == "user": text = _openai_message_text(msg.get("content", "")) if text: contents.append( types.Content( role="user", parts=[types.Part(text=text)], ), ) elif role == "assistant": parts: list[types.Part] = [] t = _openai_message_text(msg.get("content", "")) if t: parts.append(types.Part(text=t)) for tc in msg.get("tool_calls") or []: fn = tc.get("function") or {} name = str(fn.get("name") or "") raw = fn.get("arguments") if isinstance(raw, str): try: args = json.loads(raw) except json.JSONDecodeError: args = {} elif isinstance(raw, dict): args = raw else: args = {} if not isinstance(args, dict): args = {} parts.append( types.Part( function_call=types.FunctionCall(name=name, args=args), ), ) if parts: contents.append(types.Content(role="model", parts=parts)) elif role == "tool": fn_name = msg.get("name") or "" if not fn_name: continue raw_out = msg.get("content", "") if isinstance(raw_out, list): text_out = _openai_message_text(raw_out) else: text_out = str(raw_out) try: parsed: Any = json.loads(text_out) resp_body: dict[str, Any] = ( parsed if isinstance(parsed, dict) else {"output": text_out} ) except json.JSONDecodeError: resp_body = {"output": text_out} contents.append( types.Content( role="user", parts=[ types.Part( function_response=types.FunctionResponse( name=str(fn_name), response=resp_body, ), ), ], ), ) return system_instruction, contents
def _contents_for_count_tokens( sys_inst: str | None, contents: list[types.Content], ) -> list[types.Content]: """Build ``contents`` for :meth:`Models.count_tokens`. The Gemini developer API rejects ``system_instruction`` on ``countTokens``. Prepend the system text as the first part of the opening turn so totals are close to ``generate_content`` with ``system_instruction`` set. """ if not sys_inst or not str(sys_inst).strip(): return contents if not contents: return [ types.Content( role="user", parts=[types.Part(text=str(sys_inst))], ), ] first = contents[0] merged_parts: list[types.Part] = [types.Part(text=str(sys_inst))] for p in first.parts or []: merged_parts.append(p) return [ types.Content(role=first.role, parts=merged_parts), *contents[1:], ] def _afc_tool_fns( registry: ToolRegistry, *, user_id: str, ctx: ToolContext | None, max_tool_output_chars: int, tool_names: list[str] | None, ) -> list[Any]: allowed = set(tool_names) if tool_names else None async def kg_search_entities( query: str, category: str = "", scope_id: str = "", top_k: int = 12, ) -> str: raw = await registry.call( "kg_search_entities", { "query": query, "category": category, "scope_id": scope_id, "top_k": top_k, }, user_id=user_id, ctx=ctx, ) return _truncate_tool_output(raw, max_tool_output_chars) async def kg_get_entity( name: str = "", uuid: str = "", ) -> str: raw = await registry.call( "kg_get_entity", {"name": name, "uuid": uuid}, user_id=user_id, ctx=ctx, ) return _truncate_tool_output(raw, max_tool_output_chars) async def kg_inspect_entity( name: str = "", uuid: str = "", max_depth: int = 2, ) -> str: raw = await registry.call( "kg_inspect_entity", {"name": name, "uuid": uuid, "max_depth": max_depth}, user_id=user_id, ctx=ctx, ) return _truncate_tool_output(raw, max_tool_output_chars) all_tools = [kg_search_entities, kg_get_entity, kg_inspect_entity] if allowed is None: return all_tools return [f for f in all_tools if f.__name__ in allowed] async def _safe_aclose(client: genai.Client) -> None: try: await client.aio.aclose() except Exception: logger.debug("genai client aclose failed", exc_info=True)
[docs] class GeminiPoolToolChatClient: """Async Gemini chat with pool keys, countTokens, and AFC tool execution.""" def __init__( self, *, tool_registry: ToolRegistry, model_id: str, max_tool_rounds: int = 48, max_tokens: int = 60_000, max_tool_output_chars: int = 3_000_000, temperature: float = 0.25, ) -> None: self.tool_registry = tool_registry self.model_id = model_id self.max_tool_rounds = max_tool_rounds self.max_tokens = max_tokens self.max_tool_output_chars = max_tool_output_chars self.temperature = temperature
[docs] async def close(self) -> None: return
[docs] async def count_input_tokens( self, messages: list[dict[str, Any]], *, gemini_model: str | None = None, ) -> int | None: model = gemini_model or self.model_id sys_inst, contents = openai_messages_to_gemini(messages) if not contents: return None count_contents = _contents_for_count_tokens(sys_inst, contents) paid_used = False last_exc: BaseException | None = None for attempt in range(_MAX_KEY_ATTEMPTS): key = next_gemini_flash_key() if attempt >= _MAX_KEY_ATTEMPTS - 2 and not paid_used: pk = get_paid_fallback_key() if pk: key = pk paid_used = True if attempt > 0: delay = 1.0 if attempt <= 12 else min(2.0 ** (attempt - 12), 30.0) await asyncio.sleep(delay) client = genai.Client( api_key=key, http_options=types.HttpOptions(timeout=_GEN_HTTP_TIMEOUT_MS), ) try: resp = await client.aio.models.count_tokens( model=model, contents=count_contents, ) n = resp.total_tokens if isinstance(n, int): logger.info("Gemini countTokens: %d tokens", n) return n except APIError as e: last_exc = e if e.code == 429 and _error_body_daily_quota(e.details): await mark_key_daily_spent(key, "generate") if e.code not in (429, 500, 502, 503, 504): logger.debug( "count_tokens APIError %s: %s", e.code, e.message, ) return None except Exception as e: last_exc = e logger.debug("count_tokens failed", exc_info=True) return None finally: await _safe_aclose(client) if last_exc: logger.debug("count_tokens exhausted retries: %s", last_exc) return None
[docs] async def chat( self, messages: list[dict[str, Any]], user_id: str = "", ctx: ToolContext | None = None, tool_names: list[str] | None = None, validate_header: bool = False, token_count: int | None = None, on_intermediate_text: Callable[[str], Awaitable[None]] | None = None, ) -> str: del validate_header, on_intermediate_text # bulk path unused msgs = [dict(m) for m in messages] if token_count is None: token_count = await self.count_input_tokens(msgs) _count_str = str(token_count) if token_count is not None else "unavailable" for _m in msgs: if _m.get("role") == "system" and "__INPUT_TOKEN_COUNT__" in str( _m.get("content", ""), ): _m["content"] = str(_m["content"]).replace( "__INPUT_TOKEN_COUNT__", _count_str, ) break sys_inst, contents = openai_messages_to_gemini(msgs) if not contents: return "" tools = _afc_tool_fns( self.tool_registry, user_id=user_id, ctx=ctx, max_tool_output_chars=self.max_tool_output_chars, tool_names=tool_names, ) afc_limit = max(1, int(self.max_tool_rounds)) config = types.GenerateContentConfig( system_instruction=sys_inst, temperature=self.temperature, max_output_tokens=int(self.max_tokens), tools=tools, tool_config=types.ToolConfig( function_calling_config=types.FunctionCallingConfig( mode=types.FunctionCallingConfigMode.AUTO, ), ), automatic_function_calling=types.AutomaticFunctionCallingConfig( maximum_remote_calls=afc_limit, ), ) paid_used = False last_exc: BaseException | None = None for attempt in range(_MAX_KEY_ATTEMPTS): key = next_gemini_flash_key() if attempt >= _MAX_KEY_ATTEMPTS - 2 and not paid_used: pk = get_paid_fallback_key() if pk: key = pk paid_used = True if attempt > 0: delay = 1.0 if attempt <= 12 else min(2.0 ** (attempt - 12), 30.0) await asyncio.sleep(delay) client = genai.Client( api_key=key, http_options=types.HttpOptions(timeout=_GEN_HTTP_TIMEOUT_MS), ) t0 = time.monotonic() try: logger.info( "Gemini bulk AFC: attempt=%d model=%s contents=%d", attempt + 1, self.model_id, len(contents), ) resp = await client.aio.models.generate_content( model=self.model_id, contents=contents, config=config, ) logger.info( "Gemini bulk AFC completed in %.0f ms", (time.monotonic() - t0) * 1000, ) text = (getattr(resp, "text", None) or "").strip() if text.startswith("[PROXY ERROR]"): logger.warning("Unexpected proxy-style prefix in Gemini response") return text except APIError as e: last_exc = e if e.code == 429 and _error_body_daily_quota(e.details): await mark_key_daily_spent(key, "generate") if e.code not in (429, 500, 502, 503, 504): raise finally: await _safe_aclose(client) msg = f"Gemini bulk chat failed after {_MAX_KEY_ATTEMPTS} attempts" if last_exc: msg += f": {last_exc}" logger.error(msg) raise RuntimeError(msg) from last_exc