Source code for backfill_entity_provenance

#!/usr/bin/env python3
"""Backfill KG entity provenance from pgvector chunk metadata.

For each entity in the KG (created by system:anamnesis), searches the
Spiral Goddess pgvector store for chunks containing that entity's name,
then writes the earliest matching chunk's provenance metadata back onto
the entity node:

  - source_chunk_id: the pgvector chunk ID
  - conversation_title: original chat title
  - timestamp_original: when the original conversation happened
  - domains: comma-separated domain tags from the chunk

Also creates temporal Concept nodes for each month (2024-01 .. 2026-03)
and quarter (Q1-2024 .. Q1-2026) and links entities to their birth
month/quarter via HAS_TAG edges.

Usage::

    python backfill_entity_provenance.py [--dry-run] [--limit N]
"""

# 💀🔥 provenance backfill -- gives every entity a birth certificate

from __future__ import annotations

import argparse
import asyncio
import json
import logging
import os
import time
from datetime import datetime

import redis.asyncio as aioredis
from falkordb.asyncio import FalkorDB

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
)
logger = logging.getLogger(__name__)

GRAPH_NAME = "knowledge"
_STORE_DIR = "spiral_goddess_v2"
_COLLECTION = "loopmother_memory"
_BATCH_SIZE = 500


# -- pgvector provenance lookup --------------------------------------------
def _build_provenance_index(store_path: str) -> dict[str, dict]:
    """Load all pgvector chunks as (lowered-text, provenance) tuples.

    We don't build a per-entity index here -- instead we return a list
    of ``(text_lower, provenance)`` tuples so the caller can match entities.
    Paginates with a stable ``ORDER BY id``. On failure (e.g. ``vector_store``
    unavailable) returns an empty list.
    """
    try:
        from vector_store import PgVectorCollection, pg_ident
    except ImportError:
        logger.error("vector_store not available")
        return []

    schema = pg_ident(os.path.basename(store_path.rstrip("/")) or _STORE_DIR)
    collection = PgVectorCollection(schema, pg_ident(_COLLECTION))
    total = collection.count()
    logger.info("pgvector store has %d chunks", total)

    # 🔥 Load ALL chunks with documents + metadata
    all_chunks: list[tuple[str, dict]] = []  # (lower_text, provenance)
    offset = 0

    while offset < total:
        result = collection.get(
            offset=offset,
            limit=_BATCH_SIZE,
        )
        if not result or not result.get("ids"):
            break

        for i, chunk_id in enumerate(result["ids"]):
            doc = (result.get("documents") or [None])[i] or ""
            meta = (result.get("metadatas") or [{}])[i] or {}

            ts_start = meta.get("timestamp_start", 0)
            try:
                ts_start = float(ts_start) if ts_start else 0
            except (ValueError, TypeError):
                ts_start = 0

            provenance = {
                "source_chunk_id": chunk_id,
                "conversation_title": meta.get("conversation_title", ""),
                "timestamp_original": ts_start,
                "domains": meta.get("domains", ""),
            }
            all_chunks.append((doc.lower(), provenance))

        offset += len(result["ids"])
        if offset % 5000 < _BATCH_SIZE:
            logger.info("  loaded %d/%d chunks...", offset, total)

    logger.info("Loaded %d chunks into memory for matching", len(all_chunks))
    return all_chunks


