Source code for kg_extraction

"""Knowledge-graph extraction from conversations.

Replaces ``background_agents/auto_memory_extraction.py``.  Uses an LLM
to extract structured entities and relationships from conversation text,
then writes them into the FalkorDB knowledge graph via
:class:`~knowledge_graph.KnowledgeGraphManager`.
"""

from __future__ import annotations

import json
import logging
import re
import time
from typing import Any, TYPE_CHECKING

if TYPE_CHECKING:
    import redis.asyncio as aioredis
    from knowledge_graph import KnowledgeGraphManager
    from openrouter_client import OpenRouterClient

logger = logging.getLogger(__name__)

_REDIS_LAST_RUN_KEY = "stargazer:kg_extraction:last_run"
_RATE_LIMIT_PREFIX = "stargazer:kg_extract_rate:"
_RATE_LIMIT_WINDOW = 3600  # 1 hour
_MESSAGES_PER_RUN = 100
_MIN_MESSAGE_LENGTH = 20
_MIN_SINGLE_MESSAGE_LENGTH = 100

# Heuristic patterns that signal a message likely contains
# extractable personal facts, preferences, or knowledge.
_SIGNAL_PATTERNS: list[re.Pattern] = [
    re.compile(p, re.IGNORECASE) for p in [
        r"\bi am\b", r"\bi'm\b",
        r"\bmy (name|job|work|favorite|fav)\b",
        r"\bi (work|live|study|moved|started)\b",
        r"\bi (like|love|hate|prefer|enjoy|use)\b",
        r"\bi (know|learned|believe|think that)\b",
        r"\bwe (decided|agreed|use|switched)\b",
        r"\bour (team|project|company|stack)\b",
        r"\bremember (that|this|when)\b",
        r"\bdon'?t forget\b",
        r"\bimportant:\s",
        r"\bfyi\b",
        r"\bannouncement\b",
        r"\brule:\s",
        r"\bpolicy:\s",
    ]
]


def _has_extraction_signal(text: str) -> bool:
    """Cheap regex check for knowledge-bearing patterns."""
    return any(p.search(text) for p in _SIGNAL_PATTERNS)

EXTRACTION_PROMPT = """\
Analyze this conversation and extract a knowledge graph.

Return JSON with:
{
  "entities": [
    {
      "name": "...",
      "type": "person|concept|preference|fact|event|location|\
organization|project|technology|rule|directive|role",
      "description": "...",
      "category": "user|channel|general|basic",
      "user_id": "optional, required if category=user"
    }
  ],
  "relationships": [
    {
      "source": "entity_name",
      "target": "entity_name",
      "relation": "RELATION_TYPE",
      "description": "...",
      "confidence": 0.0
    }
  ]
}

Entity types:
- person: A human user, contributor, or known individual.
- concept: An abstract idea, topic, or domain of knowledge.
- preference: A stated like, dislike, or preference.
- fact: A concrete piece of information or data point.
- event: Something that happened or will happen.
- location: A physical or virtual place.
- organization: A company, team, group, or institution.
- project: A named project, repo, or initiative.
- technology: A language, framework, tool, or platform.
- rule: An explicit rule, constraint, or policy that governs behavior.
- directive: An instruction or mandate that guides action.
- role: A named role, permission level, or authority designation.

Recommended relation types (you may also use other descriptive \
UPPER_SNAKE_CASE types):
  Personal: LIKES, DISLIKES, KNOWS, PREFERS, SKILLED_IN, SAID
  Organizational: WORKS_AT, MEMBER_OF, OWNS, CREATED
  Structural: PART_OF, RELATED_TO, USES, LOCATED_IN, IS_A, HAS_PROPERTY
  Governance: ENFORCES, PERMITS, PROHIBITS, SUPERSEDES, DEPENDS_ON
  Temporal: PRECEDED_BY, FOLLOWED_BY, CAUSED

Note: relationships can cross categories. If a user entity relates to an
already-existing entity of any category, just reference it by name.
The system will resolve the target entity across all categories.

Category rules:
- "user": Facts about a specific user (preferences, \
skills, personal info). Requires user_id.
- "channel": Facts relevant to this specific channel's context.
- "general": Facts that apply broadly \
(not user-specific, not channel-specific). Default category.
- "basic": Fundamental, identity-level facts that should always \
be available (e.g. who the owner is, what the system is called). \
Use sparingly -- only for knowledge that is always relevant.
- NEVER use "core" or "guild" -- these are admin-only categories.

Other rules:
- Only extract genuinely important, persistent facts
- Do NOT extract transient chat or greetings
- Default to "general" unless another category clearly fits
- confidence: 1.0 = explicitly stated, 0.5 = implied, 0.3 = inferred
"""

