Source code for kg_agentic_extraction

"""Agentic knowledge-graph extraction for bulk chat import.

Bulk agentic chat can use native Gemini (pool keys + :mod:`gemini_kg_bulk_client`)
or OpenRouter (:func:`create_kg_bulk_openrouter_client`).  A small read-only tool
set is backed by :class:`~knowledge_graph.KnowledgeGraphManager`.
"""

from __future__ import annotations

import hashlib
import json
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Protocol

from jinja2.sandbox import SandboxedEnvironment
from jinja2 import FileSystemLoader
from prompt_renderer import sanitize_context
from tool_context import ToolContext
from tools import ToolRegistry

from kg_extraction import _parse_llm_json, apply_parsed_extraction
from openrouter_client import OpenRouterClient

if TYPE_CHECKING:
    from config import Config
    from gemini_kg_bulk_client import GeminiPoolToolChatClient
    from knowledge_graph import KnowledgeGraphManager

logger = logging.getLogger(__name__)


[docs] class KgBulkLlmClient(Protocol): """Structural type for bulk chunking + agentic KG (OpenRouter or native Gemini)."""
[docs] async def chat( self, messages: list[dict[str, Any]], user_id: str = "", ctx: ToolContext | None = None, tool_names: list[str] | None = None, validate_header: bool = False, token_count: int | None = None, on_intermediate_text: Callable[[str], Awaitable[None]] | None = None, ) -> str: ...
[docs] async def count_input_tokens( self, messages: list[dict[str, Any]], *, gemini_model: str | None = None, ) -> int | None: ...
[docs] async def close(self) -> None: ...
KG_BULK_OPENROUTER_BASE = "https://openrouter.ai/api/v1" KG_BULK_CHAT_MODEL = "google/gemini-3.1-flash-lite-preview" KG_BULK_TOOL_NAMES = [ "kg_search_entities", "kg_get_entity", "kg_inspect_entity", ] _SYSTEM_PROMPT_CACHE: dict[str, str] = {} def _print_dry_run_llm_output( *, channel_id: str, chunk_index: int, raw: str, ) -> None: """Print final model text to stdout when extraction is not persisted.""" body = (raw or "").strip() if not body: return sep = "=" * 72 print( f"\n{sep}\n" f"KG agentic dry-run — LLM output " f"(chunk {chunk_index}, scope {channel_id})\n" f"{sep}\n" f"{body}\n" f"{sep}\n", flush=True, ) def _project_root() -> Path: return Path(__file__).resolve().parent def _is_secret_setting_key(key: str) -> bool: lk = key.lower() for frag in ( "token", "password", "secret", "credential", "api_key", "client_secret", ): if frag in lk: return True return False def _platform_type_is_ncm(ptype: str) -> bool: """True if this platform entry is NCM / neurochemical subsystem (not chat).""" return "ncm" in (ptype or "").strip().lower() def _settings_key_is_ncm_related(key: str) -> bool: return "ncm" in str(key).lower()
[docs] def build_platform_context_markdown(cfg: Any | None) -> str: """Human-readable platform summary for the system prompt (no secrets).""" if cfg is None: return ( "(No configuration passed — treat `platform` / `channel_id` / " "`user_id` in logs as opaque platform-specific identifiers.)" ) platforms = getattr(cfg, "platforms", None) or [] if not platforms: lines_out = [ "(No `platforms:` entries in config.yaml — if legacy Matrix " "fields exist, only the Matrix bot may be active.)", ] hs = getattr(cfg, "homeserver", "") or "" uid = getattr(cfg, "user_id", "") or "" if hs: lines_out.append(f"- Default Matrix homeserver URL: `{hs}`") if uid: lines_out.append( f"- Legacy Matrix bot user id (format reference): `{uid}`", ) return "\n".join(lines_out) lines: list[str] = [] idx = 0 for p in platforms: if not getattr(p, "enabled", True): continue ptype = getattr(p, "type", "") or "unknown" if _platform_type_is_ncm(ptype): continue idx += 1 lines.append(f"### Platform #{idx}: `{ptype}`") st = { k: v for k, v in (p.settings or {}).items() if not _is_secret_setting_key(str(k)) and not _settings_key_is_ncm_related(str(k)) } if ptype == "matrix": hs = st.get("homeserver") or getattr(cfg, "homeserver", "") if hs: lines.append(f"- Matrix homeserver (public URL): `{hs}`") bot_uid = st.get("user_id") or getattr(cfg, "user_id", "") if bot_uid: lines.append( "- Matrix user ids in logs look like " "`@localpart:domain`." ) lines.append( f"- This deployment's bot Matrix id: `{bot_uid}`", ) elif ptype in ("discord", "discord-self"): lines.append( "- Discord numeric ids in logs (`user_id`, `channel_id`, " "etc.) are **snowflakes** (large integers as strings)." ) for key in ( "application_id", "guild_id", "primary_guild_id", "default_guild_id", ): if key in st and st[key]: lines.append(f"- `{key}`: `{st[key]}`") skip_keys = { "homeserver", "user_id", "store_path", "credentials_file", "password", } for k, v in sorted(st.items()): if k in skip_keys: continue if isinstance(v, (dict, list)): lines.append( f"- `{k}`: _(structured value; omitted from prompt)_", ) else: lines.append(f"- `{k}`: `{v}`") if not lines: enabled = [x for x in platforms if getattr(x, "enabled", True)] if enabled and all( _platform_type_is_ncm(getattr(x, "type", "") or "") for x in enabled ): return ( "(Only NCM / non-chat platform entries are configured — omitted " "from KG extraction context; treat ids in logs as opaque.)" ) return "(All platforms disabled in configuration.)" return "\n".join(lines)
def _system_template_path() -> Path: return _project_root() / "prompts" / "kg_agentic_extraction_system.j2"
[docs] def render_kg_agentic_system_prompt(cfg: Any | None = None) -> str: """Render the Jinja2 system prompt including platform context.""" path = _system_template_path() if not path.is_file(): logger.warning("Missing %s — using fallback system prompt", path) return ( "You extract knowledge graphs from chat. Use tools to search " "the graph before creating entities. Final message: JSON only.\n\n" + build_platform_context_markdown(cfg) ) env = SandboxedEnvironment( loader=FileSystemLoader(str(path.parent)), autoescape=False, ) template = env.get_template(path.name) ctx = sanitize_context( {"platform_context": build_platform_context_markdown(cfg)}, ) return template.render(**ctx)
def _system_prompt_cache_key(cfg: Any | None) -> str: path = _system_template_path() try: mtime = path.stat().st_mtime except OSError: mtime = 0.0 pc = build_platform_context_markdown(cfg) hints = "" if cfg is not None and hasattr(cfg, "kg_extraction_channel_hints"): hints = json.dumps( getattr(cfg, "kg_extraction_channel_hints", {}), sort_keys=True, ) h = hashlib.sha256(f"{pc}\n{hints}".encode()).hexdigest()[:24] return f"{mtime:.6f}:{h}"
[docs] def load_kg_agentic_system_prompt(config: Any | None = None) -> str: """Rendered system prompt (cached per template mtime + config fingerprint).""" key = _system_prompt_cache_key(config) if key not in _SYSTEM_PROMPT_CACHE: _SYSTEM_PROMPT_CACHE[key] = render_kg_agentic_system_prompt(config) return _SYSTEM_PROMPT_CACHE[key]
[docs] def format_chunk_channels_section( channel_pairs: list[tuple[str, str]], cfg: Any | None, default_channel_scope: str, channel_metadata: dict[str, dict[str, str]] | None = None, ) -> str: """Describe which rooms/sources appear in this chunk.""" meta = channel_metadata or {} if not channel_pairs: return ( "## Channels in this chunk\n" f"- _(none resolved — use default_channel_scope_id)_: " f"`{default_channel_scope}`" ) lines = ["## Channels in this chunk"] hints: dict[str, str] = {} if cfg is not None and hasattr(cfg, "kg_extraction_channel_hints"): hints = dict(getattr(cfg, "kg_extraction_channel_hints", {}) or {}) for plat, cid in channel_pairs: key = f"{plat}:{cid}" hint = hints.get(key, "") line = f"- **{plat}** / channel id `{cid}`" if hint: line += f" — _{hint}_" row = meta.get(key) or {} cname = (row.get("name") or "").strip() topic = (row.get("topic") or "").strip() if cname: line += f"\n - resolved name: **{cname}**" if topic: line += f"\n - topic: {topic}" lines.append(line) lines.append( f"- **default_channel_scope_id** (chunk default for scoped facts): " f"`{default_channel_scope}`", ) return "\n".join(lines)
[docs] def format_chunk_speakers_section( speaker_pairs: list[tuple[str, str]], ) -> str: """Unique speakers (user_id, display name) in this chunk, sorted by id.""" if not speaker_pairs: return "## Speakers in this chunk\n- _(none)_" uniq: dict[str, str] = {} for uid, name in speaker_pairs: u = (uid or "").strip() if not u: continue n = (name or "").strip() or "?" if u not in uniq: uniq[u] = n if not uniq: return "## Speakers in this chunk\n- _(none)_" lines = ["## Speakers in this chunk"] for uid in sorted(uniq.keys()): lines.append(f"- `{uid}` — **{uniq[uid]}**") return "\n".join(lines)
[docs] def format_speaker_user_id_mapping_markdown( speaker_pairs: list[tuple[str, str]], ) -> str: """Markdown table: user_id → display name for the current chunk (system prompt).""" uniq: dict[str, str] = {} for uid, name in speaker_pairs or []: u = (uid or "").strip() if not u: continue n = (name or "").strip() or "?" uniq.setdefault(u, n) if not uniq: return "" lines = [ "## User ID → display name (this chunk)", "", "Use this table when interpreting sender ids in the log lines below:", "", "| `user_id` | display name |", "| --- | --- |", ] for uid in sorted(uniq.keys()): nm = uniq[uid].replace("|", "\\|") lines.append(f"| `{uid}` | {nm} |") return "\n".join(lines)
[docs] def augment_system_prompt_with_speaker_mapping( system: str, speaker_pairs: list[tuple[str, str]] | None, ) -> str: block = format_speaker_user_id_mapping_markdown(list(speaker_pairs or [])) if not block.strip(): return system return system.rstrip() + "\n\n" + block + "\n"
[docs] async def prefetch_speaker_kg_context( kg: Any, speakers: list[tuple[str, str]], *, max_speakers: int = 8, hits_per_speaker: int = 3, min_score: float = 0.0, ) -> str: """Vector search prefetch for chunk speakers; full entity text for user prompt.""" if not speakers: return "" seen_uuid: set[str] = set() blocks: list[str] = [] cap = max(1, min(128, int(max_speakers))) hits = max(1, min(48, int(hits_per_speaker))) uniq_speakers: dict[str, str] = {} for uid, name in speakers: u = (uid or "").strip() if not u: continue uniq_speakers.setdefault(u, (name or "").strip() or "?") for uid in sorted(uniq_speakers.keys())[:cap]: name = uniq_speakers[uid] hits_list: list[dict[str, Any]] = [] try: by_scope = await kg.search_entities( name, category="user", scope_id=uid, top_k=hits, ) hits_list.extend(by_scope or []) except Exception: logger.debug("prefetch user scope search failed", exc_info=True) try: general = await kg.search_entities( f"{name} {uid}", top_k=hits, ) hits_list.extend(general or []) except Exception: logger.debug("prefetch general search failed", exc_info=True) rows: list[str] = [] for ent in hits_list: sc = float(ent.get("score") or 0.0) if sc < min_score: continue uu = str(ent.get("uuid") or "") if not uu or uu in seen_uuid: continue seen_uuid.add(uu) nm = str(ent.get("name") or "") et = str(ent.get("type") or "") desc = str(ent.get("description") or "") cat = str(ent.get("category") or "") sid = str(ent.get("scope_id") or "") rows.append( f" - `{nm}` ({et}, cat={cat}, scope={sid}, " f"score={sc:.3f}, uuid={uu})" + (f": {desc}" if desc else ""), ) if rows: blocks.append(f"**Speaker `{uid}` ({name})**:\n" + "\n".join(rows)) if not blocks: return "" header = ( "## Existing knowledge graph (speakers — prefetch)\n" "_Heuristic vector matches only; may be incomplete or noisy — " "use kg_search_entities / kg_get_entity to verify._\n\n" ) body = "\n\n".join(blocks) return header + body
[docs] def build_kg_bulk_user_message( conversation_text: str, *, channel_id: str, chunk_index: int, time_start_iso: str = "", time_end_iso: str = "", platforms_channels: str = "", config: Any | None = None, chunk_channel_pairs: list[tuple[str, str]] | None = None, chunk_speaker_pairs: list[tuple[str, str]] | None = None, speaker_kg_prefetch: str = "", channel_metadata: dict[str, dict[str, str]] | None = None, ) -> str: """User message: metadata, channel context, speakers, optional prefetch.""" pairs = sorted(set(chunk_channel_pairs or [])) chan_block = format_chunk_channels_section( pairs, config, channel_id, channel_metadata=channel_metadata, ) sp_block = format_chunk_speakers_section( list(chunk_speaker_pairs or []), ) lines = [ "## Chunk metadata", f"- chunk_index: {chunk_index}", f"- default_channel_scope_id: {channel_id}", f"- time_range_utc: {time_start_iso}{time_end_iso}", ] if platforms_channels: lines.append( f"- all platforms/channels in full corpus: {platforms_channels}", ) lines.extend([ "", chan_block, "", sp_block, ]) if (speaker_kg_prefetch or "").strip(): lines.extend(["", (speaker_kg_prefetch or "").strip()]) lines.extend([ "", "## Conversation (chronological)", conversation_text, "", "Now produce the final JSON extraction per system instructions.", ]) return "\n".join(lines)
[docs] def messages_for_agentic_token_estimate( conversation_text: str, *, channel_id: str, chunk_index: int = 0, time_start_iso: str = "", time_end_iso: str = "", platforms_channels: str = "", config: Any | None = None, chunk_channel_pairs: list[tuple[str, str]] | None = None, chunk_speaker_pairs: list[tuple[str, str]] | None = None, speaker_kg_prefetch: str = "", channel_metadata: dict[str, dict[str, str]] | None = None, ) -> list[dict[str, str]]: """OpenAI-shaped messages for Gemini countTokens (same shape as a run).""" system = augment_system_prompt_with_speaker_mapping( load_kg_agentic_system_prompt(config), chunk_speaker_pairs, ) user = build_kg_bulk_user_message( conversation_text, channel_id=channel_id, chunk_index=chunk_index, time_start_iso=time_start_iso, time_end_iso=time_end_iso, platforms_channels=platforms_channels, config=config, chunk_channel_pairs=chunk_channel_pairs, chunk_speaker_pairs=chunk_speaker_pairs, speaker_kg_prefetch=speaker_kg_prefetch, channel_metadata=channel_metadata, ) return [ {"role": "system", "content": system}, {"role": "user", "content": user}, ]
[docs] def build_kg_bulk_tool_registry() -> ToolRegistry: """Read-only KG tools for the bulk extraction agent.""" reg = ToolRegistry(task_manager=None) @reg.tool( name="kg_search_entities", description=( "Semantic search over knowledge-graph entities. " "Use short queries (names, projects, topics from the chat)." ), parameters={ "type": "object", "properties": { "query": {"type": "string"}, "category": { "type": "string", "description": "Optional: user, channel, general, basic", }, "scope_id": {"type": "string"}, "top_k": {"type": "integer", "description": "Max hits (default 12)"}, }, "required": ["query"], }, ) async def kg_search_entities( query: str, category: str = "", scope_id: str = "", top_k: int = 12, ctx: ToolContext | None = None, ) -> str: if ctx is None or ctx.kg_manager is None: return json.dumps({"success": False, "error": "kg unavailable"}) kg = ctx.kg_manager tk = max(1, min(24, int(top_k))) try: results = await kg.search_entities( query, category=category or None, scope_id=scope_id or None, top_k=tk, ) return json.dumps( {"success": True, "count": len(results), "results": results}, default=str, ) except Exception as e: return json.dumps({"success": False, "error": str(e)}) @reg.tool( name="kg_get_entity", description=( "Look up one entity by name and/or uuid; includes immediate " "relationship summaries." ), parameters={ "type": "object", "properties": { "name": {"type": "string"}, "uuid": {"type": "string"}, }, "required": [], }, ) async def kg_get_entity( name: str = "", uuid: str = "", ctx: ToolContext | None = None, ) -> str: if ctx is None or ctx.kg_manager is None: return json.dumps({"success": False, "error": "kg unavailable"}) if not (name or "").strip() and not (uuid or "").strip(): return json.dumps({ "success": False, "error": "Provide name or uuid", }) kg = ctx.kg_manager try: ent = await kg.get_entity( name=(name or "").strip() or None, uuid=(uuid or "").strip() or None, ) if ent: return json.dumps({"success": True, "entity": ent}, default=str) return json.dumps({"success": False, "error": "Entity not found"}) except Exception as e: return json.dumps({"success": False, "error": str(e)}) @reg.tool( name="kg_inspect_entity", description=( "Deep inspection: entity plus inbound/outbound edges, " "optional 2-hop neighborhood. max_depth 1 or 2." ), parameters={ "type": "object", "properties": { "name": {"type": "string"}, "uuid": {"type": "string"}, "max_depth": {"type": "integer"}, }, "required": [], }, ) async def kg_inspect_entity( name: str = "", uuid: str = "", max_depth: int = 2, ctx: ToolContext | None = None, ) -> str: if ctx is None or ctx.kg_manager is None: return json.dumps({"success": False, "error": "kg unavailable"}) if not (name or "").strip() and not (uuid or "").strip(): return json.dumps({ "success": False, "error": "Provide name or uuid", }) kg = ctx.kg_manager depth = max(1, min(2, int(max_depth))) try: result = await kg.inspect_entity( name=(name or "").strip() or None, uuid=(uuid or "").strip() or None, max_depth=depth, neighbor_limit=30, ) if result: return json.dumps({"success": True, **result}, default=str) return json.dumps({"success": False, "error": "Entity not found"}) except Exception as e: return json.dumps({"success": False, "error": str(e)}) return reg
[docs] def kg_bulk_native_model_id() -> str: """Gemini API model id (strip OpenRouter ``google/`` prefix).""" return KG_BULK_CHAT_MODEL.removeprefix("google/")
[docs] def create_kg_bulk_gemini_pool_client( *, max_tool_rounds: int = 48, max_tokens: int = 60_000, max_tool_output_chars: int = 3_000_000, temperature: float = 0.25, ) -> GeminiPoolToolChatClient: """Native Gemini via embed pool keys + AFC (:mod:`gemini_kg_bulk_client`).""" from gemini_kg_bulk_client import GeminiPoolToolChatClient return GeminiPoolToolChatClient( tool_registry=build_kg_bulk_tool_registry(), model_id=kg_bulk_native_model_id(), max_tool_rounds=max_tool_rounds, max_tokens=max_tokens, max_tool_output_chars=max_tool_output_chars, temperature=temperature, )
[docs] def create_kg_bulk_openrouter_client( api_key: str, *, gemini_api_key: str = "", max_tool_rounds: int = 48, max_tokens: int = 60_000, max_tool_output_chars: int = 3_000_000, temperature: float = 0.25, ) -> OpenRouterClient: """OpenRouter client: production endpoint + gemini-3.1-flash-lite-preview.""" return OpenRouterClient( api_key=api_key, model=KG_BULK_CHAT_MODEL, temperature=temperature, max_tokens=max_tokens, tool_registry=build_kg_bulk_tool_registry(), max_tool_rounds=max_tool_rounds, base_url=KG_BULK_OPENROUTER_BASE, gemini_api_key=gemini_api_key, max_tool_output_chars=max_tool_output_chars, )
[docs] async def run_agentic_kg_extraction_chunk( *, conversation_text: str, channel_id: str, kg_manager: KnowledgeGraphManager, bulk_client: KgBulkLlmClient, user_id: str = "000000000000", chunk_index: int = 0, time_start_iso: str = "", time_end_iso: str = "", platforms_channels: str = "", config: Any | None = None, chunk_channel_pairs: list[tuple[str, str]] | None = None, chunk_speaker_pairs: list[tuple[str, str]] | None = None, speaker_kg_prefetch: str = "", channel_metadata: dict[str, dict[str, str]] | None = None, persist_extraction: bool = True, ) -> dict[str, Any]: """One agentic extraction pass over *conversation_text*. When *persist_extraction* is false, the model still runs (including read-only KG tools); parsed JSON is not applied to the graph. """ system = augment_system_prompt_with_speaker_mapping( load_kg_agentic_system_prompt(config), chunk_speaker_pairs, ) user = build_kg_bulk_user_message( conversation_text, channel_id=channel_id, chunk_index=chunk_index, time_start_iso=time_start_iso, time_end_iso=time_end_iso, platforms_channels=platforms_channels, config=config, chunk_channel_pairs=chunk_channel_pairs, chunk_speaker_pairs=chunk_speaker_pairs, speaker_kg_prefetch=speaker_kg_prefetch, channel_metadata=channel_metadata, ) msgs: list[dict[str, Any]] = [ {"role": "system", "content": system}, {"role": "user", "content": user}, ] ctx = ToolContext( kg_manager=kg_manager, user_id=user_id, channel_id=channel_id, ) try: raw = await bulk_client.chat( msgs, user_id=user_id, ctx=ctx, tool_names=list(KG_BULK_TOOL_NAMES), validate_header=False, ) except Exception: logger.warning("Agentic KG extraction LLM call failed", exc_info=True) return { "entities_added": 0, "relationships_added": 0, "errors": 1, "parse_error": True, } if not raw or not raw.strip(): return { "entities_added": 0, "relationships_added": 0, "errors": 1, "parse_error": True, } if not persist_extraction: _print_dry_run_llm_output( channel_id=channel_id, chunk_index=chunk_index, raw=raw, ) try: data = _parse_llm_json(raw) except (json.JSONDecodeError, Exception): logger.warning( "Agentic KG extraction JSON parse failed; raw=%r", raw, exc_info=True, ) return { "entities_added": 0, "relationships_added": 0, "errors": 1, "parse_error": True, } if not persist_extraction: return { "entities_added": 0, "relationships_added": 0, "errors": 0, "parse_error": False, "persisted": False, "proposed_entities": len(data.get("entities") or []), "proposed_relationships": len(data.get("relationships") or []), } stats = await apply_parsed_extraction( data, kg_manager, channel_id, user_id=user_id, created_by="system:kg_agentic_bulk", ) stats["parse_error"] = False stats["persisted"] = True return stats