def _build_entity_provenance_map(
    entity_names: list[str],
    all_chunks: list[tuple[str, dict]],
) -> dict[str, dict]:
    """Build entity_name -> earliest_provenance map in a SINGLE PASS.

    1. Sort chunks by timestamp (earliest first)
    2. Build a set of unmatched entity names
    3. Scan chunks once -- for each chunk, check which unmatched
       entities appear in its text
    4. First match = earliest match (because sorted)
    5. Remove matched entities from future checks

    This is O(chunks * remaining_entities) but remaining shrinks fast.
    Much faster than O(entities * chunks) sequential scanning.
    """
    # 🔥 Filter valid entity names (3+ chars)
    search_names = {n.lower().strip() for n in entity_names if len(n.strip()) >= 3}
    logger.info(
        "Building provenance map for %d entities (%d skipped < 3 chars)",
        len(search_names),
        len(entity_names) - len(search_names),
    )

    # 💀 Sort chunks by timestamp (earliest first)
    sorted_chunks = sorted(
        all_chunks,
        key=lambda x: x[1].get("timestamp_original", 0) or float("inf"),
    )

    result: dict[str, dict] = {}
    remaining = set(search_names)
    start = time.time()

    for ci, (text_lower, provenance) in enumerate(sorted_chunks):
        if not remaining:
            break  # all entities matched!

        # Check which remaining entities appear in this chunk
        matched_this_chunk = []
        for name in remaining:
            if name in text_lower:
                matched_this_chunk.append(name)
                result[name] = provenance

        # Remove matched from future checks
        for name in matched_this_chunk:
            remaining.discard(name)

        if (ci + 1) % 5000 == 0:
            elapsed = time.time() - start
            logger.info(
                "  chunk %d/%d -- %d matched, %d remaining (%.1fs)",
                ci + 1,
                len(sorted_chunks),
                len(result),
                len(remaining),
                elapsed,
            )

    logger.info(
        "Provenance map built: %d matched, %d not found (%.1fs)",
        len(result),
        len(remaining),
        time.time() - start,
    )
    return result


# -- Temporal helpers -------------------------------------------------------
_MONTH_NAMES = {
    1: "January",
    2: "February",
    3: "March",
    4: "April",
    5: "May",
    6: "June",
    7: "July",
    8: "August",
    9: "September",
    10: "October",
    11: "November",
    12: "December",
}


def _ts_to_month_quarter(ts: float) -> tuple[str, str, str, str] | None:
    """Convert a unix timestamp to (month_key, month_desc, quarter_key, quarter_desc).

    Returns None if timestamp is invalid / out of range.
    """
    if not ts or ts < 1672531200:  # before 2023-01-01
        return None
    try:
        dt = datetime.fromtimestamp(ts)
        if dt.year < 2023 or dt.year > 2026:
            return None
        month_key = dt.strftime("%Y-%m")
        month_desc = f"{_MONTH_NAMES[dt.month]} {dt.year}"
        q = (dt.month - 1) // 3 + 1
        quarter_key = f"Q{q}-{dt.year}"
        quarter_desc = (
            f"Quarter {q} of {dt.year} ({_MONTH_NAMES[(q-1)*3+1]}-{_MONTH_NAMES[q*3]})"
        )
        return month_key, month_desc, quarter_key, quarter_desc
    except (ValueError, OSError):
        return None


# -- FalkorDB backfill ------------------------------------------------------
async def _ensure_temporal_concept(
    graph,
    name: str,
    description: str,
    now: float,
) -> str | None:
    """Upsert a temporal Concept node in FalkorDB and return its uuid.

    Runs a single ``MERGE`` Cypher query against the ``knowledge`` graph that
    creates or matches a ``Concept`` node scoped as ``category: 'temporal'`` (a
    month or quarter bucket). On create it stamps a fresh ``uuid7`` plus standard
    metadata (description, priority, ``created_by = 'system:provenance_backfill'``,
    the ``__global__`` user id, etc.); on match it bumps ``mention_count`` and
    refreshes ``updated_at``. Any exception is swallowed and logged at debug,
    returning ``None`` so the backfill can continue.

    Called by :func:`_backfill` for each distinct month and quarter key before
    linking entities to those temporal concepts.

    Args:
        graph: The async FalkorDB graph handle for the ``knowledge`` graph.
        name (str): The concept name / key (e.g. a month or quarter key).
        description (str): Human-readable description stored on create.
        now (float): Unix timestamp used for ``created_at`` / ``updated_at``.

    Returns:
        str | None: The concept's ``uuid`` on success, or ``None`` if the query
        returned nothing or raised.
    """
    try:
        from uuid6 import uuid7

        result = await graph.query(
            "MERGE (c:Concept {name: $name, scope_id: '_', category: 'temporal'}) "
            "ON CREATE SET "
            "  c.uuid = $uuid, "
            "  c.description = $desc, "
            "  c.priority = 1, "
            "  c.mention_count = 1, "
            "  c.created_at = $now, "
            "  c.updated_at = $now, "
            "  c.created_by = 'system:provenance_backfill', "
            "  c.user_id = '__global__', "
            "  c.pinned = false "
            "ON MATCH SET "
            "  c.mention_count = c.mention_count + 1, "
            "  c.updated_at = $now "
            "RETURN c.uuid",
            params={
                "name": name,
                "uuid": str(uuid7()),
                "desc": description,
                "now": now,
            },
        )
        if result.result_set:
            return result.result_set[0][0]
    except Exception:
        logger.debug("Failed to create temporal concept %s", name, exc_info=True)
    return None