SINGLE_MESSAGE_PROMPT = """\
Extract any important knowledge from this single message.

Speaker: {user_name} (ID: {user_id})
Channel: {channel_id}
Message: {message_text}

Return JSON with the same format as above. Only extract if there are
genuinely important facts. Return empty lists if nothing noteworthy.
"""

_VALID_CATEGORIES = {"user", "channel", "general", "basic"}
_TYPE_MAP = {
    "person": "Person",
    "concept": "Concept",
    "preference": "Preference",
    "fact": "Fact",
    "event": "Event",
    "location": "Location",
    "organization": "Organization",
    "project": "Project",
    "technology": "Technology",
    "rule": "Rule",
    "directive": "Directive",
    "role": "Role",
}


def _parse_llm_json(raw: str) -> dict:
    """Best-effort JSON parsing from LLM output."""
    raw = raw.strip()
    # Strip <thinking>...</thinking> blocks (Gemini extended thinking)
    thinking_end = raw.find("</thinking>")
    if thinking_end != -1:
        raw = raw[thinking_end + len("</thinking>"):].strip()
    if raw.startswith("```"):
        raw = raw.split("\n", 1)[-1].rsplit("```", 1)[0].strip()
    # Last resort: find the first '{' if leading text remains
    if raw and raw[0] != "{":
        brace = raw.find("{")
        if brace != -1:
            raw = raw[brace:]
    return json.loads(raw)


_SENTINEL_USER_ID = "000000000000"


