Source code for kg_agentic_extraction
"""Agentic knowledge-graph extraction for bulk chat import.
Bulk agentic chat can use native Gemini (pool keys + :mod:`gemini_kg_bulk_client`)
or OpenRouter (:func:`create_kg_bulk_openrouter_client`). A small read-only tool
set is backed by :class:`~knowledge_graph.KnowledgeGraphManager`.
"""
from __future__ import annotations
import hashlib
import json
import logging
from pathlib import Path
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Protocol
from jinja2.sandbox import SandboxedEnvironment
from jinja2 import FileSystemLoader
from prompt_renderer import sanitize_context
from tool_context import ToolContext
from tools import ToolRegistry
from kg_extraction import _parse_llm_json, apply_parsed_extraction
from openrouter_client import OpenRouterClient
if TYPE_CHECKING:
from config import Config
from gemini_kg_bulk_client import GeminiPoolToolChatClient
from knowledge_graph import KnowledgeGraphManager
logger = logging.getLogger(__name__)
[docs]
class KgBulkLlmClient(Protocol):
"""Structural type for bulk chunking + agentic KG (OpenRouter or native Gemini)."""
[docs]
async def chat(
self,
messages: list[dict[str, Any]],
user_id: str = "",
ctx: ToolContext | None = None,
tool_names: list[str] | None = None,
validate_header: bool = False,
token_count: int | None = None,
on_intermediate_text: Callable[[str], Awaitable[None]] | None = None,
) -> str:
...
[docs]
async def count_input_tokens(
self,
messages: list[dict[str, Any]],
*,
gemini_model: str | None = None,
) -> int | None:
...
KG_BULK_OPENROUTER_BASE = "https://openrouter.ai/api/v1"
KG_BULK_CHAT_MODEL = "google/gemini-3.1-flash-lite-preview"
KG_BULK_TOOL_NAMES = [
"kg_search_entities",
"kg_get_entity",
"kg_inspect_entity",
]
_SYSTEM_PROMPT_CACHE: dict[str, str] = {}
def _print_dry_run_llm_output(
*,
channel_id: str,
chunk_index: int,
raw: str,
) -> None:
"""Print final model text to stdout when extraction is not persisted."""
body = (raw or "").strip()
if not body:
return
sep = "=" * 72
print(
f"\n{sep}\n"
f"KG agentic dry-run — LLM output "
f"(chunk {chunk_index}, scope {channel_id})\n"
f"{sep}\n"
f"{body}\n"
f"{sep}\n",
flush=True,
)
def _project_root() -> Path:
return Path(__file__).resolve().parent
def _is_secret_setting_key(key: str) -> bool:
lk = key.lower()
for frag in (
"token", "password", "secret", "credential",
"api_key", "client_secret",
):
if frag in lk:
return True
return False
def _platform_type_is_ncm(ptype: str) -> bool:
"""True if this platform entry is NCM / neurochemical subsystem (not chat)."""
return "ncm" in (ptype or "").strip().lower()
def _settings_key_is_ncm_related(key: str) -> bool:
return "ncm" in str(key).lower()
[docs]
def build_platform_context_markdown(cfg: Any | None) -> str:
"""Human-readable platform summary for the system prompt (no secrets)."""
if cfg is None:
return (
"(No configuration passed — treat `platform` / `channel_id` / "
"`user_id` in logs as opaque platform-specific identifiers.)"
)
platforms = getattr(cfg, "platforms", None) or []
if not platforms:
lines_out = [
"(No `platforms:` entries in config.yaml — if legacy Matrix "
"fields exist, only the Matrix bot may be active.)",
]
hs = getattr(cfg, "homeserver", "") or ""
uid = getattr(cfg, "user_id", "") or ""
if hs:
lines_out.append(f"- Default Matrix homeserver URL: `{hs}`")
if uid:
lines_out.append(
f"- Legacy Matrix bot user id (format reference): `{uid}`",
)
return "\n".join(lines_out)
lines: list[str] = []
idx = 0
for p in platforms:
if not getattr(p, "enabled", True):
continue
ptype = getattr(p, "type", "") or "unknown"
if _platform_type_is_ncm(ptype):
continue
idx += 1
lines.append(f"### Platform #{idx}: `{ptype}`")
st = {
k: v
for k, v in (p.settings or {}).items()
if not _is_secret_setting_key(str(k))
and not _settings_key_is_ncm_related(str(k))
}
if ptype == "matrix":
hs = st.get("homeserver") or getattr(cfg, "homeserver", "")
if hs:
lines.append(f"- Matrix homeserver (public URL): `{hs}`")
bot_uid = st.get("user_id") or getattr(cfg, "user_id", "")
if bot_uid:
lines.append(
"- Matrix user ids in logs look like "
"`@localpart:domain`."
)
lines.append(
f"- This deployment's bot Matrix id: `{bot_uid}`",
)
elif ptype in ("discord", "discord-self"):
lines.append(
"- Discord numeric ids in logs (`user_id`, `channel_id`, "
"etc.) are **snowflakes** (large integers as strings)."
)
for key in (
"application_id", "guild_id", "primary_guild_id",
"default_guild_id",
):
if key in st and st[key]:
lines.append(f"- `{key}`: `{st[key]}`")
skip_keys = {
"homeserver", "user_id", "store_path",
"credentials_file", "password",
}
for k, v in sorted(st.items()):
if k in skip_keys:
continue
if isinstance(v, (dict, list)):
lines.append(
f"- `{k}`: _(structured value; omitted from prompt)_",
)
else:
lines.append(f"- `{k}`: `{v}`")
if not lines:
enabled = [x for x in platforms if getattr(x, "enabled", True)]
if enabled and all(
_platform_type_is_ncm(getattr(x, "type", "") or "")
for x in enabled
):
return (
"(Only NCM / non-chat platform entries are configured — omitted "
"from KG extraction context; treat ids in logs as opaque.)"
)
return "(All platforms disabled in configuration.)"
return "\n".join(lines)
def _system_template_path() -> Path:
return _project_root() / "prompts" / "kg_agentic_extraction_system.j2"
[docs]
def render_kg_agentic_system_prompt(cfg: Any | None = None) -> str:
"""Render the Jinja2 system prompt including platform context."""
path = _system_template_path()
if not path.is_file():
logger.warning("Missing %s — using fallback system prompt", path)
return (
"You extract knowledge graphs from chat. Use tools to search "
"the graph before creating entities. Final message: JSON only.\n\n"
+ build_platform_context_markdown(cfg)
)
env = SandboxedEnvironment(
loader=FileSystemLoader(str(path.parent)),
autoescape=False,
)
template = env.get_template(path.name)
ctx = sanitize_context(
{"platform_context": build_platform_context_markdown(cfg)},
)
return template.render(**ctx)
def _system_prompt_cache_key(cfg: Any | None) -> str:
path = _system_template_path()
try:
mtime = path.stat().st_mtime
except OSError:
mtime = 0.0
pc = build_platform_context_markdown(cfg)
hints = ""
if cfg is not None and hasattr(cfg, "kg_extraction_channel_hints"):
hints = json.dumps(
getattr(cfg, "kg_extraction_channel_hints", {}),
sort_keys=True,
)
h = hashlib.sha256(f"{pc}\n{hints}".encode()).hexdigest()[:24]
return f"{mtime:.6f}:{h}"
[docs]
def load_kg_agentic_system_prompt(config: Any | None = None) -> str:
"""Rendered system prompt (cached per template mtime + config fingerprint)."""
key = _system_prompt_cache_key(config)
if key not in _SYSTEM_PROMPT_CACHE:
_SYSTEM_PROMPT_CACHE[key] = render_kg_agentic_system_prompt(config)
return _SYSTEM_PROMPT_CACHE[key]
[docs]
def format_chunk_channels_section(
channel_pairs: list[tuple[str, str]],
cfg: Any | None,
default_channel_scope: str,
channel_metadata: dict[str, dict[str, str]] | None = None,
) -> str:
"""Describe which rooms/sources appear in this chunk."""
meta = channel_metadata or {}
if not channel_pairs:
return (
"## Channels in this chunk\n"
f"- _(none resolved — use default_channel_scope_id)_: "
f"`{default_channel_scope}`"
)
lines = ["## Channels in this chunk"]
hints: dict[str, str] = {}
if cfg is not None and hasattr(cfg, "kg_extraction_channel_hints"):
hints = dict(getattr(cfg, "kg_extraction_channel_hints", {}) or {})
for plat, cid in channel_pairs:
key = f"{plat}:{cid}"
hint = hints.get(key, "")
line = f"- **{plat}** / channel id `{cid}`"
if hint:
line += f" — _{hint}_"
row = meta.get(key) or {}
cname = (row.get("name") or "").strip()
topic = (row.get("topic") or "").strip()
if cname:
line += f"\n - resolved name: **{cname}**"
if topic:
line += f"\n - topic: {topic}"
lines.append(line)
lines.append(
f"- **default_channel_scope_id** (chunk default for scoped facts): "
f"`{default_channel_scope}`",
)
return "\n".join(lines)
[docs]
def format_chunk_speakers_section(
speaker_pairs: list[tuple[str, str]],
) -> str:
"""Unique speakers (user_id, display name) in this chunk, sorted by id."""
if not speaker_pairs:
return "## Speakers in this chunk\n- _(none)_"
uniq: dict[str, str] = {}
for uid, name in speaker_pairs:
u = (uid or "").strip()
if not u:
continue
n = (name or "").strip() or "?"
if u not in uniq:
uniq[u] = n
if not uniq:
return "## Speakers in this chunk\n- _(none)_"
lines = ["## Speakers in this chunk"]
for uid in sorted(uniq.keys()):
lines.append(f"- `{uid}` — **{uniq[uid]}**")
return "\n".join(lines)
[docs]
def format_speaker_user_id_mapping_markdown(
speaker_pairs: list[tuple[str, str]],
) -> str:
"""Markdown table: user_id → display name for the current chunk (system prompt)."""
uniq: dict[str, str] = {}
for uid, name in speaker_pairs or []:
u = (uid or "").strip()
if not u:
continue
n = (name or "").strip() or "?"
uniq.setdefault(u, n)
if not uniq:
return ""
lines = [
"## User ID → display name (this chunk)",
"",
"Use this table when interpreting sender ids in the log lines below:",
"",
"| `user_id` | display name |",
"| --- | --- |",
]
for uid in sorted(uniq.keys()):
nm = uniq[uid].replace("|", "\\|")
lines.append(f"| `{uid}` | {nm} |")
return "\n".join(lines)
[docs]
def augment_system_prompt_with_speaker_mapping(
system: str,
speaker_pairs: list[tuple[str, str]] | None,
) -> str:
block = format_speaker_user_id_mapping_markdown(list(speaker_pairs or []))
if not block.strip():
return system
return system.rstrip() + "\n\n" + block + "\n"
[docs]
async def prefetch_speaker_kg_context(
kg: Any,
speakers: list[tuple[str, str]],
*,
max_speakers: int = 8,
hits_per_speaker: int = 3,
min_score: float = 0.0,
) -> str:
"""Vector search prefetch for chunk speakers; full entity text for user prompt."""
if not speakers:
return ""
seen_uuid: set[str] = set()
blocks: list[str] = []
cap = max(1, min(128, int(max_speakers)))
hits = max(1, min(48, int(hits_per_speaker)))
uniq_speakers: dict[str, str] = {}
for uid, name in speakers:
u = (uid or "").strip()
if not u:
continue
uniq_speakers.setdefault(u, (name or "").strip() or "?")
for uid in sorted(uniq_speakers.keys())[:cap]:
name = uniq_speakers[uid]
hits_list: list[dict[str, Any]] = []
try:
by_scope = await kg.search_entities(
name,
category="user",
scope_id=uid,
top_k=hits,
)
hits_list.extend(by_scope or [])
except Exception:
logger.debug("prefetch user scope search failed", exc_info=True)
try:
general = await kg.search_entities(
f"{name} {uid}",
top_k=hits,
)
hits_list.extend(general or [])
except Exception:
logger.debug("prefetch general search failed", exc_info=True)
rows: list[str] = []
for ent in hits_list:
sc = float(ent.get("score") or 0.0)
if sc < min_score:
continue
uu = str(ent.get("uuid") or "")
if not uu or uu in seen_uuid:
continue
seen_uuid.add(uu)
nm = str(ent.get("name") or "")
et = str(ent.get("type") or "")
desc = str(ent.get("description") or "")
cat = str(ent.get("category") or "")
sid = str(ent.get("scope_id") or "")
rows.append(
f" - `{nm}` ({et}, cat={cat}, scope={sid}, "
f"score={sc:.3f}, uuid={uu})"
+ (f": {desc}" if desc else ""),
)
if rows:
blocks.append(f"**Speaker `{uid}` ({name})**:\n" + "\n".join(rows))
if not blocks:
return ""
header = (
"## Existing knowledge graph (speakers — prefetch)\n"
"_Heuristic vector matches only; may be incomplete or noisy — "
"use kg_search_entities / kg_get_entity to verify._\n\n"
)
body = "\n\n".join(blocks)
return header + body
[docs]
def build_kg_bulk_user_message(
conversation_text: str,
*,
channel_id: str,
chunk_index: int,
time_start_iso: str = "",
time_end_iso: str = "",
platforms_channels: str = "",
config: Any | None = None,
chunk_channel_pairs: list[tuple[str, str]] | None = None,
chunk_speaker_pairs: list[tuple[str, str]] | None = None,
speaker_kg_prefetch: str = "",
channel_metadata: dict[str, dict[str, str]] | None = None,
) -> str:
"""User message: metadata, channel context, speakers, optional prefetch."""
pairs = sorted(set(chunk_channel_pairs or []))
chan_block = format_chunk_channels_section(
pairs, config, channel_id,
channel_metadata=channel_metadata,
)
sp_block = format_chunk_speakers_section(
list(chunk_speaker_pairs or []),
)
lines = [
"## Chunk metadata",
f"- chunk_index: {chunk_index}",
f"- default_channel_scope_id: {channel_id}",
f"- time_range_utc: {time_start_iso} → {time_end_iso}",
]
if platforms_channels:
lines.append(
f"- all platforms/channels in full corpus: {platforms_channels}",
)
lines.extend([
"",
chan_block,
"",
sp_block,
])
if (speaker_kg_prefetch or "").strip():
lines.extend(["", (speaker_kg_prefetch or "").strip()])
lines.extend([
"",
"## Conversation (chronological)",
conversation_text,
"",
"Now produce the final JSON extraction per system instructions.",
])
return "\n".join(lines)
[docs]
def messages_for_agentic_token_estimate(
conversation_text: str,
*,
channel_id: str,
chunk_index: int = 0,
time_start_iso: str = "",
time_end_iso: str = "",
platforms_channels: str = "",
config: Any | None = None,
chunk_channel_pairs: list[tuple[str, str]] | None = None,
chunk_speaker_pairs: list[tuple[str, str]] | None = None,
speaker_kg_prefetch: str = "",
channel_metadata: dict[str, dict[str, str]] | None = None,
) -> list[dict[str, str]]:
"""OpenAI-shaped messages for Gemini countTokens (same shape as a run)."""
system = augment_system_prompt_with_speaker_mapping(
load_kg_agentic_system_prompt(config),
chunk_speaker_pairs,
)
user = build_kg_bulk_user_message(
conversation_text,
channel_id=channel_id,
chunk_index=chunk_index,
time_start_iso=time_start_iso,
time_end_iso=time_end_iso,
platforms_channels=platforms_channels,
config=config,
chunk_channel_pairs=chunk_channel_pairs,
chunk_speaker_pairs=chunk_speaker_pairs,
speaker_kg_prefetch=speaker_kg_prefetch,
channel_metadata=channel_metadata,
)
return [
{"role": "system", "content": system},
{"role": "user", "content": user},
]
[docs]
def build_kg_bulk_tool_registry() -> ToolRegistry:
"""Read-only KG tools for the bulk extraction agent."""
reg = ToolRegistry(task_manager=None)
@reg.tool(
name="kg_search_entities",
description=(
"Semantic search over knowledge-graph entities. "
"Use short queries (names, projects, topics from the chat)."
),
parameters={
"type": "object",
"properties": {
"query": {"type": "string"},
"category": {
"type": "string",
"description": "Optional: user, channel, general, basic",
},
"scope_id": {"type": "string"},
"top_k": {"type": "integer", "description": "Max hits (default 12)"},
},
"required": ["query"],
},
)
async def kg_search_entities(
query: str,
category: str = "",
scope_id: str = "",
top_k: int = 12,
ctx: ToolContext | None = None,
) -> str:
if ctx is None or ctx.kg_manager is None:
return json.dumps({"success": False, "error": "kg unavailable"})
kg = ctx.kg_manager
tk = max(1, min(24, int(top_k)))
try:
results = await kg.search_entities(
query,
category=category or None,
scope_id=scope_id or None,
top_k=tk,
)
return json.dumps(
{"success": True, "count": len(results), "results": results},
default=str,
)
except Exception as e:
return json.dumps({"success": False, "error": str(e)})
@reg.tool(
name="kg_get_entity",
description=(
"Look up one entity by name and/or uuid; includes immediate "
"relationship summaries."
),
parameters={
"type": "object",
"properties": {
"name": {"type": "string"},
"uuid": {"type": "string"},
},
"required": [],
},
)
async def kg_get_entity(
name: str = "",
uuid: str = "",
ctx: ToolContext | None = None,
) -> str:
if ctx is None or ctx.kg_manager is None:
return json.dumps({"success": False, "error": "kg unavailable"})
if not (name or "").strip() and not (uuid or "").strip():
return json.dumps({
"success": False,
"error": "Provide name or uuid",
})
kg = ctx.kg_manager
try:
ent = await kg.get_entity(
name=(name or "").strip() or None,
uuid=(uuid or "").strip() or None,
)
if ent:
return json.dumps({"success": True, "entity": ent}, default=str)
return json.dumps({"success": False, "error": "Entity not found"})
except Exception as e:
return json.dumps({"success": False, "error": str(e)})
@reg.tool(
name="kg_inspect_entity",
description=(
"Deep inspection: entity plus inbound/outbound edges, "
"optional 2-hop neighborhood. max_depth 1 or 2."
),
parameters={
"type": "object",
"properties": {
"name": {"type": "string"},
"uuid": {"type": "string"},
"max_depth": {"type": "integer"},
},
"required": [],
},
)
async def kg_inspect_entity(
name: str = "",
uuid: str = "",
max_depth: int = 2,
ctx: ToolContext | None = None,
) -> str:
if ctx is None or ctx.kg_manager is None:
return json.dumps({"success": False, "error": "kg unavailable"})
if not (name or "").strip() and not (uuid or "").strip():
return json.dumps({
"success": False,
"error": "Provide name or uuid",
})
kg = ctx.kg_manager
depth = max(1, min(2, int(max_depth)))
try:
result = await kg.inspect_entity(
name=(name or "").strip() or None,
uuid=(uuid or "").strip() or None,
max_depth=depth,
neighbor_limit=30,
)
if result:
return json.dumps({"success": True, **result}, default=str)
return json.dumps({"success": False, "error": "Entity not found"})
except Exception as e:
return json.dumps({"success": False, "error": str(e)})
return reg
[docs]
def kg_bulk_native_model_id() -> str:
"""Gemini API model id (strip OpenRouter ``google/`` prefix)."""
return KG_BULK_CHAT_MODEL.removeprefix("google/")
[docs]
def create_kg_bulk_gemini_pool_client(
*,
max_tool_rounds: int = 48,
max_tokens: int = 60_000,
max_tool_output_chars: int = 3_000_000,
temperature: float = 0.25,
) -> GeminiPoolToolChatClient:
"""Native Gemini via embed pool keys + AFC (:mod:`gemini_kg_bulk_client`)."""
from gemini_kg_bulk_client import GeminiPoolToolChatClient
return GeminiPoolToolChatClient(
tool_registry=build_kg_bulk_tool_registry(),
model_id=kg_bulk_native_model_id(),
max_tool_rounds=max_tool_rounds,
max_tokens=max_tokens,
max_tool_output_chars=max_tool_output_chars,
temperature=temperature,
)
[docs]
def create_kg_bulk_openrouter_client(
api_key: str,
*,
gemini_api_key: str = "",
max_tool_rounds: int = 48,
max_tokens: int = 60_000,
max_tool_output_chars: int = 3_000_000,
temperature: float = 0.25,
) -> OpenRouterClient:
"""OpenRouter client: production endpoint + gemini-3.1-flash-lite-preview."""
return OpenRouterClient(
api_key=api_key,
model=KG_BULK_CHAT_MODEL,
temperature=temperature,
max_tokens=max_tokens,
tool_registry=build_kg_bulk_tool_registry(),
max_tool_rounds=max_tool_rounds,
base_url=KG_BULK_OPENROUTER_BASE,
gemini_api_key=gemini_api_key,
max_tool_output_chars=max_tool_output_chars,
)
[docs]
async def run_agentic_kg_extraction_chunk(
*,
conversation_text: str,
channel_id: str,
kg_manager: KnowledgeGraphManager,
bulk_client: KgBulkLlmClient,
user_id: str = "000000000000",
chunk_index: int = 0,
time_start_iso: str = "",
time_end_iso: str = "",
platforms_channels: str = "",
config: Any | None = None,
chunk_channel_pairs: list[tuple[str, str]] | None = None,
chunk_speaker_pairs: list[tuple[str, str]] | None = None,
speaker_kg_prefetch: str = "",
channel_metadata: dict[str, dict[str, str]] | None = None,
persist_extraction: bool = True,
) -> dict[str, Any]:
"""One agentic extraction pass over *conversation_text*.
When *persist_extraction* is false, the model still runs (including
read-only KG tools); parsed JSON is not applied to the graph.
"""
system = augment_system_prompt_with_speaker_mapping(
load_kg_agentic_system_prompt(config),
chunk_speaker_pairs,
)
user = build_kg_bulk_user_message(
conversation_text,
channel_id=channel_id,
chunk_index=chunk_index,
time_start_iso=time_start_iso,
time_end_iso=time_end_iso,
platforms_channels=platforms_channels,
config=config,
chunk_channel_pairs=chunk_channel_pairs,
chunk_speaker_pairs=chunk_speaker_pairs,
speaker_kg_prefetch=speaker_kg_prefetch,
channel_metadata=channel_metadata,
)
msgs: list[dict[str, Any]] = [
{"role": "system", "content": system},
{"role": "user", "content": user},
]
ctx = ToolContext(
kg_manager=kg_manager,
user_id=user_id,
channel_id=channel_id,
)
try:
raw = await bulk_client.chat(
msgs,
user_id=user_id,
ctx=ctx,
tool_names=list(KG_BULK_TOOL_NAMES),
validate_header=False,
)
except Exception:
logger.warning("Agentic KG extraction LLM call failed", exc_info=True)
return {
"entities_added": 0,
"relationships_added": 0,
"errors": 1,
"parse_error": True,
}
if not raw or not raw.strip():
return {
"entities_added": 0,
"relationships_added": 0,
"errors": 1,
"parse_error": True,
}
if not persist_extraction:
_print_dry_run_llm_output(
channel_id=channel_id,
chunk_index=chunk_index,
raw=raw,
)
try:
data = _parse_llm_json(raw)
except (json.JSONDecodeError, Exception):
logger.warning(
"Agentic KG extraction JSON parse failed; raw=%r",
raw,
exc_info=True,
)
return {
"entities_added": 0,
"relationships_added": 0,
"errors": 1,
"parse_error": True,
}
if not persist_extraction:
return {
"entities_added": 0,
"relationships_added": 0,
"errors": 0,
"parse_error": False,
"persisted": False,
"proposed_entities": len(data.get("entities") or []),
"proposed_relationships": len(data.get("relationships") or []),
}
stats = await apply_parsed_extraction(
data,
kg_manager,
channel_id,
user_id=user_id,
created_by="system:kg_agentic_bulk",
)
stats["parse_error"] = False
stats["persisted"] = True
return stats