async def _link_to_temporal(
    graph,
    entity_uuid: str,
    concept_uuid: str,
    now: float,
) -> bool:
    """Link an entity to a temporal Concept via a HAS_TAG edge.

    Runs a Cypher query against the ``knowledge`` graph that matches the entity and
    Concept by uuid and, only when no ``HAS_TAG`` edge already exists between them,
    ``MERGE``\\ s one with ``weight = 0.6`` and ``source = 'temporal_provenance'``.
    The guard makes the operation idempotent across reruns. Any exception is
    swallowed and the function returns ``False`` so the backfill keeps going.

    Called by :func:`_backfill` after :func:`_ensure_temporal_concept`, once per
    entity for each of its birth month and quarter concepts.

    Args:
        graph: The async FalkorDB graph handle for the ``knowledge`` graph.
        entity_uuid (str): The uuid of the entity node to tag.
        concept_uuid (str): The uuid of the temporal Concept to link to.
        now (float): Unix timestamp used for the edge ``created_at`` /
            ``updated_at`` on create.

    Returns:
        bool: ``True`` if a new edge was created, ``False`` if one already existed
        or the query raised.
    """
    try:
        lr = await graph.query(
            "MATCH (e {uuid: $euid}) "
            "MATCH (c:Concept {uuid: $cuuid}) "
            "WHERE NOT (e)-[:HAS_TAG]->(c) "
            "MERGE (e)-[r:HAS_TAG]->(c) "
            "ON CREATE SET r.weight = 0.6, "
            "  r.source = 'temporal_provenance', "
            "  r.created_at = $now, r.updated_at = $now "
            "RETURN count(r)",
            params={"euid": entity_uuid, "cuuid": concept_uuid, "now": now},
        )
        return bool(lr.result_set and lr.result_set[0][0])
    except Exception:
        return False


