Source code for build_kg

#!/usr/bin/env python3
"""Standalone script to build knowledge graph entries from channel messages.

Fetches the last N messages (default 1000) from a channel via Redis cache
first, falling back to the platform API.  Sends ALL messages plus the
entire existing knowledge graph to gemini-3-flash-preview in a single
call, then presents the proposed entities/relationships for human approval
before committing to FalkorDB.

Usage:
    python build_kg.py --platform discord --channel 123456789
    python build_kg.py --platform discord --channel 123456789 --guild 987
"""

from __future__ import annotations

import argparse
import asyncio
import json
import logging
import sys
import time
from datetime import datetime, timezone
from typing import Any

import redis.asyncio as aioredis

from config import Config
from kg_extraction import EXTRACTION_PROMPT, _TYPE_MAP, _parse_llm_json
from knowledge_graph import ENTITY_LABELS, KnowledgeGraphManager
from message_cache import CachedMessage, MessageCache
from openrouter_client import OpenRouterClient

logger = logging.getLogger(__name__)

_EXTRACTION_MODEL = "gemini-3-flash-preview"
_VALID_CATEGORIES = {"user", "channel", "general", "basic"}
_DEFAULT_MESSAGE_COUNT = 1000


# ---------------------------------------------------------------------------
# Message retrieval
# ---------------------------------------------------------------------------

