"""FalkorDB-backed knowledge graph with hybrid vector+graph retrieval.
Replaces the flat Redis memory system (``memory_manager.py``) with a
property graph that stores entities as nodes and relationships as typed
edges. Each node carries a 3072-dim embedding vector indexed via HNSW
for semantic search, while multi-hop Cypher traversal connects related
facts across the graph.
The five-tier authority hierarchy (core > guild > channel > general > user)
is preserved as ``category`` / ``priority`` properties on every node and edge.
"""
from __future__ import annotations
import asyncio
import re
import jsonutil as json
import logging
import time
from typing import Any, TYPE_CHECKING
import numpy as np
from falkordb.asyncio import FalkorDB
# V73: Compiled once at module level — previously re.compile() was called inside
# the hot query() path on every write, adding CPU overhead proportional to write
# volume. Whole-word anchors (\b) prevent false positives on ASSET/OFFSET/DATASET.
_CYPHER_WRITE_PATTERN = re.compile(
r"\b(CREATE|MERGE|SET|DELETE|REMOVE)\b", re.IGNORECASE
)
# ---------------------------------------------------------------------------
# Monkeypatch FalkorDB cluster detection to avoid TypeError on older redis-py
# where sync_redis.Redis(**kwargs) fails if ssl_context is in kwargs.
# ---------------------------------------------------------------------------
try:
import falkordb.asyncio.falkordb as _fdb_mod
import falkordb.asyncio.cluster as _fdb_cluster
import falkordb.falkordb as _fdb_sync_mod
import falkordb.cluster as _fdb_sync_cluster
import redis as sync_redis
def _safe_is_cluster(conn) -> bool:
try:
pool = conn.connection_pool
conn_kw = pool.connection_kwargs
kwargs = {
"host": conn_kw.get("host", "localhost"),
"port": conn_kw.get("port", 6379),
"db": conn_kw.get("db", 0),
"username": conn_kw.get("username"),
"password": conn_kw.get("password"),
}
conn_class = getattr(pool, "connection_class", None)
conn_class_name = conn_class.__name__ if conn_class is not None else ""
is_ssl = (
conn_class is sync_redis.SSLConnection
or conn_class_name == "SSLConnection"
or conn_kw.get("ssl") is True
)
if is_ssl:
kwargs["ssl"] = True
kwargs["ssl_keyfile"] = conn_kw.get("ssl_keyfile")
kwargs["ssl_certfile"] = conn_kw.get("ssl_certfile")
kwargs["ssl_ca_certs"] = conn_kw.get("ssl_ca_certs")
kwargs["ssl_cert_reqs"] = conn_kw.get("ssl_cert_reqs", "required")
kwargs["ssl_check_hostname"] = conn_kw.get("ssl_check_hostname", True)
kwargs["socket_timeout"] = 1.0
info = sync_redis.Redis(**kwargs).info(section="server")
return "redis_mode" in info and info["redis_mode"] == "cluster"
except Exception:
# Fallback to False to let the single-node setup continue gracefully
return False
_fdb_mod.Is_Cluster = _safe_is_cluster
_fdb_cluster.Is_Cluster = _safe_is_cluster
if hasattr(_fdb_sync_mod, "Is_Cluster"):
_fdb_sync_mod.Is_Cluster = _safe_is_cluster
if hasattr(_fdb_sync_cluster, "Is_Cluster"):
_fdb_sync_cluster.Is_Cluster = _safe_is_cluster
except Exception:
pass
# ---------------------------------------------------------------------------
from uuid6 import uuid7
from .constants import (
CATEGORY_PRIORITY,
ENTITY_LABELS,
_NO_SCOPE,
_SENTINEL_USER_ID,
)
if TYPE_CHECKING:
import redis.asyncio as aioredis
from openrouter_client import OpenRouterClient
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# KnowledgeGraphManager
# ---------------------------------------------------------------------------
[docs]
class KnowledgeGraphManager:
"""Manages the FalkorDB knowledge graph for entity/relationship
CRUD and hybrid retrieval.
"""
GRAPH_NAME = "knowledge"
[docs]
def __init__(
self,
redis_client: aioredis.Redis,
openrouter: OpenRouterClient,
embedding_model: str = "google/gemini-embedding-001",
admin_user_ids: set[str] | None = None,
dedup_threshold: float = 0.90,
) -> None:
"""Initialize the instance.
Args:
redis_client (aioredis.Redis): Redis connection client.
openrouter (OpenRouterClient): The openrouter value.
embedding_model (str): The embedding model value.
admin_user_ids (set[str] | None): The admin user ids value.
"""
self._redis = redis_client
self._db = FalkorDB(
connection_pool=redis_client.connection_pool,
)
self._graph = self._db.select_graph(self.GRAPH_NAME)
self._openrouter = openrouter
self._embedding_model = embedding_model
self._admin_user_ids: set[str] = admin_user_ids or set()
self._dedup_threshold = dedup_threshold
# Set to True once ensure_indexes() completes all phases successfully.
# Callers (e.g. retrieval) can check this flag to decide whether to
# attempt a vector KNN search or gracefully degrade.
self._indexes_ready: bool = False
# SWORD monitor integration hook
self._sword_monitor: Any = None
# ---------------------------------------------------------------------------
# Query serialization and priority scheme
#
# FalkorDB can handle concurrent queries, but excessive parallel load can
# cause slow queries to queue and time out. We limit overall concurrency
# with an asyncio.Semaphore (_query_semaphore), allowing batch concurrency.
#
# Priority:
# • Foreground callers (is_background=False) increment _foreground_waiters
# BEFORE acquiring the semaphore and decrement it AFTER releasing.
# • Background callers call wait_for_foreground_idle() to check that
# _foreground_waiters == 0 BEFORE they try to acquire the semaphore.
# • This guarantees background traffic yields entirely when live traffic
# is active, while still allowing batch concurrency when idle.
# ---------------------------------------------------------------------------
self._query_semaphore = asyncio.Semaphore(20)
self._foreground_waiters: int = 0
self._foreground_cond = asyncio.Condition()
@property
def indexes_ready(self) -> bool:
"""True once :meth:`ensure_indexes` has completed all phases.
Retrieval code checks this flag before issuing HNSW KNN queries;
if False it means the index warm-up is still in progress and
vector search will return no results from FalkorDB.
"""
return self._indexes_ready
[docs]
async def wait_for_foreground_idle(self) -> None:
"""Block until no foreground queries are pending or executing.
Called by background workers before each query attempt. Because
background callers hold _query_semaphore for only one query at a time,
a foreground caller that arrives mid-background-query only waits
for the current background query to finish (not the entire batch).
"""
if self._foreground_waiters > 0:
async with self._foreground_cond:
await self._foreground_cond.wait_for(
lambda: self._foreground_waiters == 0
)
# Keep old name for compatibility with already-deployed code.
[docs]
async def wait_for_idle(self) -> None:
"""Alias for :meth:`wait_for_foreground_idle`."""
await self.wait_for_foreground_idle()
[docs]
async def query(
self, q: str, params: dict | None = None, is_background: bool = False, **kwargs
) -> Any:
"""Execute a Cypher query against the knowledge graph.
Foreground callers (is_background=False) increment the waiter counter
before acquiring the concurrency semaphore so that background workers
will yield to them at their next boundary.
"""
res = None
if not is_background:
self._foreground_waiters += 1
try:
async with self._query_semaphore:
res = await self._graph.query(q, params=params, **kwargs)
finally:
self._foreground_waiters -= 1
if self._foreground_waiters == 0:
async with self._foreground_cond:
self._foreground_cond.notify_all()
else:
# Background: yield to any pending foreground, then acquire semaphore.
await self.wait_for_foreground_idle()
async with self._query_semaphore:
res = await self._graph.query(q, params=params, **kwargs)
# Hook to notify the SWORD monitor on write operations
if self._sword_monitor:
# V73: Use module-level compiled pattern instead of re.compile() per call.
if _CYPHER_WRITE_PATTERN.search(q):
asyncio.create_task(self._sword_monitor.on_kg_write())
return res
[docs]
async def ro_query(
self, q: str, params: dict | None = None, is_background: bool = False, **kwargs
) -> Any:
"""Execute a read-only Cypher query against the knowledge graph.
Proxies the underlying FalkorDB ro_query method. Applies the same
concurrency and priority logic as :meth:`query`.
"""
if not is_background:
self._foreground_waiters += 1
try:
async with self._query_semaphore:
return await self._graph.ro_query(q, params=params, **kwargs)
finally:
self._foreground_waiters -= 1
if self._foreground_waiters == 0:
async with self._foreground_cond:
self._foreground_cond.notify_all()
else:
await self.wait_for_foreground_idle()
async with self._query_semaphore:
return await self._graph.ro_query(q, params=params, **kwargs)
[docs]
async def embed(self, text: str) -> list[float]:
"""Public proxy for the internal _embed method."""
return await self._embed(text)
[docs]
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
"""Public proxy for the internal _embed_batch method."""
return await self._embed_batch(texts)
# ------------------------------------------------------------------
# Embedding helper
# ------------------------------------------------------------------
async def _embed(self, text: str) -> list[float]:
"""Embed text using Gemini direct API (free tier).
Forces the Gemini-direct path to avoid paid OpenRouter
embedding calls. The generic ``openrouter.embed()`` randomly
routes through OpenRouter ~50% of the time.
"""
return await self._openrouter._embed_gemini(
text,
self._embedding_model,
)
async def _embed_batch(self, texts: list[str]) -> list[list[float]]:
"""Batch-embed texts using Gemini direct batchEmbedContents API."""
if not texts:
return []
return await self._openrouter._embed_gemini_batch(
texts,
self._embedding_model,
)
# ------------------------------------------------------------------
# Index management
# ------------------------------------------------------------------
async def _fetch_existing_indexes(
self,
) -> tuple[set[str], set[tuple[str, str]]]:
"""Query FalkorDB for all existing indexes in one ``CALL db.indexes()``.
Returns a 2-tuple:
* ``existing_vector`` — label names whose ``embedding`` property is
already indexed as a VECTOR (HNSW) index.
* ``existing_range`` — ``(label, property)`` pairs that already have
a RANGE (B-Tree) index.
Both sets are consumed by :meth:`ensure_indexes` to eliminate
redundant DDL round-trips on warm restarts. Returns ``(set(),
set())`` on any error so callers degrade safely to the original
build-everything behavior.
"""
try:
result = await self._graph.query("CALL db.indexes()")
existing_vector: set[str] = set()
existing_range: set[tuple[str, str]] = set()
for row in result.result_set or []:
# Row layout (may vary by FalkorDB version):
# [label, properties, type, language, stopwords, entitytype, status]
if not row or len(row) < 3:
continue
label = str(row[0])
props_str = str(row[1]).lower() if len(row) > 1 else ""
idx_type = str(row[2]).upper()
if "VECTOR" in idx_type and "embedding" in props_str:
existing_vector.add(label)
elif "RANGE" in idx_type or "BTREE" in idx_type:
# props_str may be a bracketed list like "['name', 'uuid']"
for prop in props_str.split(","):
prop = prop.strip().strip("[]'\" ")
if prop:
existing_range.add((label, prop))
logger.debug(
"[idx] Existing HNSW vector indexes: %s",
sorted(existing_vector) if existing_vector else "(none)",
)
logger.debug(
"[idx] Existing range indexes: %d pairs",
len(existing_range),
)
return existing_vector, existing_range
except Exception:
logger.debug(
"[idx] Could not introspect existing indexes "
"(db.indexes() unavailable?); will attempt all",
exc_info=True,
)
return set(), set()
async def _count_nodes_by_label(self) -> dict[str, int]:
"""Return the node count for every entity label in a single query.
Used to skip HNSW index creation for labels that have zero nodes —
building an empty index wastes a full FalkorDB round-trip.
Returns:
dict[str, int]: ``{label: node_count}`` for all entity labels
present in the graph. Labels with zero nodes are absent from
the dict (FalkorDB omits them from MATCH aggregations).
Returns an empty dict on any error so the caller degrades
safely to the original build-everything behavior.
"""
try:
result = await self._graph.query(
"MATCH (e) RETURN labels(e)[0] AS lbl, count(e) AS cnt",
)
counts: dict[str, int] = {}
for row in result.result_set or []:
if not row or row[0] is None:
continue
try:
counts[str(row[0])] = int(row[1])
except (TypeError, ValueError):
continue
logger.debug(
"[idx] Node counts per label: %s",
{k: v for k, v in sorted(counts.items()) if k in ENTITY_LABELS},
)
return counts
except Exception:
logger.debug(
"[idx] Could not fetch node counts per label; will attempt all labels",
exc_info=True,
)
return {}
[docs]
async def ensure_indexes(self) -> None:
"""Create vector + range indexes for every entity label.
Optimized with two pre-flight checks to eliminate unnecessary
FalkorDB round-trips on warm restarts:
* **Strategy 1 — existence pre-check**: ``CALL db.indexes()`` is
issued once before the Phase 1 loop. Labels whose HNSW index is
already present are skipped entirely; no ``CREATE VECTOR INDEX``
call is made for them.
* **Strategy 2 — zero-node skip**: A single ``MATCH`` aggregation
counts nodes per label. Labels with zero nodes are also skipped
because building an empty index wastes a full server round-trip.
* **Strategy 4 — readiness flag**: ``self._indexes_ready`` is set
to ``True`` only after all phases complete successfully. Retrieval
code checks this flag before issuing KNN queries.
Both pre-checks fail safe: if the introspection query errors (e.g.
older FalkorDB that lacks ``db.indexes()``), the corresponding
skip-set is empty and all labels are attempted as before.
"""
self._indexes_ready = False # reset in case of re-entrant call
logger.info(
"[idx] ensure_indexes: starting for %d entity labels",
len(ENTITY_LABELS),
)
# --- Pre-flight: single db.indexes() call + node count aggregation ---
# _fetch_existing_indexes() returns BOTH vector and range index sets so
# Phase 1 (HNSW) and Phase 2 (range) can both skip pre-existing indexes
# without any additional round-trips.
logger.info("[idx] Pre-flight: querying existing indexes and node counts...")
(existing_vector, existing_range), node_counts = await asyncio.gather(
self._fetch_existing_indexes(),
self._count_nodes_by_label(),
)
existing_indexes = existing_vector # keep local alias for HNSW checks below
# Determine which labels actually need HNSW construction.
labels_need_build: list[str] = []
labels_skip_exists: list[str] = []
labels_skip_empty: list[str] = []
for label in ENTITY_LABELS:
if label in existing_indexes:
labels_skip_exists.append(label)
elif (
node_counts
and node_counts.get(label, 0) == 0
and label in node_counts
or (node_counts and label not in node_counts and len(node_counts) > 0)
):
# node_counts is populated but label has no nodes — safe to skip.
labels_skip_empty.append(label)
else:
labels_need_build.append(label)
logger.info(
"[idx] Pre-flight summary: %d to build, %d already indexed (skip), "
"%d empty (skip). Labels to build: %s",
len(labels_need_build),
len(labels_skip_exists),
len(labels_skip_empty),
labels_need_build or "(none)",
)
if labels_skip_exists:
logger.info(
"[idx] Skipping HNSW creation for already-indexed labels: %s",
labels_skip_exists,
)
if labels_skip_empty:
logger.info(
"[idx] Skipping HNSW creation for zero-node labels: %s",
labels_skip_empty,
)
# --- Watchdog: tracks per-label step so hangs are precisely located ---
_pending: set[str] = set(ENTITY_LABELS)
# Maps label -> current operation string, updated at every await boundary.
_label_state: dict[str, str] = {lbl: "queued" for lbl in ENTITY_LABELS}
for lbl in labels_skip_exists:
_label_state[lbl] = "hnsw_exists"
_pending.discard(lbl)
for lbl in labels_skip_empty:
_label_state[lbl] = "hnsw_skipped_empty"
_pending.discard(lbl)
_WATCHDOG_INTERVAL = 3 # seconds between watchdog ticks
async def _watchdog() -> None:
elapsed = 0
while True:
await asyncio.sleep(_WATCHDOG_INTERVAL)
elapsed += _WATCHDOG_INTERVAL
if not _pending:
return
state_lines = ", ".join(
f"{lbl}={_label_state.get(lbl, '?')}" for lbl in sorted(_pending)
)
logger.warning(
"ensure_indexes watchdog [%ds elapsed]: "
"%d label(s) still in flight: %s",
elapsed,
len(_pending),
state_lines,
)
watchdog_task = asyncio.ensure_future(_watchdog())
try:
# --- Phase 1: HNSW vector indexes — sequential ---
# FalkorDB is single-threaded; concurrent CREATE VECTOR INDEX
# commands all queue server-side while each connection blocks.
# Sequential issuance is equally fast and avoids connection
# pool saturation.
if labels_need_build:
logger.info(
"[idx] Phase 1: HNSW vector indexes (sequential, %d/%d labels)",
len(labels_need_build),
len(ENTITY_LABELS),
)
for label in labels_need_build:
await self._create_hnsw_index(label, _label_state)
else:
logger.info(
"[idx] Phase 1: all HNSW indexes already present or labels empty — "
"skipping construction entirely",
)
# --- Phase 2: Range indexes — concurrent across all labels ---
# existing_range from the pre-flight call lets each label skip
# already-present (label, prop) pairs without a server round-trip.
n_range_skip = sum(
1
for lbl in ENTITY_LABELS
for prop in (
"name",
"category",
"scope_id",
"pinned",
"uuid",
"user_id",
)
if (lbl, prop) in existing_range
)
n_range_total = len(ENTITY_LABELS) * 6
logger.info(
"[idx] Phase 2: range indexes (concurrent, %d labels × 6 props, "
"%d/%d pre-existing — skipping)",
len(ENTITY_LABELS),
n_range_skip,
n_range_total,
)
await asyncio.gather(
*[
self._create_range_indexes_for_label(
label,
_pending,
_label_state,
existing_range=existing_range,
)
for label in ENTITY_LABELS
]
)
finally:
watchdog_task.cancel()
logger.info(
"[idx] Starting knowledge graph backfill of user_id in background..."
)
# Fire and forget: do not block indexes_ready on the data backfill.
self._backfill_task = asyncio.create_task(self._backfill_user_ids())
logger.info("[idx] Verifying structural indexes and constraints...")
# Pass existing_range so structural DDL also skips pre-existing pairs.
await self._ensure_structural_indexes(existing_range=existing_range)
# Strategy 4: mark indexes as ready for retrieval code.
self._indexes_ready = True
logger.info(
"[idx] ensure_indexes complete: %d labels processed "
"(%d built, %d pre-existing, %d empty). "
"KG HNSW indexes fully warmed and ready.",
len(ENTITY_LABELS),
len(labels_need_build),
len(labels_skip_exists),
len(labels_skip_empty),
)
async def _create_hnsw_index(
self,
label: str,
_label_state: dict[str, str] | None = None,
) -> None:
"""Create the HNSW vector index for a single label.
Must be called sequentially (not concurrently) because FalkorDB
serialises ``CREATE VECTOR INDEX`` commands internally. Firing them
concurrently ties up one connection per label while they queue
server-side, which starves the pool and produces an apparent hang.
"""
def _state(s: str) -> None:
if _label_state is not None:
_label_state[label] = s
_state("hnsw_create")
logger.info("[idx] %s: awaiting HNSW vector index creation...", label)
try:
await self._graph.create_node_vector_index(
label,
"embedding",
dim=3072,
similarity_function="cosine",
)
_state("hnsw_done")
logger.info("[idx] %s: HNSW vector index OK", label)
except Exception as exc:
msg = str(exc).lower()
if "already indexed" not in msg and "already exists" not in msg:
_state("hnsw_error")
logger.warning(
"[idx] %s: HNSW vector index error: %s",
label,
exc,
)
else:
_state("hnsw_exists")
logger.info(
"[idx] %s: HNSW vector index already exists, skipping",
label,
)
async def _create_range_indexes_for_label(
self,
label: str,
_pending: set[str] | None = None,
_label_state: dict[str, str] | None = None,
*,
existing_range: set[tuple[str, str]] | None = None,
) -> None:
"""Issue all 6 range index DDLs for a single label concurrently.
Called concurrently across all labels in Phase 2 of
:meth:`ensure_indexes`, after HNSW creation is complete.
Args:
existing_range: Pre-populated ``(label, property)`` pairs from the
pre-flight ``CALL db.indexes()`` call. When provided, pairs
already present in this set are skipped without a server
round-trip.
"""
if _label_state is not None:
_label_state[label] = "range_gather"
logger.debug(
"[idx] %s: ensuring 6 range indexes concurrently...",
label,
)
await asyncio.gather(
*[
self._create_range_index_safe(
label,
prop,
_label_state,
existing_range=existing_range,
)
for prop in (
"name",
"category",
"scope_id",
"pinned",
"uuid",
"user_id",
)
]
)
if _label_state is not None:
_label_state[label] = "done"
logger.debug("[idx] %s: all range indexes done", label)
if _pending is not None:
_pending.discard(label)
async def _create_label_indexes(
self,
label: str,
_pending: set[str] | None = None,
_label_state: dict[str, str] | None = None,
) -> None:
"""Create vector index and all range indexes for a single entity label.
.. deprecated::
Split into :meth:`_create_hnsw_index` and
:meth:`_create_range_indexes_for_label` by :meth:`ensure_indexes`.
Retained only for external callers (e.g. tests).
"""
await self._create_hnsw_index(label, _label_state)
await self._create_range_indexes_for_label(label, _pending, _label_state)
async def _create_range_index_safe(
self,
label: str,
prop: str,
_label_state: dict[str, str] | None = None,
*,
existing_range: set[tuple[str, str]] | None = None,
) -> None:
"""Create a single range index, suppressing already-exists errors.
Args:
existing_range: Pre-populated ``(label, property)`` pairs from the
pre-flight ``CALL db.indexes()`` call. When the pair is found
here the DDL call is skipped entirely — no server round-trip.
"""
# Pre-check: skip the DDL entirely if the index is known to exist.
if existing_range is not None and (label, prop) in existing_range:
if _label_state is not None:
_label_state[label] = f"range_exists:{prop}"
logger.debug(
"[idx] %s.%s: range index pre-existing, skipping DDL",
label,
prop,
)
return
if _label_state is not None:
_label_state[label] = f"range_await:{prop}"
logger.debug("[idx] %s.%s: creating range index...", label, prop)
try:
await self._graph.create_node_range_index(label, prop)
if _label_state is not None:
_label_state[label] = f"range_done:{prop}"
logger.info("[idx] %s.%s: range index OK", label, prop)
except Exception as exc:
msg = str(exc).lower()
if "already exists" not in msg and "already indexed" not in msg:
if _label_state is not None:
_label_state[label] = f"range_error:{prop}"
logger.warning(
"[idx] %s.%s: range index error: %s",
label,
prop,
exc,
)
else:
if _label_state is not None:
_label_state[label] = f"range_exists:{prop}"
logger.debug(
"[idx] %s.%s: range index already exists, skipping",
label,
prop,
)
async def _ensure_structural_indexes(
self,
existing_range: set[tuple[str, str]] | None = None,
) -> None:
"""Create B-Tree range indexes and uniqueness constraints for Limbic
Anchoring structural hub nodes.
Range indexes are issued concurrently across all labels and properties.
Uniqueness constraints are applied sequentially as a safe default.
This is idempotent — already-existing indexes/constraints are suppressed.
Args:
existing_range: Pre-populated ``(label, property)`` pairs from the
pre-flight ``CALL db.indexes()`` call. When provided, pairs
already in this set are skipped without a server round-trip.
"""
from knowledge_graph.constants import (
STRUCTURAL_LABELS,
) # local import avoids circular
# --- Range indexes: fan out concurrently across all labels × props ---
_range_index_props: dict[str, list[str]] = {
"Message": ["timestamp", "redis_key"],
"Channel": ["channel_id", "platform"],
"ChannelEpoch": ["epoch_seq", "channel_id"],
"DailySummary": ["date", "channel_id"],
"EpochSummary": ["start_ts", "end_ts", "channel_id"],
}
# Inline helper keeps this method self-contained (no self dependency
# on _create_range_index_safe) so it can be called in isolation during
# tests that bind only _graph to the instance.
async def _structural_range_index(lbl: str, prp: str) -> None:
# Pre-check: skip DDL if index is known to already exist.
if existing_range is not None and (lbl, prp) in existing_range:
logger.debug(
"[idx] Structural %s.%s: pre-existing, skipping DDL",
lbl,
prp,
)
return
try:
await self._graph.create_node_range_index(lbl, prp)
except Exception as exc:
msg = str(exc).lower()
if "already exists" not in msg and "already indexed" not in msg:
logger.warning(
"Structural range index %s.%s: %s",
lbl,
prp,
exc,
)
logger.debug("Finished structural range index: %s.%s", lbl, prp)
range_tasks = [
_structural_range_index(label, prop)
for label, props in _range_index_props.items()
for prop in props
]
await asyncio.gather(*range_tasks)
logger.info("Finished all structural range indexes.")
# --- Uniqueness constraints: sequential (write ordering safety) ---
_constraints: list[tuple[str, str]] = [
("Channel", "channel_id"),
("ChannelEpoch", "epoch_id"),
]
for label, prop in _constraints:
logger.debug("Ensuring uniqueness constraint on %s.%s", label, prop)
cypher = f"CREATE CONSTRAINT ON (n:{label}) ASSERT n.{prop} IS UNIQUE"
try:
await self._graph.query(cypher)
except Exception as exc:
msg = str(exc).lower()
if (
"already exists" in msg
or "already indexed" in msg
or "constraint" in msg
):
pass
else:
logger.warning("Uniqueness constraint %s.%s: %s", label, prop, exc)
logger.info(
"Structural indexes and constraints ensured for %d hub labels",
len(STRUCTURAL_LABELS),
)
async def _backfill_user_ids(self) -> None:
"""Set ``user_id`` to the sentinel value on entities that lack one.
Optimized: a single ``MATCH (e) WHERE e.user_id IS NULL`` aggregation
replaces the previous 13 sequential per-label COUNT reads. On a
fully-seeded graph (the common case) this reduces the read phase from
13 round-trips to 1. Write transactions are still issued per-label
and only when the label has null entries.
Falls back to the original sequential per-label loop on any error in
the aggregation query so that correctness is never compromised.
"""
_entity_labels_set = frozenset(ENTITY_LABELS)
# --- Single aggregation read ---
try:
result = await self._graph.query(
"MATCH (e) WHERE e.user_id IS NULL "
"RETURN labels(e)[0] AS lbl, count(e) AS cnt",
)
null_counts: dict[str, int] = {}
for row in result.result_set or []:
if not row or row[0] is None:
continue
lbl = str(row[0])
if lbl in _entity_labels_set:
null_counts[lbl] = int(row[1])
except Exception:
logger.debug(
"user_id backfill: aggregation query failed; "
"falling back to per-label reads",
exc_info=True,
)
await self._backfill_user_ids_sequential()
return
if not null_counts:
logger.debug(
"user_id backfill: all entities have user_id set — nothing to do",
)
return
# --- Per-label write (only for labels with null entries) ---
for label, count in null_counts.items():
try:
logger.info(
"[backfill] user_id: %d %s node(s) missing user_id — patching...",
count,
label,
)
await self._graph.query(
f"MATCH (e:{label}) WHERE e.user_id IS NULL SET e.user_id = $uid",
params={"uid": _SENTINEL_USER_ID},
)
logger.info(
"[backfill] user_id: patched %d %s node(s)",
count,
label,
)
except Exception:
logger.debug(
"[backfill] user_id write failed for %s",
label,
exc_info=True,
)
async def _backfill_user_ids_sequential(self) -> None:
"""Fallback: original per-label sequential backfill.
Called only when the single-aggregation path in
:meth:`_backfill_user_ids` raises an exception.
"""
for label in ENTITY_LABELS:
try:
check = await self._graph.query(
f"MATCH (e:{label}) WHERE e.user_id IS NULL RETURN count(e)",
)
count = check.result_set[0][0] if check.result_set else 0
if not count:
logger.debug(
"user_id backfill: no null entries for %s, skipping write",
label,
)
continue
logger.info(
"Backfilling user_id on %d %s node(s)...",
count,
label,
)
await self._graph.query(
f"MATCH (e:{label}) WHERE e.user_id IS NULL SET e.user_id = $uid",
params={"uid": _SENTINEL_USER_ID},
)
logger.info(
"Backfilled user_id on %d %s node(s)",
count,
label,
)
except Exception:
logger.debug(
"user_id backfill for %s failed",
label,
exc_info=True,
)
# ------------------------------------------------------------------
# Entity CRUD
# ------------------------------------------------------------------
[docs]
async def add_entity(
self,
name: str,
entity_type: str,
description: str,
category: str = "general",
scope_id: str = "_",
created_by: str = "unknown",
pinned: bool = False,
metadata: str = "{}",
user_id: str = _SENTINEL_USER_ID,
embedding: list[float] | None = None,
) -> dict[str, str]:
"""Create or update an entity.
Returns ``{"name": ..., "uuid": ...}``.
"""
if entity_type not in ENTITY_LABELS:
raise ValueError(f"Unknown entity type: {entity_type}")
if category not in CATEGORY_PRIORITY:
raise ValueError(f"Unknown category: {category}")
priority = CATEGORY_PRIORITY[category]
now = time.time()
name_lower = name.strip().lower()
new_uuid = str(uuid7())
if embedding is not None:
vec = embedding
else:
embed_text = f"{name}: {description}" if description else name
vec = await self._embed(embed_text)
q = (
f"MERGE (e:{entity_type} {{name: $name, "
f"scope_id: $sid, category: $cat}}) "
f"ON CREATE SET "
f"e.uuid = $uuid, "
f"e.description = $desc, "
f"e.priority = $pri, "
f"e.embedding = vecf32($vec), "
f"e.pinned = $pinned, "
f"e.mention_count = 1, "
f"e.created_at = $now, "
f"e.updated_at = $now, "
f"e.created_by = $creator, "
f"e.user_id = $user_id, "
f"e.metadata = $meta "
f"ON MATCH SET "
f"e.description = $desc, "
f"e.embedding = vecf32($vec), "
f"e.mention_count = e.mention_count + 1, "
f"e.updated_at = $now "
f"RETURN e.name, e.uuid"
)
params = {
"name": name_lower,
"sid": scope_id or _NO_SCOPE,
"cat": category,
"desc": description,
"pri": priority,
"vec": vec,
"pinned": pinned,
"now": now,
"creator": created_by,
"uuid": new_uuid,
"meta": metadata,
"user_id": user_id or _SENTINEL_USER_ID,
}
result = await self._graph.query(q, params=params)
row = result.result_set[0] if result.result_set else None
return {
"name": row[0] if row else name_lower,
"uuid": row[1] if row else new_uuid,
}
[docs]
async def update_entity_description(
self,
name: str,
entity_type: str,
new_description: str,
category: str | None = None,
scope_id: str | None = None,
) -> bool:
"""Update an entity's description and re-embed."""
name_lower = name.strip().lower()
embed_text = f"{name}: {new_description}" if new_description else name
vec = await self._embed(embed_text)
now = time.time()
where_parts = ["e.name = $name"]
params: dict[str, Any] = {
"name": name_lower,
"desc": new_description,
"vec": vec,
"now": now,
}
if category:
where_parts.append("e.category = $cat")
params["cat"] = category
if scope_id:
where_parts.append("e.scope_id = $sid")
params["sid"] = scope_id
where = " AND ".join(where_parts)
q = (
f"MATCH (e:{entity_type}) WHERE {where} "
f"SET e.description = $desc, "
f"e.embedding = vecf32($vec), "
f"e.updated_at = $now "
f"RETURN e.name"
)
result = await self._graph.query(
q,
params=params,
)
return len(result.result_set) > 0
[docs]
async def edit_entity(
self,
uuid: str,
description: str | None = None,
append_text: str | None = None,
pinned: bool | None = None,
category: str | None = None,
metadata_updates: dict | None = None,
) -> dict | None:
"""Selectively update fields on an existing entity.
Looks up by *uuid*. Only the provided fields are changed;
everything else is preserved.
*description* replaces the text entirely.
*append_text* is concatenated to the existing description.
(Mutually exclusive -- caller must pick one.)
*metadata_updates* is shallow-merged into the existing
metadata JSON (new keys added, existing overwritten,
unmentioned preserved).
Returns the full updated entity dict via :meth:`get_entity`,
or ``None`` if the UUID was not found.
"""
# -- Read current state -------------------------------------------
read_q = (
"MATCH (e) WHERE e.uuid = $uuid "
"RETURN e.name, labels(e)[0], "
"e.description, e.metadata, e.category"
)
result = await self._graph.query(
read_q,
params={"uuid": uuid},
)
if not result.result_set:
return None
row = result.result_set[0]
cur_name: str = row[0]
entity_type: str = row[1]
cur_desc: str = row[2] or ""
raw_meta: str = row[3] or "{}"
_cur_cat: str = row[4] or "general"
try:
cur_meta = (
json.loads(raw_meta)
if isinstance(
raw_meta,
str,
)
else {}
)
except (json.JSONDecodeError, TypeError):
cur_meta = {}
# -- Compute new values -------------------------------------------
new_desc: str | None = None
if description is not None:
new_desc = description
elif append_text:
new_desc = (cur_desc + "\n" + append_text) if cur_desc else append_text
new_meta: str | None = None
if metadata_updates:
merged = {**cur_meta, **metadata_updates}
new_meta = json.dumps(merged, ensure_ascii=False)
new_cat = category
if new_cat and new_cat not in CATEGORY_PRIORITY:
raise ValueError(f"Unknown category: {new_cat}")
# -- Build SET clause dynamically ---------------------------------
now = time.time()
set_parts: list[str] = ["e.updated_at = $now"]
params: dict[str, Any] = {"uuid": uuid, "now": now}
if new_desc is not None:
set_parts.append("e.description = $desc")
params["desc"] = new_desc
embed_text = f"{cur_name}: {new_desc}" if new_desc else cur_name
vec = await self._embed(embed_text)
set_parts.append("e.embedding = vecf32($vec)")
params["vec"] = vec
if pinned is not None:
set_parts.append("e.pinned = $pinned")
params["pinned"] = pinned
if new_cat:
set_parts.append("e.category = $cat")
params["cat"] = new_cat
set_parts.append("e.priority = $pri")
params["pri"] = CATEGORY_PRIORITY[new_cat]
if new_meta is not None:
set_parts.append("e.metadata = $meta")
params["meta"] = new_meta
set_clause = ", ".join(set_parts)
write_q = (
f"MATCH (e:{entity_type}) "
f"WHERE e.uuid = $uuid "
f"SET {set_clause} "
f"RETURN e.uuid"
)
await self._graph.query(write_q, params=params)
return await self.get_entity(uuid=uuid)
[docs]
async def delete_entity(
self,
name: str,
entity_type: str,
category: str,
scope_id: str = "_",
) -> bool:
"""Delete the specified entity.
Args:
name (str): Human-readable name.
entity_type (str): The entity type value.
category (str): The category value.
scope_id (str): The scope id value.
Returns:
bool: True on success, False otherwise.
"""
name_lower = name.strip().lower()
q = (
f"MATCH (e:{entity_type} {{name: $name, "
f"scope_id: $sid, category: $cat}}) "
f"DETACH DELETE e "
f"RETURN count(e) AS deleted"
)
params = {
"name": name_lower,
"sid": scope_id,
"cat": category,
}
result = await self._graph.query(
q,
params=params,
)
if result.result_set:
return result.result_set[0][0] > 0
return False
[docs]
async def delete_entity_by_uuid(
self,
uuid: str,
) -> bool:
"""Delete an entity by UUID (detach-deletes all relationships)."""
q = "MATCH (e {uuid: $uuid}) DETACH DELETE e RETURN count(e) AS deleted"
result = await self._graph.query(
q,
params={"uuid": uuid},
)
if result.result_set:
return result.result_set[0][0] > 0
return False
[docs]
async def pin_entity(
self,
name: str,
entity_type: str,
pinned: bool = True,
category: str | None = None,
scope_id: str | None = None,
) -> bool:
"""Set or clear the pinned flag on an entity.
When category and/or scope_id are provided, only entities
matching those filters are updated. This avoids pinning
the wrong entity when the same name exists in multiple scopes.
"""
name_lower = name.strip().lower()
now = time.time()
where_parts = ["e.name = $name"]
params: dict[str, Any] = {
"name": name_lower,
"pinned": pinned,
"now": now,
}
if category is not None:
where_parts.append("e.category = $cat")
params["cat"] = category
if scope_id is not None:
where_parts.append("e.scope_id = $sid")
params["sid"] = scope_id or _NO_SCOPE
where_clause = " AND ".join(where_parts)
q = (
f"MATCH (e:{entity_type}) "
f"WHERE {where_clause} "
f"SET e.pinned = $pinned, "
f"e.updated_at = $now "
f"RETURN e.name"
)
result = await self._graph.query(q, params=params)
return len(result.result_set) > 0
[docs]
async def get_entity(
self,
name: str = "",
entity_type: str | None = None,
category: str | None = None,
scope_id: str | None = None,
uuid: str | None = None,
) -> dict | None:
"""Fetch an entity with its immediate connections.
Can look up by *name* or by *uuid*.
"""
label_filter = f":{entity_type}" if entity_type else ""
where_parts: list[str] = []
params: dict[str, Any] = {}
if uuid:
where_parts.append("e.uuid = $uuid")
params["uuid"] = uuid
else:
name_lower = name.strip().lower()
where_parts.append("e.name = $name")
params["name"] = name_lower
if category:
where_parts.append("e.category = $cat")
params["cat"] = category
if scope_id:
where_parts.append("e.scope_id = $sid")
params["sid"] = scope_id
where = " AND ".join(where_parts)
q = (
f"MATCH (e{label_filter}) WHERE {where} "
f"OPTIONAL MATCH (e)-[r]-(neighbor) "
f"RETURN e.name AS name, "
f"labels(e)[0] AS type, "
f"e.description AS description, "
f"e.category AS category, "
f"e.priority AS priority, "
f"e.scope_id AS scope_id, "
f"e.mention_count AS mention_count, "
f"e.created_at AS created_at, "
f"e.updated_at AS updated_at, "
f"e.created_by AS created_by, "
f"e.uuid AS uuid, "
f"e.metadata AS metadata, "
f"e.user_id AS user_id, "
f"collect(DISTINCT {{rel: type(r), "
f"target: neighbor.name, "
f"target_uuid: neighbor.uuid, "
f"target_type: labels(neighbor)[0], "
f"target_category: neighbor.category, "
f"weight: r.weight, "
f"description: r.description, "
f"priority: r.priority}}) AS connections "
f"LIMIT 1"
)
result = await self._graph.query(
q,
params=params,
)
if not result.result_set:
return None
row = result.result_set[0]
raw_meta = row[11] or "{}"
try:
meta = (
json.loads(raw_meta)
if isinstance(
raw_meta,
str,
)
else {}
)
except (json.JSONDecodeError, TypeError):
meta = {}
entity_user_id = row[12] or _SENTINEL_USER_ID
conns = row[13] if row[13] else []
conns = [c for c in conns if c.get("rel") is not None]
return {
"name": row[0],
"type": row[1],
"description": row[2],
"category": row[3],
"priority": row[4],
"scope_id": row[5],
"mention_count": row[6],
"created_at": row[7],
"updated_at": row[8],
"created_by": row[9],
"uuid": row[10],
"metadata": meta,
"user_id": entity_user_id,
"connections": conns,
}
[docs]
async def inspect_entity(
self,
name: str = "",
uuid: str | None = None,
max_depth: int = 2,
neighbor_limit: int = 50,
) -> dict | None:
"""Deep inspection of an entity and its full neighborhood.
Returns the entity's properties plus all outgoing and
incoming relationships (up to *max_depth* hops), with
each neighbor's core properties included.
"""
where_parts: list[str] = []
params: dict[str, Any] = {}
if uuid:
where_parts.append("e.uuid = $uuid")
params["uuid"] = uuid
elif name:
where_parts.append("e.name = $name")
params["name"] = name.strip().lower()
else:
return None
where = " AND ".join(where_parts)
entity_q = (
f"MATCH (e) WHERE {where} "
f"RETURN e.name, labels(e)[0], e.description, "
f"e.category, e.priority, e.scope_id, "
f"e.mention_count, e.created_at, e.updated_at, "
f"e.created_by, e.uuid, e.metadata, e.user_id, "
f"e.pinned "
f"LIMIT 1"
)
result = await self._graph.query(entity_q, params=params)
if not result.result_set:
return None
row = result.result_set[0]
raw_meta = row[11] or "{}"
try:
meta = json.loads(raw_meta) if isinstance(raw_meta, str) else {}
except (json.JSONDecodeError, TypeError):
meta = {}
entity = {
"name": row[0],
"type": row[1],
"description": row[2],
"category": row[3],
"priority": row[4],
"scope_id": row[5],
"mention_count": row[6],
"created_at": row[7],
"updated_at": row[8],
"created_by": row[9],
"uuid": row[10],
"metadata": meta,
"user_id": row[12] or _SENTINEL_USER_ID,
"pinned": bool(row[13]),
}
entity_uuid = entity["uuid"]
out_q = (
"MATCH (e {uuid: $uuid})-[r]->(t) "
"RETURN type(r) AS rel, r.weight AS weight, "
"r.description AS rdesc, r.priority AS rpri, "
"t.name AS tname, labels(t)[0] AS ttype, "
"t.uuid AS tuuid, t.category AS tcat, "
"t.description AS tdesc "
"ORDER BY r.weight DESC "
"LIMIT $lim"
)
out_result = await self._graph.query(
out_q,
params={"uuid": entity_uuid, "lim": neighbor_limit},
)
in_q = (
"MATCH (s)-[r]->(e {uuid: $uuid}) "
"RETURN type(r) AS rel, r.weight AS weight, "
"r.description AS rdesc, r.priority AS rpri, "
"s.name AS sname, labels(s)[0] AS stype, "
"s.uuid AS suuid, s.category AS scat, "
"s.description AS sdesc "
"ORDER BY r.weight DESC "
"LIMIT $lim"
)
in_result = await self._graph.query(
in_q,
params={"uuid": entity_uuid, "lim": neighbor_limit},
)
outgoing = [
{
"relation": row[0],
"weight": row[1],
"rel_description": row[2],
"rel_priority": row[3],
"target_name": row[4],
"target_type": row[5],
"target_uuid": row[6],
"target_category": row[7],
"target_description": row[8],
}
for row in (out_result.result_set or [])
]
incoming = [
{
"relation": row[0],
"weight": row[1],
"rel_description": row[2],
"rel_priority": row[3],
"source_name": row[4],
"source_type": row[5],
"source_uuid": row[6],
"source_category": row[7],
"source_description": row[8],
}
for row in (in_result.result_set or [])
]
# Optional: 2-hop neighbors (neighbors of neighbors)
second_hop: list[dict] = []
if max_depth >= 2:
hop2_q = (
"MATCH (e {uuid: $uuid})-[r1]-(n1)-[r2]-(n2) "
"WHERE n2.uuid <> $uuid "
"RETURN DISTINCT n2.name AS name, "
"labels(n2)[0] AS type, "
"n2.uuid AS uuid, n2.category AS cat, "
"type(r2) AS via_rel, "
"n1.name AS via_node "
"LIMIT $lim"
)
try:
hop2 = await self._graph.query(
hop2_q,
params={
"uuid": entity_uuid,
"lim": neighbor_limit,
},
timeout=100_000,
)
second_hop = [
{
"name": row[0],
"type": row[1],
"uuid": row[2],
"category": row[3],
"via_relation": row[4],
"via_node": row[5],
}
for row in (hop2.result_set or [])
]
except Exception:
logger.debug(
"inspect_entity: 2-hop query timed out",
exc_info=True,
)
return {
"entity": entity,
"outgoing": outgoing,
"incoming": incoming,
"second_hop": second_hop,
"summary": {
"outgoing_count": len(outgoing),
"incoming_count": len(incoming),
"second_hop_count": len(second_hop),
},
}
[docs]
async def list_entities(
self,
entity_type: str | None = None,
category: str | None = None,
scope_id: str | None = None,
limit: int = 50,
offset: int = 0,
search: str | None = None,
) -> list[dict]:
"""List entities with optional filtering, pagination, and text search."""
label_filter = f":{entity_type}" if entity_type else ""
where_parts: list[str] = []
params: dict[str, Any] = {
"lim": max(1, min(int(limit), 500)),
"off": max(0, int(offset)),
}
if category:
where_parts.append("e.category = $cat")
params["cat"] = category
if scope_id:
where_parts.append("e.scope_id = $sid")
params["sid"] = scope_id
if search and search.strip():
where_parts.append(
"(toLower(e.name) CONTAINS toLower($q) "
"OR toLower(e.description) CONTAINS toLower($q))",
)
params["q"] = search.strip()
where = (" WHERE " + " AND ".join(where_parts)) if where_parts else ""
q = (
f"MATCH (e{label_filter}){where} "
f"RETURN e.name AS name, "
f"labels(e)[0] AS type, "
f"e.description AS description, "
f"e.category AS category, "
f"e.priority AS priority, "
f"e.scope_id AS scope_id, "
f"e.mention_count AS mention_count, "
f"e.pinned AS pinned, "
f"e.updated_at AS updated_at, "
f"e.uuid AS uuid, "
f"e.metadata AS metadata, "
f"e.user_id AS user_id "
f"ORDER BY e.updated_at DESC "
f"SKIP $off "
f"LIMIT $lim"
)
result = await self._graph.query(
q,
params=params,
)
entities = []
for row in result.result_set:
raw_meta = row[10] or "{}"
try:
meta = (
json.loads(raw_meta)
if isinstance(
raw_meta,
str,
)
else {}
)
except (json.JSONDecodeError, TypeError):
meta = {}
entities.append(
{
"name": row[0],
"type": row[1],
"description": row[2],
"category": row[3],
"priority": row[4],
"scope_id": row[5],
"mention_count": row[6],
"pinned": bool(row[7]) if row[7] is not None else False,
"updated_at": row[8],
"uuid": row[9],
"metadata": meta,
"user_id": row[11] or _SENTINEL_USER_ID,
}
)
return entities
# ------------------------------------------------------------------
# Relationship CRUD
# ------------------------------------------------------------------
[docs]
async def add_relationship(
self,
source_uuid: str,
target_uuid: str,
relation_type: str,
weight: float = 0.5,
description: str = "",
evidence: str = "",
) -> bool:
"""Create or reinforce a relationship between two
entities identified by UUID.
The edge inherits ``priority = min(source, target)``
and the ``category`` / ``scope_id`` from the
lower-priority endpoint. Cross-category edges are
fully supported.
"""
relation_type = relation_type.strip().upper().replace(" ", "_")
if not relation_type:
raise ValueError("Relation type cannot be empty")
now = time.time()
q = (
f"MATCH (a {{uuid: $src_uuid}}) "
f"MATCH (b {{uuid: $tgt_uuid}}) "
f"MERGE (a)-[r:{relation_type}]->(b) "
f"ON CREATE SET "
f"r.src_uuid = $src_uuid, "
f"r.tgt_uuid = $tgt_uuid, "
f"r.weight = $w, r.description = $desc, "
f"r.priority = CASE "
f"WHEN a.priority < b.priority "
f"THEN a.priority "
f"ELSE b.priority END, "
f"r.category = CASE "
f"WHEN a.priority < b.priority "
f"THEN a.category "
f"ELSE b.category END, "
f"r.scope_id = CASE "
f"WHEN a.priority < b.priority "
f"THEN a.scope_id "
f"ELSE b.scope_id END, "
f"r.evidence = $ev, "
f"r.created_at = $now, "
f"r.updated_at = $now "
f"ON MATCH SET "
f"r.weight = CASE "
f"WHEN r.weight + 0.1 > 1.0 "
f"THEN 1.0 "
f"ELSE r.weight + 0.1 END, "
f"r.updated_at = $now "
f"RETURN type(r)"
)
params = {
"src_uuid": source_uuid,
"tgt_uuid": target_uuid,
"w": weight,
"desc": description,
"ev": evidence,
"now": now,
}
result = await self._graph.query(
q,
params=params,
)
return len(result.result_set) > 0
[docs]
async def delete_relationship(
self,
source_uuid: str,
target_uuid: str,
relation_type: str,
) -> bool:
"""Delete the specified relationship.
Args:
source_uuid (str): The source uuid value.
target_uuid (str): The target uuid value.
relation_type (str): The relation type value.
Returns:
bool: True on success, False otherwise.
"""
relation_type = relation_type.upper()
q = (
f"MATCH (a {{uuid: $src_uuid}})"
f"-[r:{relation_type}]->"
f"(b {{uuid: $tgt_uuid}}) "
f"DELETE r RETURN count(r) AS deleted"
)
params = {
"src_uuid": source_uuid,
"tgt_uuid": target_uuid,
}
result = await self._graph.query(
q,
params=params,
)
if result.result_set:
return result.result_set[0][0] > 0
return False
[docs]
async def list_relationships(
self,
entity_uuid: str | None = None,
relation_type: str | None = None,
category: str | None = None,
limit: int = 50,
order_by: bool = True,
timeout: int | None = None,
) -> list[dict]:
"""List relationships.
Args:
entity_uuid (str | None): The entity uuid value.
relation_type (str | None): The relation type value.
category (str | None): The category value.
limit (int): Maximum number of items.
order_by (bool): Sort by updated_at DESC. Disable for
large-graph visualization queries where the sort
dominates query time.
timeout (int | None): Per-query timeout in ms. When
``None`` the server default is used.
Returns:
list[dict]: The result.
"""
where_parts: list[str] = []
params: dict[str, Any] = {"lim": limit}
if entity_uuid:
where_parts.append("(a.uuid = $euuid OR b.uuid = $euuid)")
params["euuid"] = entity_uuid
if category:
where_parts.append("r.category = $cat")
params["cat"] = category
rel_pattern = f":{relation_type.upper()}" if relation_type else ""
where = (" WHERE " + " AND ".join(where_parts)) if where_parts else ""
order_clause = "ORDER BY r.updated_at DESC " if order_by else ""
q = (
f"MATCH (a)-[r{rel_pattern}]->(b){where} "
f"RETURN a.name AS source, "
f"a.uuid AS source_uuid, "
f"labels(a)[0] AS source_type, "
f"type(r) AS relation, "
f"b.name AS target, "
f"b.uuid AS target_uuid, "
f"labels(b)[0] AS target_type, "
f"r.weight AS weight, "
f"r.description AS description, "
f"r.category AS category, "
f"r.priority AS priority "
f"{order_clause}"
f"LIMIT $lim"
)
query_kwargs: dict[str, Any] = {"params": params}
if timeout is not None:
query_kwargs["timeout"] = timeout
result = await self._graph.query(q, **query_kwargs)
rels = []
for row in result.result_set:
rels.append(
{
"source": row[0],
"source_uuid": row[1],
"source_type": row[2],
"relation": row[3],
"target": row[4],
"target_uuid": row[5],
"target_type": row[6],
"weight": row[7],
"description": row[8],
"category": row[9],
"priority": row[10],
}
)
return rels
# ------------------------------------------------------------------
# Entity resolution
# ------------------------------------------------------------------
async def _resolve_or_create(
self,
name: str,
entity_type: str,
category: str,
scope_id: str,
description: str = "",
created_by: str = "unknown",
user_id: str = _SENTINEL_USER_ID,
embedding: list[float] | None = None,
) -> dict[str, str]:
"""Resolve an existing entity or create a new one.
Returns ``{"name": ..., "uuid": ...}``.
Resolution strategy:
1. Exact name match within same category+scope+type.
2. Vector similarity > 0.90 within same
category+scope+type.
3. Create new if no match.
"""
name_lower = name.strip().lower()
sid = scope_id or _NO_SCOPE
# Strategy 1: exact name match
q = (
f"MATCH (e:{entity_type} "
f"{{category: $cat, scope_id: $sid}}) "
f"WHERE toLower(e.name) = $name "
f"RETURN e.name, e.uuid LIMIT 1"
)
params = {
"cat": category,
"sid": sid,
"name": name_lower,
}
result = await self._graph.query(
q,
params=params,
)
if result.result_set:
existing = result.result_set[0][0]
existing_uuid = result.result_set[0][1]
logger.debug(
"Resolved entity %r to existing node %s via exact name match",
name_lower,
existing_uuid,
)
await self._reinforce_entity(
existing,
entity_type,
description,
)
return {
"name": existing,
"uuid": existing_uuid,
}
# Strategy 2: vector similarity
if embedding is not None:
vec = embedding
else:
embed_text = f"{name}: {description}" if description else name
vec = await self._embed(embed_text)
try:
vec_q = (
f"CALL db.idx.vector.queryNodes("
f"'{entity_type}', 'embedding', "
f"5, vecf32($vec)) "
f"YIELD node, score "
f"WHERE node.category = $cat "
f"AND node.scope_id = $sid "
f"AND score > $threshold "
f"RETURN node.name, node.uuid, score "
f"ORDER BY score DESC LIMIT 1"
)
vec_params = {
"vec": vec,
"cat": category,
"sid": sid,
"threshold": self._dedup_threshold,
}
vec_result = await self._graph.query(
vec_q,
params=vec_params,
)
if vec_result.result_set:
existing = vec_result.result_set[0][0]
existing_uuid = vec_result.result_set[0][1]
logger.debug(
"Resolved entity %r to existing node %s via vector similarity (score=%.3f)",
name_lower,
existing_uuid,
vec_result.result_set[0][2],
)
await self._reinforce_entity(
existing,
entity_type,
description,
)
return {
"name": existing,
"uuid": existing_uuid,
}
except Exception:
logger.debug(
"Vector resolution failed, creating new entity",
exc_info=True,
)
# Strategy 3: create new (reuse vector to avoid re-embedding)
logger.debug("No match found for entity %r; creating new node", name_lower)
info = await self.add_entity(
name,
entity_type,
description,
category=category,
scope_id=sid,
created_by=created_by,
user_id=user_id,
embedding=vec,
)
return info
async def _reinforce_entity(
self,
name: str,
entity_type: str,
new_description: str = "",
) -> None:
"""Increment mention_count and optionally merge
descriptions.
"""
now = time.time()
if new_description:
q = (
f"MATCH (e:{entity_type} "
f"{{name: $name}}) "
f"SET e.mention_count = "
f"e.mention_count + 1, "
f"e.updated_at = $now, "
f"e.description = CASE "
f"WHEN size(e.description) "
f"> size($desc) "
f"THEN e.description "
f"ELSE $desc END "
f"RETURN e.name"
)
params = {
"name": name,
"desc": new_description,
"now": now,
}
await self._graph.query(
q,
params=params,
)
else:
q = (
f"MATCH (e:{entity_type} "
f"{{name: $name}}) "
f"SET e.mention_count = "
f"e.mention_count + 1, "
f"e.updated_at = $now "
f"RETURN e.name"
)
await self._graph.query(
q,
params={"name": name, "now": now},
)
[docs]
async def resolve_entity_cross_category(
self,
name: str,
entity_type: str,
) -> dict | None:
"""Find an entity by name across all categories.
Used for cross-category linking.
"""
name_lower = name.strip().lower()
q = (
f"MATCH (e:{entity_type}) "
f"WHERE toLower(e.name) = $name "
f"RETURN e.name, e.category, "
f"e.scope_id, e.priority, e.uuid "
f"ORDER BY e.priority DESC LIMIT 1"
)
result = await self._graph.query(
q,
params={"name": name_lower},
)
if result.result_set:
row = result.result_set[0]
return {
"name": row[0],
"category": row[1],
"scope_id": row[2],
"priority": row[3],
"uuid": row[4],
}
return None
# ------------------------------------------------------------------
# Retrieval: hybrid vector + graph traversal
# ------------------------------------------------------------------
[docs]
async def retrieve_context(
self,
query: str,
query_embedding: list[float] | np.ndarray | None = None,
user_ids: list[str] | None = None,
channel_id: str | None = None,
guild_id: str | None = None,
max_hops: int = 2,
max_per_user: int = 60,
max_channel: int = 15,
max_guild: int = 15,
max_general: int = 30,
max_per_lore: int = 20,
seed_top_k: int = 64,
seed_similarity_threshold: float = 0.38,
seed_limit: int = 15,
min_edge_weight: float = 0.0,
default_edge_weight: float = 0.8,
semantic_hop_decay: float = 0.8,
expansion_neighbor_limit: int = 500,
dynamic_threshold_enabled: bool = True,
dynamic_threshold_target_ratio: float = 0.10,
dynamic_threshold_min: float = 0.20,
dynamic_threshold_min_stored: int = 5,
full_user_memory_ids: list[str] | None = None,
user_seed_min: int = 10,
user_candidate_limit: int = 100,
lore_candidate_limit: int = 40,
lore_seed_min: int = 5,
lore_amplified: bool = False,
max_per_meta: int = 20,
meta_candidate_limit: int = 40,
meta_seed_min: int = 5,
meta_amplified: bool = False,
) -> dict[str, list[dict]]:
from .retrieval import run_retrieve_context
return await run_retrieve_context(
self,
query,
query_embedding,
user_ids,
channel_id,
guild_id,
max_hops,
max_per_user,
max_channel,
max_guild,
max_general,
max_per_lore,
seed_top_k,
seed_similarity_threshold,
seed_limit,
min_edge_weight,
default_edge_weight,
semantic_hop_decay,
expansion_neighbor_limit,
dynamic_threshold_enabled,
dynamic_threshold_target_ratio,
dynamic_threshold_min,
dynamic_threshold_min_stored,
full_user_memory_ids,
user_seed_min,
user_candidate_limit,
lore_candidate_limit,
lore_seed_min,
lore_amplified=lore_amplified,
max_per_meta=max_per_meta,
meta_candidate_limit=meta_candidate_limit,
meta_seed_min=meta_seed_min,
meta_amplified=meta_amplified,
)
async def _fetch_core_knowledge(self) -> list[dict]:
from .retrieval import run_fetch_core_knowledge
return await run_fetch_core_knowledge(self)
async def _fetch_basic_knowledge(
self,
max_entities: int = 50,
) -> list[dict]:
from .retrieval import run_fetch_basic_knowledge
return await run_fetch_basic_knowledge(self, max_entities)
async def _fetch_pinned_entities(
self,
allowed: set[tuple[str, str]],
) -> list[dict]:
from .retrieval import run_fetch_pinned_entities
return await run_fetch_pinned_entities(self, allowed)
async def _seed_vector_search(
self,
q_vec: list[float],
allowed: set[tuple[str, str]],
top_k: int,
) -> list[dict]:
from .retrieval import run_seed_vector_search
return await run_seed_vector_search(self, q_vec, allowed, top_k)
async def _expand_graph(
self,
seed_uuids: list[str],
max_hops: int,
min_edge_weight: float = 0.0,
default_edge_weight: float = 0.8,
neighbor_limit: int = 500,
) -> list[dict]:
from .retrieval import run_expand_graph
return await run_expand_graph(
self,
seed_uuids,
max_hops,
min_edge_weight=min_edge_weight,
default_edge_weight=default_edge_weight,
neighbor_limit=neighbor_limit,
)
def _deconflict(
self,
entities: list[dict],
) -> list[dict]:
from .retrieval import run_deconflict
return run_deconflict(entities)
[docs]
async def get_core_knowledge(self) -> list[dict]:
from .retrieval import run_get_core_knowledge
return await run_get_core_knowledge(self)
[docs]
async def search_entities(
self,
query: str,
query_embedding: list[float] | None = None,
category: str | None = None,
scope_id: str | None = None,
entity_type: str | None = None,
top_k: int = 10,
) -> list[dict]:
from .retrieval import run_search_entities
return await run_search_entities(
self,
query,
query_embedding,
category,
scope_id,
entity_type,
top_k,
)
@staticmethod
def _coerce_embedding_list(raw: Any) -> list[float] | None:
if raw is None:
return None
if isinstance(raw, list):
out: list[float] = []
for x in raw:
if isinstance(x, (int, float)):
out.append(float(x))
else:
return None
return out
if isinstance(raw, np.ndarray):
flat = raw.astype(np.float64).flatten()
return [float(x) for x in flat.tolist()]
return None
[docs]
async def reconsolidate_embeddings_on_recall(
self,
targets: list[dict[str, str]],
query_embedding: list[float],
*,
learning_rate: float = 0.03,
max_step: float = 0.25,
) -> None:
"""Nudge recalled entities' embeddings toward the query with a tanh clamp.
Fire-and-forget from prompt build; failures are logged at debug.
"""
if not targets or not query_embedding:
return
q = np.asarray(query_embedding, dtype=np.float64).flatten()
if q.size == 0:
return
qn = float(np.linalg.norm(q))
if qn > 1e-12:
q = q / qn
dim = int(q.size)
now = time.time()
for t in targets:
uuid = t.get("uuid")
label = t.get("type")
if not uuid or not label or label not in ENTITY_LABELS:
continue
read_q = f"MATCH (e:{label} {{uuid: $uuid}}) RETURN e.embedding AS emb"
try:
res = await self._graph.query(
read_q,
params={"uuid": uuid},
)
except Exception:
logger.debug(
"reconsolidate read failed for %s",
uuid,
exc_info=True,
)
continue
if not res.result_set or res.result_set[0][0] is None:
continue
emb_list = self._coerce_embedding_list(res.result_set[0][0])
if not emb_list or len(emb_list) != dim:
continue
emb = np.asarray(emb_list, dtype=np.float64)
en = float(np.linalg.norm(emb))
if en > 1e-12:
emb = emb / en
delta = float(learning_rate) * (q - emb)
step_norm = float(np.linalg.norm(delta))
if step_norm < 1e-12:
continue
ms = float(max_step)
allowed = ms * float(np.tanh(step_norm / ms))
delta = delta * (allowed / step_norm)
new_emb = emb + delta
nn = float(np.linalg.norm(new_emb))
if nn > 1e-12:
new_emb = new_emb / nn
vec_lit = new_emb.astype(np.float32).tolist()
write_q = (
f"MATCH (e:{label} {{uuid: $uuid}}) "
f"SET e.embedding = vecf32($vec), e.updated_at = $now"
)
try:
await self._graph.query(
write_q,
params={"uuid": uuid, "vec": vec_lit, "now": now},
)
except Exception:
logger.debug(
"reconsolidate write failed for %s",
uuid,
exc_info=True,
)
[docs]
async def get_graph_stats(self) -> dict:
"""Return high-level graph statistics.
Each sub-query runs independently so a timeout on one
(e.g. the expensive relationship-type scan) does not
prevent the other stats from being returned.
"""
_TIMEOUT = 300_000
stats: dict[str, Any] = {
"node_count": 0,
"label_count": 0,
"relationship_count": 0,
"entities_by_label": {},
"entities_by_category": {},
"relationships_by_type": {},
}
try:
nr = await self._graph.query(
"MATCH (n) "
"RETURN count(n) AS node_count, "
"count(DISTINCT labels(n)[0]) AS label_count",
timeout=_TIMEOUT,
)
if nr.result_set:
stats["node_count"] = nr.result_set[0][0]
stats["label_count"] = nr.result_set[0][1]
except Exception:
logger.warning("get_graph_stats: node count query failed", exc_info=True)
try:
rr = await self._graph.query(
"MATCH ()-[r]->() RETURN count(r) AS rel_count",
timeout=_TIMEOUT,
)
if rr.result_set:
stats["relationship_count"] = rr.result_set[0][0]
except Exception:
logger.warning("get_graph_stats: rel count query failed", exc_info=True)
try:
rl = await self._graph.query(
"MATCH (n) "
"RETURN labels(n)[0] AS label, count(n) AS cnt "
"ORDER BY cnt DESC",
timeout=_TIMEOUT,
)
stats["entities_by_label"] = {
str(row[0]): int(row[1]) for row in (rl.result_set or [])
}
except Exception:
logger.warning("get_graph_stats: per-label query failed", exc_info=True)
try:
rc = await self._graph.query(
"MATCH (n) "
"RETURN coalesce(n.category, 'unknown') AS category, "
"count(n) AS cnt "
"ORDER BY cnt DESC",
timeout=_TIMEOUT,
)
stats["entities_by_category"] = {
str(row[0]): int(row[1]) for row in (rc.result_set or [])
}
except Exception:
logger.warning("get_graph_stats: per-category query failed", exc_info=True)
try:
types_result = await self._graph.query(
"CALL db.relationshiptypes()",
timeout=_TIMEOUT,
)
rel_types = [str(row[0]) for row in (types_result.result_set or [])]
counts: dict[str, int] = {}
for rt in rel_types:
try:
cr = await self._graph.query(
f"MATCH ()-[r:{rt}]->() RETURN count(r) AS cnt",
timeout=_TIMEOUT,
)
cnt = int(cr.result_set[0][0]) if cr.result_set else 0
if cnt > 0:
counts[rt] = cnt
except Exception:
pass
stats["relationships_by_type"] = dict(
sorted(counts.items(), key=lambda x: -x[1])
)
except Exception:
logger.warning("get_graph_stats: per-rel-type query failed", exc_info=True)
return stats