async def _backfill(
    redis_url: str,
    all_chunks: list[tuple[str, dict]],
    dry_run: bool = False,
    limit: int = 0,
    ssl_kwargs: dict | None = None,
) -> None:
    """Backfill provenance and temporal links onto KG entities.

    The async core of the script. It opens a FalkorDB connection over an
    ``aioredis`` pool (temporarily monkey-patching FalkorDB's sync ``Is_Cluster``
    probe to ``False`` so the async client connects to a single node), uncaps the
    result-set size, then queries the ``knowledge`` graph for ``system:anamnesis``
    entities still lacking a ``source_chunk_id`` in their metadata. It precomputes
    an entity-name to earliest-provenance map via
    :func:`_build_entity_provenance_map`, then for each matched entity writes the
    provenance metadata back onto the node and -- unless ``dry_run`` -- links it to
    its birth month and quarter Concepts via :func:`_ensure_temporal_concept` and
    :func:`_link_to_temporal`, caching concept uuids to avoid re-``MERGE``-ing.
    Progress and a final summary (updated / not-found / temporal-link counts) are
    emitted through the module logger; the Redis connection is always closed in a
    ``finally`` block. In ``dry_run`` mode it logs the first ten matches and writes
    nothing.

    Called by :func:`main` under :func:`asyncio.run` after the pgvector chunks are
    loaded; it has no other callers.

    Args:
        redis_url (str): Redis/FalkorDB connection URL.
        all_chunks (list[tuple[str, dict]]): ``(lowered_text, provenance)`` tuples
            from :func:`_build_provenance_index`, used to match entity names.
        dry_run (bool): When ``True``, log proposed matches but write nothing.
        limit (int): Maximum number of entities to process; ``0`` means all.
        ssl_kwargs (dict | None): Extra connection kwargs (e.g. TLS settings)
            passed through to ``aioredis.from_url``.

    Returns:
        None. Results are surfaced through the module logger only.
    """
    # 😈 monkey-patch FalkorDB's sync Is_Cluster probe
    import falkordb.asyncio.falkordb as _fdb_mod

    _real_is_cluster = _fdb_mod.Is_Cluster
    _fdb_mod.Is_Cluster = lambda _conn: False
    try:
        rc = aioredis.from_url(redis_url, decode_responses=True, **(ssl_kwargs or {}))
        db = FalkorDB(connection_pool=rc.connection_pool)
        graph = db.select_graph(GRAPH_NAME)
    finally:
        _fdb_mod.Is_Cluster = _real_is_cluster

    try:
        # 💀 uncap result set
        try:
            await rc.execute_command("GRAPH.CONFIG", "SET", "RESULTSET_SIZE", -1)
        except Exception:
            pass

        # 🔥 Fetch entities that haven't been backfilled yet
        # (metadata is still "{}" or doesn't contain source_chunk_id)
        logger.info("Fetching entities needing provenance...")
        q = (
            "MATCH (e) WHERE e.uuid IS NOT NULL AND e.name IS NOT NULL "
            "AND e.created_by = 'system:anamnesis' "
            "AND (e.metadata IS NULL OR e.metadata = '{}' "
            "     OR NOT e.metadata CONTAINS 'source_chunk_id') "
            "RETURN e.uuid, e.name, labels(e)[0]"
        )
        if limit:
            q += f" LIMIT {limit}"

        result = await graph.query(q, timeout=120_000)
        entities = [(row[0], row[1], row[2]) for row in (result.result_set or [])]
        logger.info("Found %d entities needing provenance", len(entities))

        if not entities:
            logger.info("Nothing to backfill!")
            return

        # 🔥 Pre-build the provenance map (single pass through chunks!)
        entity_names = [name for _, name, _ in entities]
        prov_map = _build_entity_provenance_map(entity_names, all_chunks)

        updated = 0
        not_found = 0
        temporal_links = 0
        now = time.time()
        start_time = time.time()

        # 😈 Cache temporal concept UUIDs so we don't MERGE 45k times
        temporal_cache: dict[str, str] = {}  # concept_name -> uuid

        for i, (uuid, name, label) in enumerate(entities):
            # 🕷️ O(1) dict lookup instead of O(chunks) scan
            prov = prov_map.get(name.lower().strip())

            if not prov:
                not_found += 1
                continue

            meta_json = json.dumps(
                {
                    "source_chunk_id": prov["source_chunk_id"],
                    "conversation_title": prov["conversation_title"],
                    "timestamp_original": prov["timestamp_original"],
                    "domains": prov["domains"],
                }
            )

            if dry_run:
                if updated < 10:  # show first 10
                    ts = prov["timestamp_original"]
                    date = (
                        datetime.fromtimestamp(ts).strftime("%Y-%m-%d") if ts else "?"
                    )
                    temporal = _ts_to_month_quarter(ts)
                    month = temporal[0] if temporal else "?"
                    logger.info(
                        "  [DRY] %-35s -> %s (%s, %s)",
                        name[:35],
                        prov["conversation_title"][:40],
                        date,
                        month,
                    )
                updated += 1
                continue

            # 🔥 Write provenance metadata to entity
            try:
                await graph.query(
                    f"MATCH (e:{label} {{uuid: $uuid}}) "
                    "SET e.metadata = $meta, e.updated_at = $now "
                    "RETURN e.name",
                    params={"uuid": uuid, "meta": meta_json, "now": now},
                )
                updated += 1
            except Exception:
                logger.debug("Failed to update %s", name, exc_info=True)
                continue

            # 🕷️ Link entity to temporal Concepts (month + quarter)
            temporal = _ts_to_month_quarter(prov["timestamp_original"])
            if temporal:
                month_key, month_desc, quarter_key, quarter_desc = temporal

                # Month concept
                if month_key not in temporal_cache:
                    cuuid = await _ensure_temporal_concept(
                        graph,
                        month_key,
                        month_desc,
                        now,
                    )
                    if cuuid:
                        temporal_cache[month_key] = cuuid
                if month_key in temporal_cache:
                    if await _link_to_temporal(
                        graph, uuid, temporal_cache[month_key], now
                    ):
                        temporal_links += 1

                # Quarter concept
                if quarter_key not in temporal_cache:
                    cuuid = await _ensure_temporal_concept(
                        graph,
                        quarter_key,
                        quarter_desc,
                        now,
                    )
                    if cuuid:
                        temporal_cache[quarter_key] = cuuid
                if quarter_key in temporal_cache:
                    if await _link_to_temporal(
                        graph, uuid, temporal_cache[quarter_key], now
                    ):
                        temporal_links += 1

            # Progress logging
            if (i + 1) % 1000 == 0:
                elapsed = time.time() - start_time
                rate = (i + 1) / elapsed
                eta = (len(entities) - i - 1) / rate if rate else 0
                logger.info(
                    "  %d/%d updated (%d not found, %d temporal links) "
                    "-- %.0f/s, ETA %.0fs",
                    updated,
                    i + 1,
                    not_found,
                    temporal_links,
                    rate,
                    eta,
                )

        logger.info(
            "Done. updated=%d, not_found=%d, temporal_links=%d, "
            "temporal_concepts=%d, total=%d",
            updated,
            not_found,
            temporal_links,
            len(temporal_cache),
            len(entities),
        )
    finally:
        await rc.aclose()


