#!/usr/bin/env python3
"""Standalone script to build knowledge graph entries from channel messages.
Fetches the last N messages (default 1000) from a channel via Redis cache
first, falling back to the platform API. Sends ALL messages plus the
entire existing knowledge graph to gemini-3-flash-preview in a single
call, then presents the proposed entities/relationships for human approval
before committing to FalkorDB.
Usage:
python build_kg.py --platform discord --channel 123456789
python build_kg.py --platform discord --channel 123456789 --guild 987
"""
from __future__ import annotations
import argparse
import asyncio
import json
import logging
import sys
import time
from datetime import datetime, timezone
from typing import Any
import redis.asyncio as aioredis
from config import Config
from kg_extraction import EXTRACTION_PROMPT, _TYPE_MAP, _parse_llm_json
from knowledge_graph import ENTITY_LABELS, KnowledgeGraphManager
from message_cache import CachedMessage, MessageCache
from openrouter_client import OpenRouterClient
logger = logging.getLogger(__name__)
_EXTRACTION_MODEL = "gemini-3-flash-preview"
_VALID_CATEGORIES = {"user", "channel", "general", "basic"}
_DEFAULT_MESSAGE_COUNT = 1000
# ---------------------------------------------------------------------------
# Message retrieval
# ---------------------------------------------------------------------------
[docs]
async def fetch_messages_redis(
cache: MessageCache,
platform: str,
channel_id: str,
count: int,
) -> list[CachedMessage]:
"""Pull up to *count* messages from the Redis sorted-set cache."""
try:
msgs = await cache.get_recent(platform, channel_id, count=count)
msgs.reverse() # get_recent returns newest-first; we want chronological
return msgs
except Exception:
logger.warning("Redis message fetch failed", exc_info=True)
return []
[docs]
async def fetch_messages_discord(
token: str,
channel_id: str,
limit: int,
) -> list[dict[str, Any]]:
"""Fetch messages directly from the Discord API using discord.py.
Returns dicts with keys: user_id, user_name, text, timestamp (float).
"""
try:
import discord
except ImportError:
logger.error("discord.py is not installed -- cannot fall back to API")
return []
intents = discord.Intents.default()
intents.message_content = True
client = discord.Client(intents=intents)
messages: list[dict[str, Any]] = []
ready_event = asyncio.Event()
@client.event
async def on_ready():
"""On ready.
"""
ready_event.set()
token_task = asyncio.create_task(client.start(token))
try:
await asyncio.wait_for(ready_event.wait(), timeout=30)
channel = client.get_channel(int(channel_id))
if channel is None:
channel = await client.fetch_channel(int(channel_id))
if hasattr(channel, "history"):
async for msg in channel.history(limit=limit):
messages.append({
"user_id": str(msg.author.id),
"user_name": msg.author.display_name,
"text": msg.content,
"timestamp": msg.created_at.timestamp(),
"is_bot": msg.author.bot or False,
})
except asyncio.TimeoutError:
logger.error("Discord client failed to connect within 30s")
except Exception:
logger.error("Discord API fetch failed", exc_info=True)
finally:
await client.close()
token_task.cancel()
try:
await token_task
except (asyncio.CancelledError, Exception):
pass
messages.reverse() # history() returns newest-first
return messages
[docs]
async def gather_messages(
cache: MessageCache | None,
platform: str,
channel_id: str,
count: int,
cfg: Config,
) -> list[dict[str, Any]]:
"""Collect up to *count* messages, Redis-first with API fallback.
Returns a chronologically-ordered list of message dicts with keys:
user_id, user_name, text, timestamp.
"""
results: list[dict[str, Any]] = []
if cache is not None:
cached = await fetch_messages_redis(cache, platform, channel_id, count)
for cm in cached:
results.append({
"user_id": cm.user_id,
"user_name": cm.user_name,
"text": cm.text,
"timestamp": cm.timestamp,
})
if len(results) >= count:
return results[:count]
remaining = count - len(results)
print(
f" Redis returned {len(results)} messages, "
f"attempting API fetch for up to {remaining} more..."
)
if platform == "discord":
discord_token = None
for p in cfg.platforms:
if p.type == "discord":
discord_token = p.settings.get("token")
break
if not discord_token:
print(" WARNING: No Discord token in config, cannot fall back to API.")
else:
api_msgs = await fetch_messages_discord(
discord_token, channel_id, limit=remaining,
)
seen_ts = {m["timestamp"] for m in results}
for m in api_msgs:
if m["timestamp"] not in seen_ts:
results.append({
"user_id": m["user_id"],
"user_name": m["user_name"],
"text": m["text"],
"timestamp": m["timestamp"],
})
else:
print(f" WARNING: API fallback not implemented for platform '{platform}'.")
results.sort(key=lambda m: m["timestamp"])
return results[:count]
# ---------------------------------------------------------------------------
# Graph dump for LLM context
# ---------------------------------------------------------------------------
[docs]
async def dump_full_graph(kg: KnowledgeGraphManager) -> str:
"""Serialize the entire knowledge graph into a text block for LLM context."""
entities = await kg.list_entities(limit=10_000)
relationships = await kg.list_relationships(limit=10_000)
if not entities and not relationships:
return "(The knowledge graph is currently empty.)\n"
lines: list[str] = ["=== EXISTING KNOWLEDGE GRAPH ===\n"]
if entities:
lines.append("Entities:")
for e in entities:
scope = e.get("scope_id", "_")
scope_str = f", scope={scope}" if scope != "_" else ""
lines.append(
f" - [{e.get('type', '?')}] {e.get('name', '?')} "
f"(category={e.get('category', '?')}{scope_str}): "
f"\"{e.get('description', '')}\""
)
if relationships:
lines.append("\nRelationships:")
for r in relationships:
lines.append(
f" - {r.get('source', '?')} -[{r.get('relation', '?')}]-> "
f"{r.get('target', '?')} "
f"(weight={r.get('weight', '?')}): "
f"\"{r.get('description', '')}\""
)
return "\n".join(lines) + "\n"
# ---------------------------------------------------------------------------
# LLM extraction
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Human approval UI
# ---------------------------------------------------------------------------
[docs]
def prompt_approval(
entities: list[dict],
relationships: list[dict],
num_messages: int,
) -> tuple[list[int], list[int]] | None:
"""Display proposed entries and return approved indices.
Returns (entity_indices, relationship_indices) or None to quit.
Entity/relationship indices are 0-based.
"""
print(f"\n{'=' * 60}")
print(f" Extraction results from {num_messages} messages")
print(f"{'=' * 60}")
if not entities and not relationships:
print(" (no entities or relationships extracted)")
return [], []
if entities:
print(f"\n Proposed Entities ({len(entities)}):")
for i, e in enumerate(entities):
print(format_entity(i, e))
if relationships:
print(f"\n Proposed Relationships ({len(relationships)}):")
for i, r in enumerate(relationships):
print(format_relationship(i, r))
print()
print(" Options:")
print(" y / Enter = approve all")
print(" n = reject all")
print(" e1,e3,r2 = approve only selected (e=entity, r=relationship)")
print(" q = quit without committing")
print()
choice = input(" Your choice: ").strip().lower()
if choice == "q":
return None
if choice in ("n", "no"):
return [], []
if choice in ("y", "yes", ""):
return (
list(range(len(entities))),
list(range(len(relationships))),
)
# Selective approval
approved_ents: list[int] = []
approved_rels: list[int] = []
for token in choice.split(","):
token = token.strip()
if not token:
continue
if token.startswith("e") and token[1:].isdigit():
idx = int(token[1:]) - 1 # 1-based -> 0-based
if 0 <= idx < len(entities):
approved_ents.append(idx)
elif token.startswith("r") and token[1:].isdigit():
idx = int(token[1:]) - 1
if 0 <= idx < len(relationships):
approved_rels.append(idx)
else:
print(f" WARNING: Unrecognised token '{token}', ignoring.")
return approved_ents, approved_rels
# ---------------------------------------------------------------------------
# Commit approved entries
# ---------------------------------------------------------------------------
[docs]
async def commit_entities(
kg: KnowledgeGraphManager,
entities: list[dict],
approved_indices: list[int],
channel_id: str,
entity_uuid_lookup: dict[str, str],
) -> int:
"""Resolve-or-create approved entities. Returns count committed.
Populates *entity_uuid_lookup* with name->uuid mappings.
"""
prepared: list[tuple[str, str, str, str, str]] = []
embed_texts: list[str] = []
for idx in approved_indices:
ent = entities[idx]
name = ent.get("name", "").strip()
if not name:
continue
raw_type = ent.get("type", "fact").lower()
etype = _TYPE_MAP.get(raw_type, "Fact")
description = ent.get("description", "")
category = ent.get("category", "general")
user_id = ent.get("user_id", "")
if category not in _VALID_CATEGORIES:
print(f" Skipping entity '{name}': invalid category '{category}'")
continue
if category == "user":
scope_id = user_id or "_"
elif category == "channel":
scope_id = channel_id
else:
scope_id = "_"
prepared.append((name, etype, description, category, scope_id))
embed_texts.append(
f"{name}: {description}" if description else name,
)
vectors: list[list[float] | None] = [None] * len(prepared)
if embed_texts:
try:
vectors = await kg._embed_batch(embed_texts)
except Exception:
print(f" WARNING: Batch embedding failed for {len(embed_texts)} entities, falling back to per-entity")
vectors = [None] * len(prepared)
committed = 0
for (name, etype, description, category, scope_id), vec in zip(
prepared, vectors,
):
try:
info = await kg._resolve_or_create(
name, etype, category, scope_id,
description=description,
created_by="system:build_kg_script",
embedding=vec,
)
entity_uuid_lookup[name.lower()] = info["uuid"]
committed += 1
except Exception as exc:
print(f" ERROR committing entity: {exc}")
return committed
async def _guess_uuid(kg: KnowledgeGraphManager, name: str) -> str | None:
"""Try to find an entity's UUID by name across all labels."""
name_lower = name.strip().lower()
for label in ENTITY_LABELS:
try:
q = (
f"MATCH (e:{label}) "
f"WHERE e.name = $name "
f"RETURN e.uuid LIMIT 1"
)
result = await kg._graph.query(q, params={"name": name_lower})
if result.result_set and result.result_set[0][0]:
return result.result_set[0][0]
except Exception:
continue
return None
[docs]
async def commit_relationships(
kg: KnowledgeGraphManager,
relationships: list[dict],
approved_indices: list[int],
entity_uuid_lookup: dict[str, str],
) -> int:
"""Create/reinforce approved relationships. Returns count committed."""
committed = 0
for idx in approved_indices:
rel = relationships[idx]
try:
src_name = rel.get("source", "").strip()
tgt_name = rel.get("target", "").strip()
relation = rel.get("relation", "RELATED_TO").upper()
desc = rel.get("description", "")
confidence = float(rel.get("confidence", 0.5))
if not src_name or not tgt_name:
continue
src_uuid = entity_uuid_lookup.get(src_name.lower())
tgt_uuid = entity_uuid_lookup.get(tgt_name.lower())
if not src_uuid:
src_uuid = await _guess_uuid(kg, src_name)
if not tgt_uuid:
tgt_uuid = await _guess_uuid(kg, tgt_name)
if not src_uuid or not tgt_uuid:
print(f" Skipping relationship: could not resolve "
f"UUID for '{src_name}' or '{tgt_name}'")
continue
await kg.add_relationship(
src_uuid, tgt_uuid,
relation, weight=confidence, description=desc,
)
committed += 1
except Exception as exc:
print(f" ERROR committing relationship: {exc}")
return committed
# ---------------------------------------------------------------------------
# Main pipeline
# ---------------------------------------------------------------------------
[docs]
async def run(args: argparse.Namespace) -> None:
"""Execute this tool and return the result.
Args:
args (argparse.Namespace): The args value.
"""
cfg = Config.load()
if not cfg.redis_url:
print("ERROR: No redis_url configured. Cannot proceed.")
sys.exit(1)
if not cfg.api_key:
print("ERROR: No api_key configured. Cannot proceed.")
sys.exit(1)
# -- Initialise components --
extraction_client = OpenRouterClient(
api_key=cfg.api_key,
model=_EXTRACTION_MODEL,
temperature=0.3,
max_tokens=60_000,
base_url=cfg.llm_base_url,
)
cache = MessageCache(
redis_url=cfg.redis_url,
openrouter_client=extraction_client,
embedding_model=cfg.embedding_model,
)
kg = KnowledgeGraphManager(
redis_client=cache.redis_client,
openrouter=extraction_client,
embedding_model=cfg.embedding_model,
admin_user_ids=set(cfg.admin_user_ids) if cfg.admin_user_ids else None,
)
await kg.ensure_indexes()
# -- Gather messages --
print(f"\nFetching up to {args.count} messages from "
f"{args.platform}:{args.channel} ...")
messages = await gather_messages(
cache, args.platform, args.channel, args.count, cfg,
)
if not messages:
print("No messages found. Nothing to do.")
await cache.close()
return
print(f" Collected {len(messages)} messages "
f"(oldest: {datetime.fromtimestamp(messages[0]['timestamp'], tz=timezone.utc).isoformat()}, "
f"newest: {datetime.fromtimestamp(messages[-1]['timestamp'], tz=timezone.utc).isoformat()})")
# -- Dump existing graph --
print("\nLoading existing knowledge graph for LLM context...")
graph_context = await dump_full_graph(kg)
stats = await kg.get_graph_stats()
print(f" Graph has {stats.get('node_count', 0)} nodes, "
f"{stats.get('relationship_count', 0)} relationships")
# -- Format all messages (no chunking -- full context) --
conversation_text = format_conversation(messages)
print(f"\n Conversation text: {len(conversation_text):,} chars")
print(f" Graph context: {len(graph_context):,} chars")
print(f" Total prompt size: ~{len(conversation_text) + len(graph_context):,} chars")
# -- Single LLM extraction call --
print("\n Calling LLM for full extraction (this may take a while)...",
end="", flush=True)
extracted = await run_extraction(
extraction_client, conversation_text, graph_context,
)
entities = extracted["entities"]
relationships = extracted["relationships"]
print(f" done.")
print(f" Extracted {len(entities)} entities, {len(relationships)} relationships")
if not entities and not relationships:
print("\n Nothing extracted. Exiting.")
await cache.close()
return
# -- Human approval --
result = prompt_approval(entities, relationships, len(messages))
if result is None:
print("\n Quit -- nothing committed.")
await cache.close()
return
approved_ents, approved_rels = result
if not approved_ents and not approved_rels:
print(" All rejected -- nothing committed.")
await cache.close()
return
# Commit
print(f"\n Committing {len(approved_ents)} entities, "
f"{len(approved_rels)} relationships...")
entity_uuid_lookup: dict[str, str] = {}
n_ent = await commit_entities(
kg, entities, approved_ents, args.channel,
entity_uuid_lookup,
)
n_rel = await commit_relationships(
kg, relationships, approved_rels, entity_uuid_lookup,
)
# -- Summary --
print(f"\n{'=' * 60}")
print(f" DONE")
print(f" Entities committed: {n_ent}")
print(f" Relationships committed: {n_rel}")
print(f"{'=' * 60}\n")
await cache.close()
[docs]
def main() -> None:
"""Main.
"""
parser = argparse.ArgumentParser(
description="Build knowledge graph entries from channel messages with human approval.",
)
parser.add_argument(
"--platform", required=True,
help="Platform name (e.g. 'discord', 'matrix')",
)
parser.add_argument(
"--channel", required=True,
help="Channel ID to process",
)
parser.add_argument(
"--guild",
help="Guild/server ID (optional, for scoping)",
)
parser.add_argument(
"--count", type=int, default=_DEFAULT_MESSAGE_COUNT,
help=f"Number of messages to fetch (default: {_DEFAULT_MESSAGE_COUNT})",
)
parser.add_argument(
"--verbose", "-v", action="store_true",
help="Enable debug logging",
)
args = parser.parse_args()
logging.basicConfig(
level=logging.DEBUG if args.verbose else logging.WARNING,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
asyncio.run(run(args))
if __name__ == "__main__":
main()