Source code for kg_extraction
"""Knowledge-graph extraction from conversations.
Replaces ``background_agents/auto_memory_extraction.py``. Uses an LLM
to extract structured entities and relationships from conversation text,
then writes them into the FalkorDB knowledge graph via
:class:`~knowledge_graph.KnowledgeGraphManager`.
"""
from __future__ import annotations
import json
import logging
import re
import time
from typing import Any, TYPE_CHECKING
if TYPE_CHECKING:
import redis.asyncio as aioredis
from knowledge_graph import KnowledgeGraphManager
from openrouter_client import OpenRouterClient
logger = logging.getLogger(__name__)
_REDIS_LAST_RUN_KEY = "stargazer:kg_extraction:last_run"
_RATE_LIMIT_PREFIX = "stargazer:kg_extract_rate:"
_RATE_LIMIT_WINDOW = 3600 # 1 hour
_MESSAGES_PER_RUN = 100
_MIN_MESSAGE_LENGTH = 20
_MIN_SINGLE_MESSAGE_LENGTH = 100
# Heuristic patterns that signal a message likely contains
# extractable personal facts, preferences, or knowledge.
_SIGNAL_PATTERNS: list[re.Pattern] = [
re.compile(p, re.IGNORECASE) for p in [
r"\bi am\b", r"\bi'm\b",
r"\bmy (name|job|work|favorite|fav)\b",
r"\bi (work|live|study|moved|started)\b",
r"\bi (like|love|hate|prefer|enjoy|use)\b",
r"\bi (know|learned|believe|think that)\b",
r"\bwe (decided|agreed|use|switched)\b",
r"\bour (team|project|company|stack)\b",
r"\bremember (that|this|when)\b",
r"\bdon'?t forget\b",
r"\bimportant:\s",
r"\bfyi\b",
r"\bannouncement\b",
r"\brule:\s",
r"\bpolicy:\s",
]
]
def _has_extraction_signal(text: str) -> bool:
"""Cheap regex check for knowledge-bearing patterns."""
return any(p.search(text) for p in _SIGNAL_PATTERNS)
EXTRACTION_PROMPT = """\
Analyze this conversation and extract a knowledge graph.
Return JSON with:
{
"entities": [
{
"name": "...",
"type": "person|concept|preference|fact|event|location|\
organization|project|technology|rule|directive|role",
"description": "...",
"category": "user|channel|general|basic",
"user_id": "optional, required if category=user"
}
],
"relationships": [
{
"source": "entity_name",
"target": "entity_name",
"relation": "RELATION_TYPE",
"description": "...",
"confidence": 0.0
}
]
}
Entity types:
- person: A human user, contributor, or known individual.
- concept: An abstract idea, topic, or domain of knowledge.
- preference: A stated like, dislike, or preference.
- fact: A concrete piece of information or data point.
- event: Something that happened or will happen.
- location: A physical or virtual place.
- organization: A company, team, group, or institution.
- project: A named project, repo, or initiative.
- technology: A language, framework, tool, or platform.
- rule: An explicit rule, constraint, or policy that governs behavior.
- directive: An instruction or mandate that guides action.
- role: A named role, permission level, or authority designation.
Recommended relation types (you may also use other descriptive \
UPPER_SNAKE_CASE types):
Personal: LIKES, DISLIKES, KNOWS, PREFERS, SKILLED_IN, SAID
Organizational: WORKS_AT, MEMBER_OF, OWNS, CREATED
Structural: PART_OF, RELATED_TO, USES, LOCATED_IN, IS_A, HAS_PROPERTY
Governance: ENFORCES, PERMITS, PROHIBITS, SUPERSEDES, DEPENDS_ON
Temporal: PRECEDED_BY, FOLLOWED_BY, CAUSED
Note: relationships can cross categories. If a user entity relates to an
already-existing entity of any category, just reference it by name.
The system will resolve the target entity across all categories.
Category rules:
- "user": Facts about a specific user (preferences, \
skills, personal info). Requires user_id.
- "channel": Facts relevant to this specific channel's context.
- "general": Facts that apply broadly \
(not user-specific, not channel-specific). Default category.
- "basic": Fundamental, identity-level facts that should always \
be available (e.g. who the owner is, what the system is called). \
Use sparingly -- only for knowledge that is always relevant.
- NEVER use "core" or "guild" -- these are admin-only categories.
Other rules:
- Only extract genuinely important, persistent facts
- Do NOT extract transient chat or greetings
- Default to "general" unless another category clearly fits
- confidence: 1.0 = explicitly stated, 0.5 = implied, 0.3 = inferred
"""
SINGLE_MESSAGE_PROMPT = """\
Extract any important knowledge from this single message.
Speaker: {user_name} (ID: {user_id})
Channel: {channel_id}
Message: {message_text}
Return JSON with the same format as above. Only extract if there are
genuinely important facts. Return empty lists if nothing noteworthy.
"""
_VALID_CATEGORIES = {"user", "channel", "general", "basic"}
_TYPE_MAP = {
"person": "Person",
"concept": "Concept",
"preference": "Preference",
"fact": "Fact",
"event": "Event",
"location": "Location",
"organization": "Organization",
"project": "Project",
"technology": "Technology",
"rule": "Rule",
"directive": "Directive",
"role": "Role",
}
def _parse_llm_json(raw: str) -> dict:
"""Best-effort JSON parsing from LLM output."""
raw = raw.strip()
# Strip <thinking>...</thinking> blocks (Gemini extended thinking)
thinking_end = raw.find("</thinking>")
if thinking_end != -1:
raw = raw[thinking_end + len("</thinking>"):].strip()
if raw.startswith("```"):
raw = raw.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
# Last resort: find the first '{' if leading text remains
if raw and raw[0] != "{":
brace = raw.find("{")
if brace != -1:
raw = raw[brace:]
return json.loads(raw)
_SENTINEL_USER_ID = "000000000000"
[docs]
async def apply_parsed_extraction(
data: dict[str, Any],
kg_manager: KnowledgeGraphManager,
channel_id: str,
user_id: str = _SENTINEL_USER_ID,
created_by: str = "system:extraction",
) -> dict[str, Any]:
"""Apply LLM extraction JSON (entities + relationships) to FalkorDB.
Each entity may include ``existing_uuid`` — if set, that UUID is
registered for relationship resolution and no new node is created.
Per-entity ``user_id`` overrides the default *user_id* for category
``user`` scope resolution.
"""
entities = data.get("entities", [])
relationships = data.get("relationships", [])
stats: dict[str, Any] = {
"entities_added": 0,
"relationships_added": 0,
"errors": 0,
}
entity_uuid_lookup: dict[str, str] = {}
prepared: list[tuple[str, str, str, str, str, str]] = []
embed_texts: list[str] = []
for ent in entities:
name = ent.get("name", "").strip()
existing_uuid = str(ent.get("existing_uuid", "") or "").strip()
if existing_uuid:
if name:
entity_uuid_lookup[name.lower()] = existing_uuid
continue
action = str(ent.get("action", "create") or "create").lower()
if action in ("skip", "omit", "reference_only") and not existing_uuid:
continue
if not name:
continue
raw_type = ent.get("type", "fact").lower()
etype = _TYPE_MAP.get(raw_type, "Fact")
description = ent.get("description", "")
category = ent.get("category", "general")
if category not in _VALID_CATEGORIES:
continue
ent_user = str(ent.get("user_id", "") or "").strip()
uid_for_resolve = ent_user if category == "user" and ent_user else user_id
if category == "user":
scope_id = uid_for_resolve or "_"
elif category == "channel":
scope_id = channel_id
else:
scope_id = "_"
prepared.append(
(name, etype, description, category, scope_id, uid_for_resolve),
)
embed_texts.append(
f"{name}: {description}" if description else name,
)
vectors: list[list[float] | None] = [None] * len(prepared)
if embed_texts:
try:
vectors = await kg_manager._embed_batch(embed_texts)
except Exception:
logger.warning(
"Batch embedding failed for %d entities, "
"falling back to per-entity",
len(embed_texts), exc_info=True,
)
vectors = [None] * len(prepared)
for (name, etype, description, category, scope_id, uid_res), vec in zip(
prepared, vectors,
):
try:
info = await kg_manager._resolve_or_create(
name, etype, category, scope_id,
description=description,
created_by=created_by,
user_id=uid_res,
embedding=vec,
)
entity_uuid_lookup[name.lower()] = info["uuid"]
stats["entities_added"] += 1
except Exception:
logger.warning("Entity extraction error for %s", name, exc_info=True)
stats["errors"] += 1
for rel in relationships:
try:
src_name = rel.get("source", "").strip()
tgt_name = rel.get("target", "").strip()
relation = rel.get("relation", "RELATED_TO").upper()
desc = rel.get("description", "")
confidence = float(rel.get("confidence", 0.5))
src_uuid = str(rel.get("source_uuid", "") or "").strip()
tgt_uuid = str(rel.get("target_uuid", "") or "").strip()
if not src_uuid and src_name:
src_uuid = entity_uuid_lookup.get(src_name.lower())
if not tgt_uuid and tgt_name:
tgt_uuid = entity_uuid_lookup.get(tgt_name.lower())
if not src_uuid and src_name:
src_uuid = await _resolve_uuid(kg_manager, src_name) or ""
if not tgt_uuid and tgt_name:
tgt_uuid = await _resolve_uuid(kg_manager, tgt_name) or ""
if not src_uuid or not tgt_uuid:
continue
await kg_manager.add_relationship(
src_uuid, tgt_uuid,
relation, weight=confidence, description=desc,
)
stats["relationships_added"] += 1
except Exception:
logger.warning(
"Relationship extraction error for %s",
rel, exc_info=True,
)
stats["errors"] += 1
return stats
[docs]
async def extract_knowledge(
conversation: str,
openrouter: OpenRouterClient,
kg_manager: KnowledgeGraphManager,
channel_id: str,
guild_id: str | None = None,
user_id: str = _SENTINEL_USER_ID,
conversation_char_limit: int | None = 4000,
) -> dict[str, Any]:
"""Full extraction pipeline for a block of conversation text.
1. Call LLM with extraction prompt
2. Parse structured JSON
3. Validate categories (reject core/guild)
4. Resolve or create each entity
5. Create/reinforce each relationship
6. Return stats
Args:
conversation_char_limit: If set, truncate *conversation* to this
many characters before sending to the LLM. ``None`` means no
truncation.
"""
body = conversation
if conversation_char_limit is not None:
body = conversation[:conversation_char_limit]
prompt = (
EXTRACTION_PROMPT + "\n\nConversation:\n"
+ body + "\n\nJSON:"
)
sys_msg = (
"You extract structured knowledge graphs "
"from conversations. Output only valid JSON."
)
msgs = [
{"role": "system", "content": sys_msg},
{"role": "user", "content": prompt},
]
try:
raw = await openrouter.chat(msgs)
data = _parse_llm_json(raw)
except (json.JSONDecodeError, Exception):
logger.warning(
"KG extraction LLM parse failed",
exc_info=True,
)
return {
"entities_added": 0,
"relationships_added": 0,
"errors": 1,
}
return await apply_parsed_extraction(
data, kg_manager, channel_id, user_id=user_id,
created_by="system:extraction",
)
_CROSS_LABELS = [
"Fact", "Person", "Organization",
"Technology", "Project", "Concept",
"Rule", "Directive", "Role",
]
async def _resolve_uuid(
kg_manager: KnowledgeGraphManager,
name: str,
) -> str | None:
"""Find an entity by name across all labels and return its UUID."""
name_lower = name.strip().lower()
for label in _CROSS_LABELS:
cross = (
await kg_manager.resolve_entity_cross_category(
name_lower, label,
)
)
if cross and cross.get("uuid"):
return cross["uuid"]
return await _guess_uuid(kg_manager, name_lower)
async def _guess_uuid(
kg_manager: KnowledgeGraphManager, name: str,
) -> str | None:
"""Try to find an entity's UUID by name across all labels."""
from knowledge_graph import ENTITY_LABELS
name_lower = name.strip().lower()
for label in ENTITY_LABELS:
try:
q = (
f"MATCH (e:{label}) "
f"WHERE e.name = $name "
f"RETURN e.uuid LIMIT 1"
)
result = await kg_manager._graph.query(
q, params={"name": name_lower},
)
if result.result_set and result.result_set[0][0]:
return result.result_set[0][0]
except Exception:
continue
return None
[docs]
async def extract_from_message(
message_text: str,
user_id: str,
user_name: str,
channel_id: str,
guild_id: str | None,
openrouter: OpenRouterClient,
kg_manager: KnowledgeGraphManager,
redis: aioredis.Redis | None = None,
per_user_limit: int = 5,
) -> None:
"""Per-message extraction with cheap pre-filtering.
This function is called fire-and-forget but is gated by
three layers *before* any LLM call is made:
1. **Length gate** -- message must be >= 100 chars.
2. **Heuristic gate** -- message must match at least
one regex pattern that signals knowledge content.
3. **Rate limit** -- max *per_user_limit* extractions
per user per hour (via Redis INCR + EXPIRE).
"""
if len(message_text) < _MIN_SINGLE_MESSAGE_LENGTH:
return
if not _has_extraction_signal(message_text):
return
if redis is not None:
if not await _check_rate_limit(
redis, user_id, per_user_limit,
):
return
prompt = SINGLE_MESSAGE_PROMPT.format(
user_name=user_name,
user_id=user_id,
channel_id=channel_id,
message_text=message_text[:2000],
)
full_prompt = EXTRACTION_PROMPT + "\n\n" + prompt + "\n\nJSON:"
sys_content = (
"You extract structured knowledge graphs "
"from messages. Output only valid JSON."
)
msgs = [
{"role": "system", "content": sys_content},
{"role": "user", "content": full_prompt},
]
try:
raw = await openrouter.chat(msgs)
data = _parse_llm_json(raw)
except Exception:
logger.warning("Per-message extraction parse failed", exc_info=True)
return
entities = data.get("entities", [])
relationships = data.get("relationships", [])
entity_uuid_lookup: dict[str, str] = {}
prepared: list[tuple[str, str, str, str, str]] = []
embed_texts: list[str] = []
for ent in entities:
name = ent.get("name", "").strip()
if not name:
continue
raw_type = ent.get("type", "fact").lower()
etype = _TYPE_MAP.get(raw_type, "Fact")
description = ent.get("description", "")
category = ent.get("category", "general")
if category not in _VALID_CATEGORIES:
continue
if category == "user":
scope_id = user_id
elif category == "channel":
scope_id = channel_id
else:
scope_id = "_"
prepared.append((name, etype, description, category, scope_id))
embed_texts.append(
f"{name}: {description}" if description else name,
)
vectors: list[list[float] | None] = [None] * len(prepared)
if embed_texts:
try:
vectors = await kg_manager._embed_batch(embed_texts)
except Exception:
logger.warning(
"Per-message batch embedding failed for %d entities, "
"falling back to per-entity",
len(embed_texts), exc_info=True,
)
vectors = [None] * len(prepared)
for (name, etype, description, category, scope_id), vec in zip(
prepared, vectors,
):
try:
info = await kg_manager._resolve_or_create(
name, etype, category, scope_id,
description=description,
created_by=f"system:msg_extraction:{user_id}",
user_id=user_id,
embedding=vec,
)
entity_uuid_lookup[name.lower()] = info["uuid"]
except Exception:
logger.warning("Per-message entity error", exc_info=True)
for rel in relationships:
try:
src_name = rel.get("source", "").strip()
tgt_name = rel.get("target", "").strip()
relation = rel.get("relation", "RELATED_TO").upper()
desc = rel.get("description", "")
confidence = float(rel.get("confidence", 0.5))
if not src_name or not tgt_name:
continue
src_uuid = entity_uuid_lookup.get(src_name.lower())
tgt_uuid = entity_uuid_lookup.get(tgt_name.lower())
if not src_uuid:
src_uuid = await _resolve_uuid(
kg_manager, src_name,
)
if not tgt_uuid:
tgt_uuid = await _resolve_uuid(
kg_manager, tgt_name,
)
if not src_uuid or not tgt_uuid:
continue
await kg_manager.add_relationship(
src_uuid, tgt_uuid,
relation, weight=confidence, description=desc,
)
except Exception:
logger.warning("Per-message relationship error", exc_info=True)
# ---------------------------------------------------------------------------
# Batch extraction (periodic background task)
# ---------------------------------------------------------------------------
[docs]
async def run_batch_extraction(
redis: Any,
kg_manager: KnowledgeGraphManager,
openrouter: OpenRouterClient,
messages_limit: int = _MESSAGES_PER_RUN,
) -> dict[str, Any]:
"""Scan recent messages and extract KG entities.
Called periodically by the background scheduler.
"""
if not all((redis, kg_manager, openrouter)):
return {
"status": "disabled",
"reason": "missing dependencies",
}
last_run = await _get_last_run(redis)
try:
messages = await _fetch_recent_messages(
redis, since=last_run, limit=messages_limit,
)
except Exception:
logger.warning(
"KG batch extraction: fetch failed, not advancing last_run",
exc_info=True,
)
return {
"status": "error",
"reason": "fetch_failed",
"messages_scanned": 0,
"entities_added": 0,
"relationships_added": 0,
}
if not messages:
await _set_last_run(redis)
return {
"status": "completed",
"messages_scanned": 0,
"entities_added": 0,
"relationships_added": 0,
}
by_channel: dict[str, list[dict[str, Any]]] = {}
for m in messages:
cid = m.get("channel_id", "unknown")
by_channel.setdefault(cid, []).append(m)
total_entities = 0
total_rels = 0
for channel_id, channel_msgs in by_channel.items():
conversation_text = "\n".join(
f"[{m.get('user_name', '?')} "
f"({m.get('user_id', '?')})] "
f"{m.get('text', '')}"
for m in channel_msgs
if len(m.get("text", "")) >= _MIN_MESSAGE_LENGTH
)
if not conversation_text.strip():
continue
try:
stats = await extract_knowledge(
conversation_text, openrouter, kg_manager,
channel_id=channel_id,
)
total_entities += stats.get("entities_added", 0)
total_rels += stats.get("relationships_added", 0)
except Exception:
logger.warning(
"Batch extraction failed for channel %s",
channel_id, exc_info=True,
)
await _set_last_run(redis)
return {
"status": "completed",
"messages_scanned": len(messages),
"channels_processed": len(by_channel),
"entities_added": total_entities,
"relationships_added": total_rels,
}
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
async def _fetch_recent_messages(
redis: Any, since: float = 0, limit: int = 100,
) -> list[dict[str, Any]]:
"""Fetch recent messages from channel_msgs ZSETs.
Uses per-channel sorted sets (channel_msgs:{platform}:{channel_id})
with ZRANGEBYSCORE to get messages with timestamp > since in
chronological order. Ensures messages are fetched reliably and
none are skipped.
"""
from message_cache import get_active_channels
all_keys: list[str] = []
async for key in redis.scan_iter(match="channel_msgs:*", count=500):
k = key.decode() if isinstance(key, bytes) else key
all_keys.append(k)
if not all_keys:
return []
# Prefer most recently active channels
active = await get_active_channels(redis, limit=20)
if active:
zset_keys = [
f"channel_msgs:{platform}:{channel_id}"
for platform, channel_id in active
]
# Include any channel_msgs keys not in active (e.g. older channels)
seen = set(zset_keys)
for k in all_keys:
if k not in seen:
zset_keys.append(k)
else:
zset_keys = all_keys
candidates: list[tuple[str, float]] = []
fields = (
"user_id", "user_name", "platform", "channel_id",
"text", "timestamp", "message_id", "reply_to_id",
)
for zset_key in zset_keys:
try:
results = await redis.zrangebyscore(
zset_key, min=since, max="+inf",
withscores=True, start=0, num=limit,
)
except Exception:
continue
for member, score in results:
msg_key = member.decode() if isinstance(member, bytes) else member
candidates.append((msg_key, float(score)))
candidates.sort(key=lambda x: x[1], reverse=True)
candidates = candidates[:limit]
if not candidates:
return []
pipe = redis.pipeline()
for msg_key, _ in candidates:
pipe.hmget(msg_key, *fields)
results = await pipe.execute()
messages: list[dict[str, Any]] = []
for (msg_key, _), values in zip(candidates, results):
if not values or all(v is None for v in values):
continue
mapping = dict(zip(fields, values))
msg = {
k: (v.decode() if isinstance(v, bytes) else v)
for k, v in mapping.items()
}
msg["timestamp"] = float(msg.get("timestamp") or 0)
messages.append(msg)
return messages
async def _check_rate_limit(
redis: Any, user_id: str, limit: int,
) -> bool:
"""Return True if the user is under the extraction limit."""
key = f"{_RATE_LIMIT_PREFIX}{user_id}"
try:
count = await redis.incr(key)
if count == 1:
await redis.expire(key, _RATE_LIMIT_WINDOW)
return count <= limit
except Exception:
return True
async def _get_last_run(redis: Any) -> float:
"""Internal helper: get last run.
Args:
redis (Any): The redis value.
Returns:
float: The result.
"""
try:
val = await redis.get(_REDIS_LAST_RUN_KEY)
return float(val) if val else 0.0
except Exception:
return 0.0
async def _set_last_run(redis: Any) -> None:
"""Internal helper: set last run.
Args:
redis (Any): The redis value.
"""
try:
await redis.set(_REDIS_LAST_RUN_KEY, str(time.time()))
except Exception:
pass