Source code for gemini_kg_bulk_client
"""Native Gemini (google.genai) client for bulk agentic KG extraction.
Uses :func:`gemini_embed_pool.next_gemini_flash_key` for API key rotation,
``client.aio`` for async I/O, and automatic function calling (AFC) with
thin async callables that delegate to :meth:`tools.ToolRegistry.call`.
"""
from __future__ import annotations
import asyncio
import json
import logging
import time
from typing import Any, Callable, Awaitable
from google import genai
from google.genai import types
from google.genai.errors import APIError
from gemini_embed_pool import (
get_paid_fallback_key,
mark_key_daily_spent,
next_gemini_flash_key,
)
from tool_context import ToolContext
from tools import ToolRegistry
logger = logging.getLogger(__name__)
_MAX_KEY_ATTEMPTS = 32
def _error_body_daily_quota(details: Any) -> bool:
"""True if Gemini error details include a *daily* (PerDay) quota violation.
Avoids relying on ``response.status_code`` / httpx-only parsing: the GenAI
async client may attach aiohttp responses.
"""
if details is None:
return False
body = details
if isinstance(body, list) and body:
body = body[0]
if not isinstance(body, dict):
return False
err = body.get("error", body)
if not isinstance(err, dict):
return False
for detail in err.get("details", []) or []:
if not isinstance(detail, dict):
continue
for violation in detail.get("violations", []) or []:
qid = str(violation.get("quotaId", "") or "")
if "PerDay" in qid:
return True
return False
_GEN_HTTP_TIMEOUT_MS = 600_000
def _truncate_tool_output(result: str, cap: int) -> str:
if cap <= 0 or len(result) <= cap:
return result
orig = len(result)
return (
result[:cap]
+ f"\n\n[TRUNCATED — output was {orig:,} chars, showing first {cap:,}]"
)
def _openai_message_text(content: Any) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
return " ".join(
str(p.get("text", ""))
for p in content
if isinstance(p, dict) and p.get("type") == "text"
)
return ""
[docs]
def openai_messages_to_gemini(
messages: list[dict[str, Any]],
) -> tuple[str | None, list[types.Content]]:
"""Split system instruction and build Gemini ``contents`` (no system in contents)."""
system_chunks: list[str] = []
rest: list[dict[str, Any]] = []
seen_non_system = False
for msg in messages:
role = msg.get("role", "user")
if role == "system" and not seen_non_system:
t = _openai_message_text(msg.get("content", ""))
if t:
system_chunks.append(t)
continue
seen_non_system = True
rest.append(msg)
system_instruction = "\n\n".join(system_chunks) if system_chunks else None
contents: list[types.Content] = []
for msg in rest:
role = msg.get("role", "user")
if role == "user":
text = _openai_message_text(msg.get("content", ""))
if text:
contents.append(
types.Content(
role="user",
parts=[types.Part(text=text)],
),
)
elif role == "assistant":
parts: list[types.Part] = []
t = _openai_message_text(msg.get("content", ""))
if t:
parts.append(types.Part(text=t))
for tc in msg.get("tool_calls") or []:
fn = tc.get("function") or {}
name = str(fn.get("name") or "")
raw = fn.get("arguments")
if isinstance(raw, str):
try:
args = json.loads(raw)
except json.JSONDecodeError:
args = {}
elif isinstance(raw, dict):
args = raw
else:
args = {}
if not isinstance(args, dict):
args = {}
parts.append(
types.Part(
function_call=types.FunctionCall(name=name, args=args),
),
)
if parts:
contents.append(types.Content(role="model", parts=parts))
elif role == "tool":
fn_name = msg.get("name") or ""
if not fn_name:
continue
raw_out = msg.get("content", "")
if isinstance(raw_out, list):
text_out = _openai_message_text(raw_out)
else:
text_out = str(raw_out)
try:
parsed: Any = json.loads(text_out)
resp_body: dict[str, Any] = (
parsed if isinstance(parsed, dict) else {"output": text_out}
)
except json.JSONDecodeError:
resp_body = {"output": text_out}
contents.append(
types.Content(
role="user",
parts=[
types.Part(
function_response=types.FunctionResponse(
name=str(fn_name),
response=resp_body,
),
),
],
),
)
return system_instruction, contents
def _contents_for_count_tokens(
sys_inst: str | None,
contents: list[types.Content],
) -> list[types.Content]:
"""Build ``contents`` for :meth:`Models.count_tokens`.
The Gemini developer API rejects ``system_instruction`` on ``countTokens``.
Prepend the system text as the first part of the opening turn so totals are
close to ``generate_content`` with ``system_instruction`` set.
"""
if not sys_inst or not str(sys_inst).strip():
return contents
if not contents:
return [
types.Content(
role="user",
parts=[types.Part(text=str(sys_inst))],
),
]
first = contents[0]
merged_parts: list[types.Part] = [types.Part(text=str(sys_inst))]
for p in first.parts or []:
merged_parts.append(p)
return [
types.Content(role=first.role, parts=merged_parts),
*contents[1:],
]
def _afc_tool_fns(
registry: ToolRegistry,
*,
user_id: str,
ctx: ToolContext | None,
max_tool_output_chars: int,
tool_names: list[str] | None,
) -> list[Any]:
allowed = set(tool_names) if tool_names else None
async def kg_search_entities(
query: str,
category: str = "",
scope_id: str = "",
top_k: int = 12,
) -> str:
raw = await registry.call(
"kg_search_entities",
{
"query": query,
"category": category,
"scope_id": scope_id,
"top_k": top_k,
},
user_id=user_id,
ctx=ctx,
)
return _truncate_tool_output(raw, max_tool_output_chars)
async def kg_get_entity(
name: str = "",
uuid: str = "",
) -> str:
raw = await registry.call(
"kg_get_entity",
{"name": name, "uuid": uuid},
user_id=user_id,
ctx=ctx,
)
return _truncate_tool_output(raw, max_tool_output_chars)
async def kg_inspect_entity(
name: str = "",
uuid: str = "",
max_depth: int = 2,
) -> str:
raw = await registry.call(
"kg_inspect_entity",
{"name": name, "uuid": uuid, "max_depth": max_depth},
user_id=user_id,
ctx=ctx,
)
return _truncate_tool_output(raw, max_tool_output_chars)
all_tools = [kg_search_entities, kg_get_entity, kg_inspect_entity]
if allowed is None:
return all_tools
return [f for f in all_tools if f.__name__ in allowed]
async def _safe_aclose(client: genai.Client) -> None:
try:
await client.aio.aclose()
except Exception:
logger.debug("genai client aclose failed", exc_info=True)
[docs]
class GeminiPoolToolChatClient:
"""Async Gemini chat with pool keys, countTokens, and AFC tool execution."""
def __init__(
self,
*,
tool_registry: ToolRegistry,
model_id: str,
max_tool_rounds: int = 48,
max_tokens: int = 60_000,
max_tool_output_chars: int = 3_000_000,
temperature: float = 0.25,
) -> None:
self.tool_registry = tool_registry
self.model_id = model_id
self.max_tool_rounds = max_tool_rounds
self.max_tokens = max_tokens
self.max_tool_output_chars = max_tool_output_chars
self.temperature = temperature
[docs]
async def count_input_tokens(
self,
messages: list[dict[str, Any]],
*,
gemini_model: str | None = None,
) -> int | None:
model = gemini_model or self.model_id
sys_inst, contents = openai_messages_to_gemini(messages)
if not contents:
return None
count_contents = _contents_for_count_tokens(sys_inst, contents)
paid_used = False
last_exc: BaseException | None = None
for attempt in range(_MAX_KEY_ATTEMPTS):
key = next_gemini_flash_key()
if attempt >= _MAX_KEY_ATTEMPTS - 2 and not paid_used:
pk = get_paid_fallback_key()
if pk:
key = pk
paid_used = True
if attempt > 0:
delay = 1.0 if attempt <= 12 else min(2.0 ** (attempt - 12), 30.0)
await asyncio.sleep(delay)
client = genai.Client(
api_key=key,
http_options=types.HttpOptions(timeout=_GEN_HTTP_TIMEOUT_MS),
)
try:
resp = await client.aio.models.count_tokens(
model=model,
contents=count_contents,
)
n = resp.total_tokens
if isinstance(n, int):
logger.info("Gemini countTokens: %d tokens", n)
return n
except APIError as e:
last_exc = e
if e.code == 429 and _error_body_daily_quota(e.details):
await mark_key_daily_spent(key, "generate")
if e.code not in (429, 500, 502, 503, 504):
logger.debug(
"count_tokens APIError %s: %s", e.code, e.message,
)
return None
except Exception as e:
last_exc = e
logger.debug("count_tokens failed", exc_info=True)
return None
finally:
await _safe_aclose(client)
if last_exc:
logger.debug("count_tokens exhausted retries: %s", last_exc)
return None
[docs]
async def chat(
self,
messages: list[dict[str, Any]],
user_id: str = "",
ctx: ToolContext | None = None,
tool_names: list[str] | None = None,
validate_header: bool = False,
token_count: int | None = None,
on_intermediate_text: Callable[[str], Awaitable[None]] | None = None,
) -> str:
del validate_header, on_intermediate_text # bulk path unused
msgs = [dict(m) for m in messages]
if token_count is None:
token_count = await self.count_input_tokens(msgs)
_count_str = str(token_count) if token_count is not None else "unavailable"
for _m in msgs:
if _m.get("role") == "system" and "__INPUT_TOKEN_COUNT__" in str(
_m.get("content", ""),
):
_m["content"] = str(_m["content"]).replace(
"__INPUT_TOKEN_COUNT__", _count_str,
)
break
sys_inst, contents = openai_messages_to_gemini(msgs)
if not contents:
return ""
tools = _afc_tool_fns(
self.tool_registry,
user_id=user_id,
ctx=ctx,
max_tool_output_chars=self.max_tool_output_chars,
tool_names=tool_names,
)
afc_limit = max(1, int(self.max_tool_rounds))
config = types.GenerateContentConfig(
system_instruction=sys_inst,
temperature=self.temperature,
max_output_tokens=int(self.max_tokens),
tools=tools,
tool_config=types.ToolConfig(
function_calling_config=types.FunctionCallingConfig(
mode=types.FunctionCallingConfigMode.AUTO,
),
),
automatic_function_calling=types.AutomaticFunctionCallingConfig(
maximum_remote_calls=afc_limit,
),
)
paid_used = False
last_exc: BaseException | None = None
for attempt in range(_MAX_KEY_ATTEMPTS):
key = next_gemini_flash_key()
if attempt >= _MAX_KEY_ATTEMPTS - 2 and not paid_used:
pk = get_paid_fallback_key()
if pk:
key = pk
paid_used = True
if attempt > 0:
delay = 1.0 if attempt <= 12 else min(2.0 ** (attempt - 12), 30.0)
await asyncio.sleep(delay)
client = genai.Client(
api_key=key,
http_options=types.HttpOptions(timeout=_GEN_HTTP_TIMEOUT_MS),
)
t0 = time.monotonic()
try:
logger.info(
"Gemini bulk AFC: attempt=%d model=%s contents=%d",
attempt + 1, self.model_id, len(contents),
)
resp = await client.aio.models.generate_content(
model=self.model_id,
contents=contents,
config=config,
)
logger.info(
"Gemini bulk AFC completed in %.0f ms",
(time.monotonic() - t0) * 1000,
)
text = (getattr(resp, "text", None) or "").strip()
if text.startswith("[PROXY ERROR]"):
logger.warning("Unexpected proxy-style prefix in Gemini response")
return text
except APIError as e:
last_exc = e
if e.code == 429 and _error_body_daily_quota(e.details):
await mark_key_daily_spent(key, "generate")
if e.code not in (429, 500, 502, 503, 504):
raise
finally:
await _safe_aclose(client)
msg = f"Gemini bulk chat failed after {_MAX_KEY_ATTEMPTS} attempts"
if last_exc:
msg += f": {last_exc}"
logger.error(msg)
raise RuntimeError(msg) from last_exc