[docs] async def fetch_messages_redis( cache: MessageCache, platform: str, channel_id: str, count: int, ) -> list[CachedMessage]: """Pull up to *count* messages from the Redis sorted-set cache.""" try: msgs = await cache.get_recent(platform, channel_id, count=count) msgs.reverse() # get_recent returns newest-first; we want chronological return msgs except Exception: logger.warning("Redis message fetch failed", exc_info=True) return []
[docs] async def fetch_messages_discord( token: str, channel_id: str, limit: int, ) -> list[dict[str, Any]]: """Fetch messages directly from the Discord API using discord.py. Returns dicts with keys: user_id, user_name, text, timestamp (float). """ try: import discord except ImportError: logger.error("discord.py is not installed -- cannot fall back to API") return [] intents = discord.Intents.default() intents.message_content = True client = discord.Client(intents=intents) messages: list[dict[str, Any]] = [] ready_event = asyncio.Event() @client.event async def on_ready(): """On ready. """ ready_event.set() token_task = asyncio.create_task(client.start(token)) try: await asyncio.wait_for(ready_event.wait(), timeout=30) channel = client.get_channel(int(channel_id)) if channel is None: channel = await client.fetch_channel(int(channel_id)) if hasattr(channel, "history"): async for msg in channel.history(limit=limit): messages.append({ "user_id": str(msg.author.id), "user_name": msg.author.display_name, "text": msg.content, "timestamp": msg.created_at.timestamp(), "is_bot": msg.author.bot or False, }) except asyncio.TimeoutError: logger.error("Discord client failed to connect within 30s") except Exception: logger.error("Discord API fetch failed", exc_info=True) finally: await client.close() token_task.cancel() try: await token_task except (asyncio.CancelledError, Exception): pass messages.reverse() # history() returns newest-first return messages
[docs] async def gather_messages( cache: MessageCache | None, platform: str, channel_id: str, count: int, cfg: Config, ) -> list[dict[str, Any]]: """Collect up to *count* messages, Redis-first with API fallback. Returns a chronologically-ordered list of message dicts with keys: user_id, user_name, text, timestamp. """ results: list[dict[str, Any]] = [] if cache is not None: cached = await fetch_messages_redis(cache, platform, channel_id, count) for cm in cached: results.append({ "user_id": cm.user_id, "user_name": cm.user_name, "text": cm.text, "timestamp": cm.timestamp, }) if len(results) >= count: return results[:count] remaining = count - len(results) print( f" Redis returned {len(results)} messages, " f"attempting API fetch for up to {remaining} more..." ) if platform == "discord": discord_token = None for p in cfg.platforms: if p.type == "discord": discord_token = p.settings.get("token") break if not discord_token: print(" WARNING: No Discord token in config, cannot fall back to API.") else: api_msgs = await fetch_messages_discord( discord_token, channel_id, limit=remaining, ) seen_ts = {m["timestamp"] for m in results} for m in api_msgs: if m["timestamp"] not in seen_ts: results.append({ "user_id": m["user_id"], "user_name": m["user_name"], "text": m["text"], "timestamp": m["timestamp"], }) else: print(f" WARNING: API fallback not implemented for platform '{platform}'.") results.sort(key=lambda m: m["timestamp"]) return results[:count]
# --------------------------------------------------------------------------- # Graph dump for LLM context # ---------------------------------------------------------------------------
[docs] async def dump_full_graph(kg: KnowledgeGraphManager) -> str: """Serialize the entire knowledge graph into a text block for LLM context.""" entities = await kg.list_entities(limit=10_000) relationships = await kg.list_relationships(limit=10_000) if not entities and not relationships: return "(The knowledge graph is currently empty.)\n" lines: list[str] = ["=== EXISTING KNOWLEDGE GRAPH ===\n"] if entities: lines.append("Entities:") for e in entities: scope = e.get("scope_id", "_") scope_str = f", scope={scope}" if scope != "_" else "" lines.append( f" - [{e.get('type', '?')}] {e.get('name', '?')} " f"(category={e.get('category', '?')}{scope_str}): " f"\"{e.get('description', '')}\"" ) if relationships: lines.append("\nRelationships:") for r in relationships: lines.append( f" - {r.get('source', '?')} -[{r.get('relation', '?')}]-> " f"{r.get('target', '?')} " f"(weight={r.get('weight', '?')}): " f"\"{r.get('description', '')}\"" ) return "\n".join(lines) + "\n"
# --------------------------------------------------------------------------- # LLM extraction # ---------------------------------------------------------------------------
[docs] def build_extraction_prompt( conversation_text: str, graph_context: str, ) -> list[dict[str, str]]: """Build the messages list for the extraction LLM call.""" system = ( "You extract structured knowledge graphs from conversations. " "Output only valid JSON. Do NOT duplicate entities that already " "exist in the graph below -- instead reference them by their " "existing name when creating relationships." ) user_content = ( f"{graph_context}\n\n" f"{EXTRACTION_PROMPT}\n\n" f"Conversation:\n{conversation_text}\n\nJSON:" ) return [ {"role": "system", "content": system}, {"role": "user", "content": user_content}, ]
[docs] async def run_extraction( openrouter: OpenRouterClient, conversation_text: str, graph_context: str, ) -> dict[str, list[dict]]: """Call the LLM to extract entities and relationships. Returns {"entities": [...], "relationships": [...]}. """ msgs = build_extraction_prompt(conversation_text, graph_context) try: raw = await openrouter.chat(msgs) data = _parse_llm_json(raw) return { "entities": data.get("entities", []), "relationships": data.get("relationships", []), } except (json.JSONDecodeError, Exception) as exc: logger.warning("LLM extraction failed: %s", exc) return {"entities": [], "relationships": []}
# --------------------------------------------------------------------------- # Human approval UI # ---------------------------------------------------------------------------
[docs] def format_entity(idx: int, ent: dict) -> str: """Format entity for output. Args: idx (int): The idx value. ent (dict): The ent value. Returns: str: Result string. """ cat = ent.get("category", "general") uid = ent.get("user_id", "") scope_str = f", user_id={uid}" if uid else "" return ( f" e{idx + 1}. [{ent.get('type', '?')}] " f"{ent.get('name', '?')} " f"(category={cat}{scope_str}): " f"\"{ent.get('description', '')}\"" )
[docs] def format_relationship(idx: int, rel: dict) -> str: """Format relationship for output. Args: idx (int): The idx value. rel (dict): The rel value. Returns: str: Result string. """ return ( f" r{idx + 1}. {rel.get('source', '?')} " f"-[{rel.get('relation', 'RELATED_TO')}]-> " f"{rel.get('target', '?')} " f"(confidence={rel.get('confidence', '?')}): " f"\"{rel.get('description', '')}\"" )
[docs] def prompt_approval( entities: list[dict], relationships: list[dict], num_messages: int, ) -> tuple[list[int], list[int]] | None: """Display proposed entries and return approved indices. Returns (entity_indices, relationship_indices) or None to quit. Entity/relationship indices are 0-based. """ print(f"\n{'=' * 60}") print(f" Extraction results from {num_messages} messages") print(f"{'=' * 60}") if not entities and not relationships: print(" (no entities or relationships extracted)") return [], [] if entities: print(f"\n Proposed Entities ({len(entities)}):") for i, e in enumerate(entities): print(format_entity(i, e)) if relationships: print(f"\n Proposed Relationships ({len(relationships)}):") for i, r in enumerate(relationships): print(format_relationship(i, r)) print() print(" Options:") print(" y / Enter = approve all") print(" n = reject all") print(" e1,e3,r2 = approve only selected (e=entity, r=relationship)") print(" q = quit without committing") print() choice = input(" Your choice: ").strip().lower() if choice == "q": return None if choice in ("n", "no"): return [], [] if choice in ("y", "yes", ""): return ( list(range(len(entities))), list(range(len(relationships))), ) # Selective approval approved_ents: list[int] = [] approved_rels: list[int] = [] for token in choice.split(","): token = token.strip() if not token: continue if token.startswith("e") and token[1:].isdigit(): idx = int(token[1:]) - 1 # 1-based -> 0-based if 0 <= idx < len(entities): approved_ents.append(idx) elif token.startswith("r") and token[1:].isdigit(): idx = int(token[1:]) - 1 if 0 <= idx < len(relationships): approved_rels.append(idx) else: print(f" WARNING: Unrecognised token '{token}', ignoring.") return approved_ents, approved_rels
# --------------------------------------------------------------------------- # Commit approved entries # ---------------------------------------------------------------------------
[docs] async def commit_entities( kg: KnowledgeGraphManager, entities: list[dict], approved_indices: list[int], channel_id: str, entity_uuid_lookup: dict[str, str], ) -> int: """Resolve-or-create approved entities. Returns count committed. Populates *entity_uuid_lookup* with name->uuid mappings. """ prepared: list[tuple[str, str, str, str, str]] = [] embed_texts: list[str] = [] for idx in approved_indices: ent = entities[idx] name = ent.get("name", "").strip() if not name: continue raw_type = ent.get("type", "fact").lower() etype = _TYPE_MAP.get(raw_type, "Fact") description = ent.get("description", "") category = ent.get("category", "general") user_id = ent.get("user_id", "") if category not in _VALID_CATEGORIES: print(f" Skipping entity '{name}': invalid category '{category}'") continue if category == "user": scope_id = user_id or "_" elif category == "channel": scope_id = channel_id else: scope_id = "_" prepared.append((name, etype, description, category, scope_id)) embed_texts.append( f"{name}: {description}" if description else name, ) vectors: list[list[float] | None] = [None] * len(prepared) if embed_texts: try: vectors = await kg._embed_batch(embed_texts) except Exception: print(f" WARNING: Batch embedding failed for {len(embed_texts)} entities, falling back to per-entity") vectors = [None] * len(prepared) committed = 0 for (name, etype, description, category, scope_id), vec in zip( prepared, vectors, ): try: info = await kg._resolve_or_create( name, etype, category, scope_id, description=description, created_by="system:build_kg_script", embedding=vec, ) entity_uuid_lookup[name.lower()] = info["uuid"] committed += 1 except Exception as exc: print(f" ERROR committing entity: {exc}") return committed
async def _guess_uuid(kg: KnowledgeGraphManager, name: str) -> str | None: """Try to find an entity's UUID by name across all labels.""" name_lower = name.strip().lower() for label in ENTITY_LABELS: try: q = ( f"MATCH (e:{label}) " f"WHERE e.name = $name " f"RETURN e.uuid LIMIT 1" ) result = await kg._graph.query(q, params={"name": name_lower}) if result.result_set and result.result_set[0][0]: return result.result_set[0][0] except Exception: continue return None
[docs] async def commit_relationships( kg: KnowledgeGraphManager, relationships: list[dict], approved_indices: list[int], entity_uuid_lookup: dict[str, str], ) -> int: """Create/reinforce approved relationships. Returns count committed.""" committed = 0 for idx in approved_indices: rel = relationships[idx] try: src_name = rel.get("source", "").strip() tgt_name = rel.get("target", "").strip() relation = rel.get("relation", "RELATED_TO").upper() desc = rel.get("description", "") confidence = float(rel.get("confidence", 0.5)) if not src_name or not tgt_name: continue src_uuid = entity_uuid_lookup.get(src_name.lower()) tgt_uuid = entity_uuid_lookup.get(tgt_name.lower()) if not src_uuid: src_uuid = await _guess_uuid(kg, src_name) if not tgt_uuid: tgt_uuid = await _guess_uuid(kg, tgt_name) if not src_uuid or not tgt_uuid: print(f" Skipping relationship: could not resolve " f"UUID for '{src_name}' or '{tgt_name}'") continue await kg.add_relationship( src_uuid, tgt_uuid, relation, weight=confidence, description=desc, ) committed += 1 except Exception as exc: print(f" ERROR committing relationship: {exc}") return committed
# --------------------------------------------------------------------------- # Main pipeline # ---------------------------------------------------------------------------
[docs] def format_conversation(messages: list[dict[str, Any]]) -> str: """Format all messages into conversation text for the LLM.""" lines: list[str] = [] for m in messages: ts = datetime.fromtimestamp(m["timestamp"], tz=timezone.utc).isoformat() lines.append( f"[{ts}] {m.get('user_name', '?')} " f"({m.get('user_id', '?')}): {m.get('text', '')}" ) return "\n".join(lines)
[docs] async def run(args: argparse.Namespace) -> None: """Execute this tool and return the result. Args: args (argparse.Namespace): The args value. """ cfg = Config.load() if not cfg.redis_url: print("ERROR: No redis_url configured. Cannot proceed.") sys.exit(1) if not cfg.api_key: print("ERROR: No api_key configured. Cannot proceed.") sys.exit(1) # -- Initialise components -- extraction_client = OpenRouterClient( api_key=cfg.api_key, model=_EXTRACTION_MODEL, temperature=0.3, max_tokens=60_000, base_url=cfg.llm_base_url, ) cache = MessageCache( redis_url=cfg.redis_url, openrouter_client=extraction_client, embedding_model=cfg.embedding_model, ) kg = KnowledgeGraphManager( redis_client=cache.redis_client, openrouter=extraction_client, embedding_model=cfg.embedding_model, admin_user_ids=set(cfg.admin_user_ids) if cfg.admin_user_ids else None, ) await kg.ensure_indexes() # -- Gather messages -- print(f"\nFetching up to {args.count} messages from " f"{args.platform}:{args.channel} ...") messages = await gather_messages( cache, args.platform, args.channel, args.count, cfg, ) if not messages: print("No messages found. Nothing to do.") await cache.close() return print(f" Collected {len(messages)} messages " f"(oldest: {datetime.fromtimestamp(messages[0]['timestamp'], tz=timezone.utc).isoformat()}, " f"newest: {datetime.fromtimestamp(messages[-1]['timestamp'], tz=timezone.utc).isoformat()})") # -- Dump existing graph -- print("\nLoading existing knowledge graph for LLM context...") graph_context = await dump_full_graph(kg) stats = await kg.get_graph_stats() print(f" Graph has {stats.get('node_count', 0)} nodes, " f"{stats.get('relationship_count', 0)} relationships") # -- Format all messages (no chunking -- full context) -- conversation_text = format_conversation(messages) print(f"\n Conversation text: {len(conversation_text):,} chars") print(f" Graph context: {len(graph_context):,} chars") print(f" Total prompt size: ~{len(conversation_text) + len(graph_context):,} chars") # -- Single LLM extraction call -- print("\n Calling LLM for full extraction (this may take a while)...", end="", flush=True) extracted = await run_extraction( extraction_client, conversation_text, graph_context, ) entities = extracted["entities"] relationships = extracted["relationships"] print(f" done.") print(f" Extracted {len(entities)} entities, {len(relationships)} relationships") if not entities and not relationships: print("\n Nothing extracted. Exiting.") await cache.close() return # -- Human approval -- result = prompt_approval(entities, relationships, len(messages)) if result is None: print("\n Quit -- nothing committed.") await cache.close() return approved_ents, approved_rels = result if not approved_ents and not approved_rels: print(" All rejected -- nothing committed.") await cache.close() return # Commit print(f"\n Committing {len(approved_ents)} entities, " f"{len(approved_rels)} relationships...") entity_uuid_lookup: dict[str, str] = {} n_ent = await commit_entities( kg, entities, approved_ents, args.channel, entity_uuid_lookup, ) n_rel = await commit_relationships( kg, relationships, approved_rels, entity_uuid_lookup, ) # -- Summary -- print(f"\n{'=' * 60}") print(f" DONE") print(f" Entities committed: {n_ent}") print(f" Relationships committed: {n_rel}") print(f"{'=' * 60}\n") await cache.close()
[docs] def main() -> None: """Main. """ parser = argparse.ArgumentParser( description="Build knowledge graph entries from channel messages with human approval.", ) parser.add_argument( "--platform", required=True, help="Platform name (e.g. 'discord', 'matrix')", ) parser.add_argument( "--channel", required=True, help="Channel ID to process", ) parser.add_argument( "--guild", help="Guild/server ID (optional, for scoping)", ) parser.add_argument( "--count", type=int, default=_DEFAULT_MESSAGE_COUNT, help=f"Number of messages to fetch (default: {_DEFAULT_MESSAGE_COUNT})", ) parser.add_argument( "--verbose", "-v", action="store_true", help="Enable debug logging", ) args = parser.parse_args() logging.basicConfig( level=logging.DEBUG if args.verbose else logging.WARNING, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", ) asyncio.run(run(args))
if __name__ == "__main__": main()