Source code for gemini_kg_bulk_client
"""Native Gemini (google.genai) client for bulk agentic KG extraction.
Uses :func:`gemini_embed_pool.next_gemini_flash_key` for API key rotation,
``client.aio`` for async I/O, and automatic function calling (AFC) with
thin async callables that delegate to :meth:`tools.ToolRegistry.call`.
"""
from __future__ import annotations
import asyncio
import jsonutil as json
import logging
import time
from typing import Any, Callable, Awaitable
from google import genai
from google.genai import types
from google.genai.errors import APIError
from gemini_embed_pool import (
get_paid_fallback_key,
mark_key_daily_spent,
next_gemini_flash_key,
)
from tool_context import ToolContext
from tools import ToolRegistry
logger = logging.getLogger(__name__)
_MAX_KEY_ATTEMPTS = 32
def _error_body_daily_quota(details: Any) -> bool:
"""True if Gemini error details include a *daily* (PerDay) quota violation.
Avoids relying on ``response.status_code`` / httpx-only parsing: the GenAI
async client may attach aiohttp responses.
"""
if details is None:
return False
body = details
if isinstance(body, list) and body:
body = body[0]
if not isinstance(body, dict):
return False
err = body.get("error", body)
if not isinstance(err, dict):
return False
for detail in err.get("details", []) or []:
if not isinstance(detail, dict):
continue
for violation in detail.get("violations", []) or []:
qid = str(violation.get("quotaId", "") or "")
if "PerDay" in qid:
return True
return False
_GEN_HTTP_TIMEOUT_MS = 600_000
def _truncate_tool_output(result: str, cap: int) -> str:
"""Cap a tool's text output, appending a truncation marker when clipped.
Guards against oversized KG tool results blowing up the Gemini context
window. A non-positive *cap* or an already-short *result* is returned
unchanged; otherwise the string is clipped to *cap* characters and a
human-readable ``[TRUNCATED ...]`` notice (with original/shown char counts)
is appended.
Called by the nested AFC callables ``kg_search_entities``, ``kg_get_entity``,
and ``kg_inspect_entity`` inside :func:`_afc_tool_fns` to bound every raw
tool result before it is handed back to Gemini's automatic function calling.
Args:
result: The raw tool output string to bound.
cap: Maximum number of characters to retain; ``<= 0`` disables capping.
Returns:
str: ``result`` unchanged if within the cap, otherwise the first *cap*
characters followed by a truncation notice.
"""
if cap <= 0 or len(result) <= cap:
return result
orig = len(result)
return (
result[:cap]
+ f"\n\n[TRUNCATED — output was {orig:,} chars, showing first {cap:,}]"
)
def _openai_message_text(content: Any) -> str:
"""Flatten an OpenAI-style message ``content`` field into plain text.
OpenAI chat content may be either a bare string or a list of typed parts
(``{"type": "text", "text": ...}``, image parts, etc.). This extracts only
the textual portion: strings pass through verbatim and list contents have
their ``text`` parts joined with spaces; anything else yields an empty
string. Non-text parts (e.g. images) are intentionally dropped.
Called by :func:`openai_messages_to_gemini` (for system/user/assistant/tool
content) and by :meth:`GeminiPoolToolChatClient.chat` only indirectly via
that conversion path.
Args:
content: The ``content`` value of an OpenAI-style message; a string,
a list of part dicts, or any other value.
Returns:
str: The concatenated text, or ``""`` when no text is present.
"""
if isinstance(content, str):
return content
if isinstance(content, list):
return " ".join(
str(p.get("text", ""))
for p in content
if isinstance(p, dict) and p.get("type") == "text"
)
return ""
[docs]
def openai_messages_to_gemini(
messages: list[dict[str, Any]],
) -> tuple[str | None, list[types.Content]]:
"""Split system instruction and build Gemini ``contents`` (no system in contents).
Translates an OpenAI-style chat history into the two-part shape the
google-genai SDK expects: a single concatenated system instruction (peeled
off only from leading ``system`` turns, since the Gemini developer API takes
system text out of band rather than inside ``contents``) and a list of
:class:`google.genai.types.Content` turns for the user/assistant/tool
messages. Assistant ``tool_calls`` become :class:`types.FunctionCall` parts
(arguments JSON-decoded via the module-aliased :mod:`jsonutil`, imported as
``json``) and ``tool`` results become :class:`types.FunctionResponse` parts,
so an OpenAI-shaped transcript round-trips into Gemini's function-calling
format. Pure transformation with no I/O; relies on :func:`_openai_message_text`
to flatten message content to plain text.
Called by :meth:`GeminiPoolToolChatClient.chat` (to build the contents sent
to ``generate_content``) and by :meth:`GeminiPoolToolChatClient.count_input_tokens`
(before reshaping via :func:`_contents_for_count_tokens`).
Args:
messages: OpenAI-style chat messages with ``role``/``content`` (and, for
assistant turns, optional ``tool_calls``; for tool turns, ``name``).
Returns:
tuple[str | None, list[types.Content]]: The merged system instruction
(``None`` when no leading system text exists) and the Gemini ``contents``
list (empty when there are no non-system turns).
"""
system_chunks: list[str] = []
rest: list[dict[str, Any]] = []
seen_non_system = False
for msg in messages:
role = msg.get("role", "user")
if role == "system" and not seen_non_system:
t = _openai_message_text(msg.get("content", ""))
if t:
system_chunks.append(t)
continue
seen_non_system = True
rest.append(msg)
system_instruction = "\n\n".join(system_chunks) if system_chunks else None
contents: list[types.Content] = []
for msg in rest:
role = msg.get("role", "user")
if role == "user":
text = _openai_message_text(msg.get("content", ""))
if text:
contents.append(
types.Content(
role="user",
parts=[types.Part(text=text)],
),
)
elif role == "assistant":
parts: list[types.Part] = []
t = _openai_message_text(msg.get("content", ""))
if t:
parts.append(types.Part(text=t))
for tc in msg.get("tool_calls") or []:
fn = tc.get("function") or {}
name = str(fn.get("name") or "")
raw = fn.get("arguments")
if isinstance(raw, str):
try:
args = json.loads(raw)
except json.JSONDecodeError:
args = {}
elif isinstance(raw, dict):
args = raw
else:
args = {}
if not isinstance(args, dict):
args = {}
parts.append(
types.Part(
function_call=types.FunctionCall(name=name, args=args),
),
)
if parts:
contents.append(types.Content(role="model", parts=parts))
elif role == "tool":
fn_name = msg.get("name") or ""
if not fn_name:
continue
raw_out = msg.get("content", "")
if isinstance(raw_out, list):
text_out = _openai_message_text(raw_out)
else:
text_out = str(raw_out)
try:
parsed: Any = json.loads(text_out)
resp_body: dict[str, Any] = (
parsed if isinstance(parsed, dict) else {"output": text_out}
)
except json.JSONDecodeError:
resp_body = {"output": text_out}
contents.append(
types.Content(
role="user",
parts=[
types.Part(
function_response=types.FunctionResponse(
name=str(fn_name),
response=resp_body,
),
),
],
),
)
return system_instruction, contents
def _contents_for_count_tokens(
sys_inst: str | None,
contents: list[types.Content],
) -> list[types.Content]:
"""Build ``contents`` for :meth:`Models.count_tokens`.
The Gemini developer API rejects ``system_instruction`` on ``countTokens``.
Prepend the system text as the first part of the opening turn so totals are
close to ``generate_content`` with ``system_instruction`` set.
"""
if not sys_inst or not str(sys_inst).strip():
return contents
if not contents:
return [
types.Content(
role="user",
parts=[types.Part(text=str(sys_inst))],
),
]
first = contents[0]
merged_parts: list[types.Part] = [types.Part(text=str(sys_inst))]
for p in first.parts or []:
merged_parts.append(p)
return [
types.Content(role=first.role, parts=merged_parts),
*contents[1:],
]
def _afc_tool_fns(
registry: ToolRegistry,
*,
user_id: str,
ctx: ToolContext | None,
max_tool_output_chars: int,
tool_names: list[str] | None,
) -> list[Any]:
"""Build the thin async callables Gemini AFC invokes for KG read tools.
Defines three nested async wrappers (``kg_search_entities``,
``kg_get_entity``, ``kg_inspect_entity``) whose signatures and names are
introspected by the google-genai SDK to generate function-call schemas.
Each wrapper forwards to :meth:`tools.ToolRegistry.call` with the captured
*user_id*/*ctx* and caps the result via :func:`_truncate_tool_output`. When
*tool_names* is provided, the returned list is filtered to only those names
(matched on ``__name__``); otherwise all three are returned.
Called by :meth:`GeminiPoolToolChatClient.chat` to assemble the ``tools=``
list passed to ``GenerateContentConfig`` for automatic function calling.
Args:
registry: The tool registry whose ``call`` executes the real KG tools.
user_id: Identity threaded into every ``registry.call`` for scoping.
ctx: Optional :class:`tool_context.ToolContext` (carries the KG manager,
channel id, etc.) passed to each tool invocation.
max_tool_output_chars: Per-call character cap applied to tool output.
tool_names: Optional allow-list of tool names to expose; ``None`` exposes
all KG read tools.
Returns:
list[Any]: The async callables to register as Gemini AFC tools.
"""
allowed = set(tool_names) if tool_names else None
async def kg_search_entities(
query: str,
category: str = "",
scope_id: str = "",
top_k: int = 12,
) -> str:
"""Run the ``kg_search_entities`` KG tool for Gemini AFC.
Thin wrapper invoked automatically by the google-genai SDK when the
model emits a ``kg_search_entities`` function call; delegates to
:meth:`tools.ToolRegistry.call` and caps the result via
:func:`_truncate_tool_output`. Defined inside :func:`_afc_tool_fns`,
closing over *registry*, *user_id*, *ctx*, and *max_tool_output_chars*.
Args:
query: Free-text entity search query.
category: Optional entity category filter.
scope_id: Optional scope (e.g. channel/user) to constrain results.
top_k: Maximum number of matches to return.
Returns:
str: The (possibly truncated) tool output.
"""
raw = await registry.call(
"kg_search_entities",
{
"query": query,
"category": category,
"scope_id": scope_id,
"top_k": top_k,
},
user_id=user_id,
ctx=ctx,
)
return _truncate_tool_output(raw, max_tool_output_chars)
async def kg_get_entity(
name: str = "",
uuid: str = "",
) -> str:
"""Run the ``kg_get_entity`` KG tool for Gemini AFC.
Thin wrapper invoked automatically by the google-genai SDK when the
model emits a ``kg_get_entity`` function call; delegates to
:meth:`tools.ToolRegistry.call` and caps the result via
:func:`_truncate_tool_output`. Defined inside :func:`_afc_tool_fns`,
closing over *registry*, *user_id*, *ctx*, and *max_tool_output_chars*.
Args:
name: Entity name to look up (used when *uuid* is empty).
uuid: Entity UUID to look up (takes precedence over *name*).
Returns:
str: The (possibly truncated) tool output.
"""
raw = await registry.call(
"kg_get_entity",
{"name": name, "uuid": uuid},
user_id=user_id,
ctx=ctx,
)
return _truncate_tool_output(raw, max_tool_output_chars)
async def kg_inspect_entity(
name: str = "",
uuid: str = "",
max_depth: int = 2,
) -> str:
"""Run the ``kg_inspect_entity`` KG tool for Gemini AFC.
Thin wrapper invoked automatically by the google-genai SDK when the
model emits a ``kg_inspect_entity`` function call; delegates to
:meth:`tools.ToolRegistry.call` and caps the result via
:func:`_truncate_tool_output`. Defined inside :func:`_afc_tool_fns`,
closing over *registry*, *user_id*, *ctx*, and *max_tool_output_chars*.
Args:
name: Entity name to inspect (used when *uuid* is empty).
uuid: Entity UUID to inspect (takes precedence over *name*).
max_depth: Relationship traversal depth around the entity.
Returns:
str: The (possibly truncated) tool output.
"""
raw = await registry.call(
"kg_inspect_entity",
{"name": name, "uuid": uuid, "max_depth": max_depth},
user_id=user_id,
ctx=ctx,
)
return _truncate_tool_output(raw, max_tool_output_chars)
all_tools = [kg_search_entities, kg_get_entity, kg_inspect_entity]
if allowed is None:
return all_tools
return [f for f in all_tools if f.__name__ in allowed]
async def _safe_aclose(client: genai.Client) -> None:
"""Close a genai client's async transport, swallowing any error.
Each attempt in the retry loops constructs a fresh :class:`genai.Client`;
this guarantees its underlying ``aio`` HTTP session is released in the
``finally`` block even on failure. Errors are logged at debug level and
never propagated, so cleanup cannot mask the real call result/exception.
Called from the ``finally`` clauses of
:meth:`GeminiPoolToolChatClient.count_input_tokens` and
:meth:`GeminiPoolToolChatClient.chat`.
Args:
client: The per-attempt genai client to close.
"""
try:
await client.aio.aclose()
except Exception:
logger.debug("genai client aclose failed", exc_info=True)
[docs]
class GeminiPoolToolChatClient:
"""Async Gemini chat with pool keys, countTokens, and AFC tool execution.
Drop-in chat client for bulk knowledge-graph extraction that talks to the
native Gemini API (``google.genai``) instead of OpenRouter, exposing the same
``chat``/``count_input_tokens``/``close`` surface as the pooled OpenRouter
client so :mod:`kg_agentic_extraction` can swap backends transparently. Each
request constructs a short-lived :class:`google.genai.Client` keyed by a
rotating pool key from :mod:`gemini_embed_pool`, runs ``generate_content``
with automatic function calling so Gemini drives the KG read tools itself
(delegated to :meth:`tools.ToolRegistry.call` via the wrappers built in
:func:`_afc_tool_fns`), and tears the client down through
:func:`_safe_aclose`. The KG/FalkorDB side effects are produced indirectly by
those tools; this class itself holds no long-lived session.
Instances are built by
:func:`kg_agentic_extraction.create_kg_bulk_gemini_pool_client`, which passes
the bulk tool registry and native Gemini model id.
"""
[docs]
def __init__(
self,
*,
tool_registry: ToolRegistry,
model_id: str,
max_tool_rounds: int = 48,
max_tokens: int = 60_000,
max_tool_output_chars: int = 3_000_000,
temperature: float = 0.25,
) -> None:
"""Store the tool registry, model id, and generation/AFC limits.
Holds configuration only; no network/SDK clients are created here (each
request constructs a short-lived :class:`genai.Client`). Instances are
built by :func:`kg_agentic_extraction.create_kg_bulk_gemini_pool_client`,
which supplies the KG bulk tool registry and the native Gemini model id.
Args:
tool_registry: Registry whose tools the AFC wrappers delegate to.
model_id: Native Gemini model id (no ``google/`` prefix).
max_tool_rounds: Cap on AFC remote function-call rounds per request.
max_tokens: ``max_output_tokens`` for generation.
max_tool_output_chars: Per-call character cap for tool output.
temperature: Sampling temperature for generation.
"""
self.tool_registry = tool_registry
self.model_id = model_id
self.max_tool_rounds = max_tool_rounds
self.max_tokens = max_tokens
self.max_tool_output_chars = max_tool_output_chars
self.temperature = temperature
[docs]
async def close(self) -> None:
"""No-op async close for interface parity with the OpenRouter client.
This client holds no long-lived session (genai clients are created and
closed per request via :func:`_safe_aclose`), so there is nothing to
tear down. Provided so callers can treat it like the pooled OpenRouter
client; invoked during cleanup in :mod:`kg_bulk_runner` (e.g. the
``bulk_client.close()`` / ``token_counter.close()`` calls).
"""
return
[docs]
async def count_input_tokens(
self,
messages: list[dict[str, Any]],
*,
gemini_model: str | None = None,
) -> int | None:
"""Count Gemini input tokens for OpenAI-style *messages* via countTokens.
Converts *messages* into a system instruction plus Gemini ``contents``
using :func:`openai_messages_to_gemini`, then reshapes them with
:func:`_contents_for_count_tokens` (the developer API rejects
``system_instruction`` on ``countTokens``, so the system text is
folded into the opening turn). It rotates pool keys via
:func:`gemini_embed_pool.next_gemini_flash_key`, escalating to
:func:`gemini_embed_pool.get_paid_fallback_key` on the final two
attempts, and issues ``client.aio.models.count_tokens`` against a
short-lived :class:`genai.Client` that is always released through
:func:`_safe_aclose`. On a 429 attributable to a daily-quota
violation (per :func:`_error_body_daily_quota`) the spent key is marked
with :func:`gemini_embed_pool.mark_key_daily_spent`; retryable codes
(429/500/502/503/504) back off and retry while other ``APIError`` codes
and unexpected exceptions abort with ``None``.
Called by :meth:`chat` (when its ``token_count`` is ``None``) to fill
the system prompt's ``__INPUT_TOKEN_COUNT__`` placeholder, and by
:mod:`kg_bulk_runner` as a standalone token counter.
Args:
messages: OpenAI-style chat messages to measure.
gemini_model: Optional model id override; defaults to
``self.model_id``.
Returns:
int | None: The total input-token count, or ``None`` when there
are no non-system contents, the API returns a non-integer total,
or every key attempt is exhausted without success.
"""
model = gemini_model or self.model_id
sys_inst, contents = openai_messages_to_gemini(messages)
if not contents:
return None
count_contents = _contents_for_count_tokens(sys_inst, contents)
paid_used = False
last_exc: BaseException | None = None
for attempt in range(_MAX_KEY_ATTEMPTS):
key = next_gemini_flash_key()
if attempt >= _MAX_KEY_ATTEMPTS - 2 and not paid_used:
pk = get_paid_fallback_key()
if pk:
key = pk
paid_used = True
if attempt > 0:
delay = 1.0 if attempt <= 12 else min(2.0 ** (attempt - 12), 30.0)
await asyncio.sleep(delay)
client = genai.Client(
api_key=key,
http_options=types.HttpOptions(timeout=_GEN_HTTP_TIMEOUT_MS),
)
try:
resp = await client.aio.models.count_tokens(
model=model,
contents=count_contents,
)
n = resp.total_tokens
if isinstance(n, int):
logger.info("Gemini countTokens: %d tokens", n)
return n
except APIError as e:
last_exc = e
if e.code == 429 and _error_body_daily_quota(e.details):
await mark_key_daily_spent(key, "generate")
if e.code not in (429, 500, 502, 503, 504):
logger.debug(
"count_tokens APIError %s: %s",
e.code,
e.message,
)
return None
except Exception as e:
last_exc = e
logger.debug("count_tokens failed", exc_info=True)
return None
finally:
await _safe_aclose(client)
if last_exc:
logger.debug("count_tokens exhausted retries: %s", last_exc)
return None
[docs]
async def chat(
self,
messages: list[dict[str, Any]],
user_id: str = "",
ctx: ToolContext | None = None,
tool_names: list[str] | None = None,
validate_header: bool = False,
token_count: int | None = None,
on_intermediate_text: Callable[[str], Awaitable[None]] | None = None,
) -> str:
"""Run one bulk agentic generation with Gemini automatic function calling.
Converts OpenAI-style *messages* to Gemini ``contents`` plus a system
instruction, substitutes the resolved input-token count into the system
prompt's ``__INPUT_TOKEN_COUNT__`` placeholder, and issues a single
``generate_content`` call with AFC enabled so Gemini can drive the KG
read tools itself (up to ``max_tool_rounds`` rounds). The model's final
text is returned. *validate_header* and *on_intermediate_text* are
accepted for interface parity with the OpenRouter client but unused on
this bulk path.
Calls :meth:`count_input_tokens` (when *token_count* is ``None``),
:func:`openai_messages_to_gemini`, and :func:`_afc_tool_fns` to build the
tool callables, then rotates pool keys via
:func:`gemini_embed_pool.next_gemini_flash_key` (escalating to
:func:`gemini_embed_pool.get_paid_fallback_key` on the last two
attempts) across a short-lived :class:`genai.Client` per attempt, each
closed with :func:`_safe_aclose`. On a 429 caused by a daily-quota
violation (per :func:`_error_body_daily_quota`) it marks the key spent
via :func:`gemini_embed_pool.mark_key_daily_spent`; retryable codes
(429/500/502/503/504) trigger backoff, other ``APIError`` codes
re-raise immediately. The KG side effects are produced indirectly by the
tools invoked through AFC.
Called by :func:`kg_agentic_extraction.run_agentic_kg_extraction_chunk`
for each conversation chunk during bulk KG extraction.
Args:
messages: OpenAI-style chat messages (system/user/assistant/tool).
user_id: Identity threaded into AFC tool calls.
ctx: Optional :class:`tool_context.ToolContext` for tool execution
(carries the KG manager, channel id, etc.).
tool_names: Optional allow-list restricting which KG tools are
exposed to AFC.
validate_header: Ignored on this bulk path (interface parity).
token_count: Precomputed input-token count; counted on demand when
``None``.
on_intermediate_text: Ignored on this bulk path (interface parity).
Returns:
str: The model's final response text, or ``""`` when there are no
non-system contents to send.
Raises:
APIError: Propagated for non-retryable Gemini API error codes.
RuntimeError: When all ``_MAX_KEY_ATTEMPTS`` attempts are exhausted
without a successful response.
"""
del validate_header, on_intermediate_text # bulk path unused
msgs = [dict(m) for m in messages]
if token_count is None:
token_count = await self.count_input_tokens(msgs)
_count_str = str(token_count) if token_count is not None else "unavailable"
for _m in msgs:
if _m.get("role") == "system" and "__INPUT_TOKEN_COUNT__" in str(
_m.get("content", ""),
):
_m["content"] = str(_m["content"]).replace(
"__INPUT_TOKEN_COUNT__",
_count_str,
)
break
sys_inst, contents = openai_messages_to_gemini(msgs)
if not contents:
return ""
tools = _afc_tool_fns(
self.tool_registry,
user_id=user_id,
ctx=ctx,
max_tool_output_chars=self.max_tool_output_chars,
tool_names=tool_names,
)
afc_limit = max(1, int(self.max_tool_rounds))
config = types.GenerateContentConfig(
system_instruction=sys_inst,
temperature=self.temperature,
max_output_tokens=int(self.max_tokens),
tools=tools,
tool_config=types.ToolConfig(
function_calling_config=types.FunctionCallingConfig(
mode=types.FunctionCallingConfigMode.AUTO,
),
),
automatic_function_calling=types.AutomaticFunctionCallingConfig(
maximum_remote_calls=afc_limit,
),
)
paid_used = False
last_exc: BaseException | None = None
for attempt in range(_MAX_KEY_ATTEMPTS):
key = next_gemini_flash_key()
if attempt >= _MAX_KEY_ATTEMPTS - 2 and not paid_used:
pk = get_paid_fallback_key()
if pk:
key = pk
paid_used = True
if attempt > 0:
delay = 1.0 if attempt <= 12 else min(2.0 ** (attempt - 12), 30.0)
await asyncio.sleep(delay)
client = genai.Client(
api_key=key,
http_options=types.HttpOptions(timeout=_GEN_HTTP_TIMEOUT_MS),
)
t0 = time.monotonic()
try:
logger.info(
"Gemini bulk AFC: attempt=%d model=%s contents=%d",
attempt + 1,
self.model_id,
len(contents),
)
resp = await client.aio.models.generate_content(
model=self.model_id,
contents=contents,
config=config,
)
logger.info(
"Gemini bulk AFC completed in %.0f ms",
(time.monotonic() - t0) * 1000,
)
text = (getattr(resp, "text", None) or "").strip()
if text.startswith("[PROXY ERROR]"):
logger.warning("Unexpected proxy-style prefix in Gemini response")
return text
except APIError as e:
last_exc = e
if e.code == 429 and _error_body_daily_quota(e.details):
await mark_key_daily_spent(key, "generate")
if e.code not in (429, 500, 502, 503, 504):
raise
finally:
await _safe_aclose(client)
msg = f"Gemini bulk chat failed after {_MAX_KEY_ATTEMPTS} attempts"
if last_exc:
msg += f": {last_exc}"
logger.error(msg)
raise RuntimeError(msg) from last_exc