[docs] async def apply_parsed_extraction( data: dict[str, Any], kg_manager: KnowledgeGraphManager, channel_id: str, user_id: str = _SENTINEL_USER_ID, created_by: str = "system:extraction", ) -> dict[str, Any]: """Apply LLM extraction JSON (entities + relationships) to FalkorDB. Each entity may include ``existing_uuid`` — if set, that UUID is registered for relationship resolution and no new node is created. Per-entity ``user_id`` overrides the default *user_id* for category ``user`` scope resolution. """ entities = data.get("entities", []) relationships = data.get("relationships", []) stats: dict[str, Any] = { "entities_added": 0, "relationships_added": 0, "errors": 0, } entity_uuid_lookup: dict[str, str] = {} prepared: list[tuple[str, str, str, str, str, str]] = [] embed_texts: list[str] = [] for ent in entities: name = ent.get("name", "").strip() existing_uuid = str(ent.get("existing_uuid", "") or "").strip() if existing_uuid: if name: entity_uuid_lookup[name.lower()] = existing_uuid continue action = str(ent.get("action", "create") or "create").lower() if action in ("skip", "omit", "reference_only") and not existing_uuid: continue if not name: continue raw_type = ent.get("type", "fact").lower() etype = _TYPE_MAP.get(raw_type, "Fact") description = ent.get("description", "") category = ent.get("category", "general") if category not in _VALID_CATEGORIES: continue ent_user = str(ent.get("user_id", "") or "").strip() uid_for_resolve = ent_user if category == "user" and ent_user else user_id if category == "user": scope_id = uid_for_resolve or "_" elif category == "channel": scope_id = channel_id else: scope_id = "_" prepared.append( (name, etype, description, category, scope_id, uid_for_resolve), ) embed_texts.append( f"{name}: {description}" if description else name, ) vectors: list[list[float] | None] = [None] * len(prepared) if embed_texts: try: vectors = await kg_manager._embed_batch(embed_texts) except Exception: logger.warning( "Batch embedding failed for %d entities, " "falling back to per-entity", len(embed_texts), exc_info=True, ) vectors = [None] * len(prepared) for (name, etype, description, category, scope_id, uid_res), vec in zip( prepared, vectors, ): try: info = await kg_manager._resolve_or_create( name, etype, category, scope_id, description=description, created_by=created_by, user_id=uid_res, embedding=vec, ) entity_uuid_lookup[name.lower()] = info["uuid"] stats["entities_added"] += 1 except Exception: logger.warning("Entity extraction error for %s", name, exc_info=True) stats["errors"] += 1 for rel in relationships: try: src_name = rel.get("source", "").strip() tgt_name = rel.get("target", "").strip() relation = rel.get("relation", "RELATED_TO").upper() desc = rel.get("description", "") confidence = float(rel.get("confidence", 0.5)) src_uuid = str(rel.get("source_uuid", "") or "").strip() tgt_uuid = str(rel.get("target_uuid", "") or "").strip() if not src_uuid and src_name: src_uuid = entity_uuid_lookup.get(src_name.lower()) if not tgt_uuid and tgt_name: tgt_uuid = entity_uuid_lookup.get(tgt_name.lower()) if not src_uuid and src_name: src_uuid = await _resolve_uuid(kg_manager, src_name) or "" if not tgt_uuid and tgt_name: tgt_uuid = await _resolve_uuid(kg_manager, tgt_name) or "" if not src_uuid or not tgt_uuid: continue await kg_manager.add_relationship( src_uuid, tgt_uuid, relation, weight=confidence, description=desc, ) stats["relationships_added"] += 1 except Exception: logger.warning( "Relationship extraction error for %s", rel, exc_info=True, ) stats["errors"] += 1 return stats
[docs] async def extract_knowledge( conversation: str, openrouter: OpenRouterClient, kg_manager: KnowledgeGraphManager, channel_id: str, guild_id: str | None = None, user_id: str = _SENTINEL_USER_ID, conversation_char_limit: int | None = 4000, ) -> dict[str, Any]: """Full extraction pipeline for a block of conversation text. 1. Call LLM with extraction prompt 2. Parse structured JSON 3. Validate categories (reject core/guild) 4. Resolve or create each entity 5. Create/reinforce each relationship 6. Return stats Args: conversation_char_limit: If set, truncate *conversation* to this many characters before sending to the LLM. ``None`` means no truncation. """ body = conversation if conversation_char_limit is not None: body = conversation[:conversation_char_limit] prompt = ( EXTRACTION_PROMPT + "\n\nConversation:\n" + body + "\n\nJSON:" ) sys_msg = ( "You extract structured knowledge graphs " "from conversations. Output only valid JSON." ) msgs = [ {"role": "system", "content": sys_msg}, {"role": "user", "content": prompt}, ] try: raw = await openrouter.chat(msgs) data = _parse_llm_json(raw) except (json.JSONDecodeError, Exception): logger.warning( "KG extraction LLM parse failed", exc_info=True, ) return { "entities_added": 0, "relationships_added": 0, "errors": 1, } return await apply_parsed_extraction( data, kg_manager, channel_id, user_id=user_id, created_by="system:extraction", )
_CROSS_LABELS = [ "Fact", "Person", "Organization", "Technology", "Project", "Concept", "Rule", "Directive", "Role", ] async def _resolve_uuid( kg_manager: KnowledgeGraphManager, name: str, ) -> str | None: """Find an entity by name across all labels and return its UUID.""" name_lower = name.strip().lower() for label in _CROSS_LABELS: cross = ( await kg_manager.resolve_entity_cross_category( name_lower, label, ) ) if cross and cross.get("uuid"): return cross["uuid"] return await _guess_uuid(kg_manager, name_lower) async def _guess_uuid( kg_manager: KnowledgeGraphManager, name: str, ) -> str | None: """Try to find an entity's UUID by name across all labels.""" from knowledge_graph import ENTITY_LABELS name_lower = name.strip().lower() for label in ENTITY_LABELS: try: q = ( f"MATCH (e:{label}) " f"WHERE e.name = $name " f"RETURN e.uuid LIMIT 1" ) result = await kg_manager._graph.query( q, params={"name": name_lower}, ) if result.result_set and result.result_set[0][0]: return result.result_set[0][0] except Exception: continue return None
[docs] async def extract_from_message( message_text: str, user_id: str, user_name: str, channel_id: str, guild_id: str | None, openrouter: OpenRouterClient, kg_manager: KnowledgeGraphManager, redis: aioredis.Redis | None = None, per_user_limit: int = 5, ) -> None: """Per-message extraction with cheap pre-filtering. This function is called fire-and-forget but is gated by three layers *before* any LLM call is made: 1. **Length gate** -- message must be >= 100 chars. 2. **Heuristic gate** -- message must match at least one regex pattern that signals knowledge content. 3. **Rate limit** -- max *per_user_limit* extractions per user per hour (via Redis INCR + EXPIRE). """ if len(message_text) < _MIN_SINGLE_MESSAGE_LENGTH: return if not _has_extraction_signal(message_text): return if redis is not None: if not await _check_rate_limit( redis, user_id, per_user_limit, ): return prompt = SINGLE_MESSAGE_PROMPT.format( user_name=user_name, user_id=user_id, channel_id=channel_id, message_text=message_text[:2000], ) full_prompt = EXTRACTION_PROMPT + "\n\n" + prompt + "\n\nJSON:" sys_content = ( "You extract structured knowledge graphs " "from messages. Output only valid JSON." ) msgs = [ {"role": "system", "content": sys_content}, {"role": "user", "content": full_prompt}, ] try: raw = await openrouter.chat(msgs) data = _parse_llm_json(raw) except Exception: logger.warning("Per-message extraction parse failed", exc_info=True) return entities = data.get("entities", []) relationships = data.get("relationships", []) entity_uuid_lookup: dict[str, str] = {} prepared: list[tuple[str, str, str, str, str]] = [] embed_texts: list[str] = [] for ent in entities: name = ent.get("name", "").strip() if not name: continue raw_type = ent.get("type", "fact").lower() etype = _TYPE_MAP.get(raw_type, "Fact") description = ent.get("description", "") category = ent.get("category", "general") if category not in _VALID_CATEGORIES: continue if category == "user": scope_id = user_id elif category == "channel": scope_id = channel_id else: scope_id = "_" prepared.append((name, etype, description, category, scope_id)) embed_texts.append( f"{name}: {description}" if description else name, ) vectors: list[list[float] | None] = [None] * len(prepared) if embed_texts: try: vectors = await kg_manager._embed_batch(embed_texts) except Exception: logger.warning( "Per-message batch embedding failed for %d entities, " "falling back to per-entity", len(embed_texts), exc_info=True, ) vectors = [None] * len(prepared) for (name, etype, description, category, scope_id), vec in zip( prepared, vectors, ): try: info = await kg_manager._resolve_or_create( name, etype, category, scope_id, description=description, created_by=f"system:msg_extraction:{user_id}", user_id=user_id, embedding=vec, ) entity_uuid_lookup[name.lower()] = info["uuid"] except Exception: logger.warning("Per-message entity error", exc_info=True) for rel in relationships: try: src_name = rel.get("source", "").strip() tgt_name = rel.get("target", "").strip() relation = rel.get("relation", "RELATED_TO").upper() desc = rel.get("description", "") confidence = float(rel.get("confidence", 0.5)) if not src_name or not tgt_name: continue src_uuid = entity_uuid_lookup.get(src_name.lower()) tgt_uuid = entity_uuid_lookup.get(tgt_name.lower()) if not src_uuid: src_uuid = await _resolve_uuid( kg_manager, src_name, ) if not tgt_uuid: tgt_uuid = await _resolve_uuid( kg_manager, tgt_name, ) if not src_uuid or not tgt_uuid: continue await kg_manager.add_relationship( src_uuid, tgt_uuid, relation, weight=confidence, description=desc, ) except Exception: logger.warning("Per-message relationship error", exc_info=True)
# --------------------------------------------------------------------------- # Batch extraction (periodic background task) # ---------------------------------------------------------------------------
[docs] async def run_batch_extraction( redis: Any, kg_manager: KnowledgeGraphManager, openrouter: OpenRouterClient, messages_limit: int = _MESSAGES_PER_RUN, ) -> dict[str, Any]: """Scan recent messages and extract KG entities. Called periodically by the background scheduler. """ if not all((redis, kg_manager, openrouter)): return { "status": "disabled", "reason": "missing dependencies", } last_run = await _get_last_run(redis) try: messages = await _fetch_recent_messages( redis, since=last_run, limit=messages_limit, ) except Exception: logger.warning( "KG batch extraction: fetch failed, not advancing last_run", exc_info=True, ) return { "status": "error", "reason": "fetch_failed", "messages_scanned": 0, "entities_added": 0, "relationships_added": 0, } if not messages: await _set_last_run(redis) return { "status": "completed", "messages_scanned": 0, "entities_added": 0, "relationships_added": 0, } by_channel: dict[str, list[dict[str, Any]]] = {} for m in messages: cid = m.get("channel_id", "unknown") by_channel.setdefault(cid, []).append(m) total_entities = 0 total_rels = 0 for channel_id, channel_msgs in by_channel.items(): conversation_text = "\n".join( f"[{m.get('user_name', '?')} " f"({m.get('user_id', '?')})] " f"{m.get('text', '')}" for m in channel_msgs if len(m.get("text", "")) >= _MIN_MESSAGE_LENGTH ) if not conversation_text.strip(): continue try: stats = await extract_knowledge( conversation_text, openrouter, kg_manager, channel_id=channel_id, ) total_entities += stats.get("entities_added", 0) total_rels += stats.get("relationships_added", 0) except Exception: logger.warning( "Batch extraction failed for channel %s", channel_id, exc_info=True, ) await _set_last_run(redis) return { "status": "completed", "messages_scanned": len(messages), "channels_processed": len(by_channel), "entities_added": total_entities, "relationships_added": total_rels, }
# --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- async def _fetch_recent_messages( redis: Any, since: float = 0, limit: int = 100, ) -> list[dict[str, Any]]: """Fetch recent messages from channel_msgs ZSETs. Uses per-channel sorted sets (channel_msgs:{platform}:{channel_id}) with ZRANGEBYSCORE to get messages with timestamp > since in chronological order. Ensures messages are fetched reliably and none are skipped. """ from message_cache import get_active_channels all_keys: list[str] = [] async for key in redis.scan_iter(match="channel_msgs:*", count=500): k = key.decode() if isinstance(key, bytes) else key all_keys.append(k) if not all_keys: return [] # Prefer most recently active channels active = await get_active_channels(redis, limit=20) if active: zset_keys = [ f"channel_msgs:{platform}:{channel_id}" for platform, channel_id in active ] # Include any channel_msgs keys not in active (e.g. older channels) seen = set(zset_keys) for k in all_keys: if k not in seen: zset_keys.append(k) else: zset_keys = all_keys candidates: list[tuple[str, float]] = [] fields = ( "user_id", "user_name", "platform", "channel_id", "text", "timestamp", "message_id", "reply_to_id", ) for zset_key in zset_keys: try: results = await redis.zrangebyscore( zset_key, min=since, max="+inf", withscores=True, start=0, num=limit, ) except Exception: continue for member, score in results: msg_key = member.decode() if isinstance(member, bytes) else member candidates.append((msg_key, float(score))) candidates.sort(key=lambda x: x[1], reverse=True) candidates = candidates[:limit] if not candidates: return [] pipe = redis.pipeline() for msg_key, _ in candidates: pipe.hmget(msg_key, *fields) results = await pipe.execute() messages: list[dict[str, Any]] = [] for (msg_key, _), values in zip(candidates, results): if not values or all(v is None for v in values): continue mapping = dict(zip(fields, values)) msg = { k: (v.decode() if isinstance(v, bytes) else v) for k, v in mapping.items() } msg["timestamp"] = float(msg.get("timestamp") or 0) messages.append(msg) return messages async def _check_rate_limit( redis: Any, user_id: str, limit: int, ) -> bool: """Return True if the user is under the extraction limit.""" key = f"{_RATE_LIMIT_PREFIX}{user_id}" try: count = await redis.incr(key) if count == 1: await redis.expire(key, _RATE_LIMIT_WINDOW) return count <= limit except Exception: return True async def _get_last_run(redis: Any) -> float: """Internal helper: get last run. Args: redis (Any): The redis value. Returns: float: The result. """ try: val = await redis.get(_REDIS_LAST_RUN_KEY) return float(val) if val else 0.0 except Exception: return 0.0 async def _set_last_run(redis: Any) -> None: """Internal helper: set last run. Args: redis (Any): The redis value. """ try: await redis.set(_REDIS_LAST_RUN_KEY, str(time.time())) except Exception: pass