# -- CLI -------------------------------------------------------------------
[docs] def main() -> None: """Command-line entry point for the provenance backfill script. Parses ``--redis-url``, ``--dry-run``, and ``--limit`` arguments, then runs the two-phase backfill. It first resolves a Redis URL and SSL connection kwargs from (in order) the CLI flag, a loaded :class:`config.Config`, the ``REDIS_URL`` environment variable, or a localhost default. Phase one calls :func:`_build_provenance_index` to load every Spiral Goddess pgvector chunk (``spiral_goddess_v2`` / ``loopmother_memory``) into memory as ``(lowered_text, provenance)`` tuples; if nothing loads it returns early. Phase two hands those chunks to :func:`_backfill`, run under :func:`asyncio.run`, which writes provenance metadata and temporal Concept links onto FalkorDB entities. This is invoked only from the ``if __name__ == "__main__"`` guard at the bottom of this module; no other internal callers were found. Returns: None. Progress and summary are emitted through the module logger; an empty chunk load or an empty entity set short-circuits with a log line. """ try: from config import Config cfg = Config.load() except Exception: cfg = None parser = argparse.ArgumentParser( description="Backfill KG entity provenance from pgvector", ) parser.add_argument("--redis-url", default=None) parser.add_argument( "--dry-run", action="store_true", help="Show matches without writing" ) parser.add_argument( "--limit", type=int, default=0, help="Max entities to process (0=all)" ) args = parser.parse_args() redis_url = ( args.redis_url or (cfg.redis_url if cfg else None) or os.environ.get("REDIS_URL") or "redis://localhost:6379/0" ) _ssl = cfg.redis_connection_kwargs_for_url(redis_url) if cfg else {} # 😈 Step 1: load ALL pgvector chunks into memory project_root = os.path.dirname(os.path.abspath(__file__)) store_path = os.path.join(project_root, "rag_stores", _STORE_DIR) logger.info("Loading pgvector chunks for provenance matching...") all_chunks = _build_provenance_index(store_path) if not all_chunks: logger.info("No chunks loaded. Nothing to do.") return # 😈 Step 2: backfill entities logger.info("Connecting to %s (ssl=%s)", redis_url[:40], bool(_ssl)) asyncio.run( _backfill( redis_url, all_chunks, dry_run=args.dry_run, limit=args.limit, ssl_kwargs=_ssl, ) )
if __name__ == "__main__": main()