Source code for message_cache

"""Redis-backed message cache with per-channel sorted-set indexes and
RediSearch vector search.

Each message is stored as a Redis hash keyed by ``msg:{uuid}`` with
the embedding as a binary FLOAT32 blob.  A per-channel sorted set
(``channel_msgs:{platform}:{channel_id}``) provides O(log N) time-
ordered lookups by channel, while a RediSearch HNSW index
(``idx:messages``) enables semantic KNN similarity search.

All cached messages are represented as :class:`CachedMessage` objects
and are given a 90-day TTL in Redis.
"""

from __future__ import annotations

import jsonutil as json
import logging
import re
import time
import uuid as _uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any

import numpy as np
import redis.asyncio as aioredis
from redis.asyncio.sentinel import Sentinel
from redis.commands.search.query import Query

from classifiers.redis_vector_index import DEFAULT_KNN_EF_RUNTIME
from init_redis_indexes import MSG_INDEX_NAME, VECTOR_DIM
from observability import observability
from openrouter_client import OpenRouterClient

logger = logging.getLogger(__name__)

TTL_SECONDS = 90 * 86400  # 90 days


def _note_context_read_degraded(section: str) -> None:
    """Emit a metric when a context-critical Redis read degrades to empty.

    Client-level retry (see ``Config.redis_resilience_kwargs``) already rides
    through a transient Sentinel failover window, so this fires only when a read
    *still* failed after retries — making the (intentionally fail-open) degraded
    turn visible instead of silent.
    """
    try:
        observability.increment(
            "context_read_degraded", tags={"section": section}
        )
    except Exception:
        pass

# Prompt-time injections appended to the user turn for the LLM must never be
# persisted in Redis (duplicate noise + wrong embeddings). Stripped in
# :meth:`MessageCache.log_message` before embed + write.
_CHANNEL_SEMANTIC_RECALL_BLOCK_RE = re.compile(
    r"<channel_semantic_recall\b[^>]*>.*?</channel_semantic_recall\s*>",
    re.DOTALL | re.IGNORECASE,
)


[docs] def strip_llm_injection_artifacts_for_cache(text: str) -> str: """Remove prompt-injection blobs that must not be stored in the message cache. Callers normally pass raw platform text; this guards against accidental forwarding of augmented LLM-request content into :meth:`log_message`. """ if not text: return text t = _CHANNEL_SEMANTIC_RECALL_BLOCK_RE.sub("", text) # Malformed / truncated XML: drop from the opening tag onward. key = "<channel_semantic_recall" lower = t.lower() idx = lower.find(key) if idx >= 0: t = t[:idx] return t.rstrip()
def _embed_to_bytes(embedding: list[float] | np.ndarray) -> bytes: """Pack an embedding vector into a little-endian FLOAT32 byte blob. RediSearch HNSW vector fields store and compare raw FLOAT32 buffers, so every embedding written to a ``msg:*`` hash or passed as the ``$query_vec`` KNN parameter must first be flattened to this binary form. Pure numpy conversion with no I/O. Called by :meth:`CachedMessage.to_redis_hash` (when serialising a message for storage) and by :meth:`MessageCache.search_messages` (when turning a query embedding into a search blob). Args: embedding (list[float] | np.ndarray): The embedding vector to encode. Returns: bytes: The vector as a contiguous FLOAT32 byte buffer. """ arr = np.array(embedding, dtype=np.float32) return arr.tobytes() def _bytes_to_embed(raw: Any) -> list[float]: """Decode a stored embedding back into a Python list of floats. The inverse of :func:`_embed_to_bytes`, hardened against the several shapes an embedding can take once it has round-tripped through Redis: a raw FLOAT32 binary buffer (the normal case), a JSON array string, or a corrupted ``b'...'`` / ``{...}`` repr that should be treated as empty. Tolerates bad input by returning an empty list rather than raising. Pure decoding with no I/O. Called by :meth:`CachedMessage.from_redis_hash` when reconstructing a message from an ``HGETALL`` result. Args: raw (Any): The stored embedding value — ``bytes`` / ``bytearray`` / ``memoryview`` for a binary blob, a ``str`` for a JSON or space-separated representation, or ``None``. Returns: list[float]: The decoded vector, or an empty list when the input is missing, malformed, or not a valid FLOAT32-length buffer. """ if raw is None: return [] # 1. Handle string/unicode representation directly if isinstance(raw, str): stripped = raw.strip() if stripped.startswith("b'") or stripped.startswith('b"') or stripped.startswith("{"): return [] if stripped.startswith("["): try: import json parsed = json.loads(stripped) if isinstance(parsed, list): return [float(x) for x in parsed] except Exception: pass # Fallback: space-separated values? try: return [float(x) for x in stripped.split()] except Exception: return [] # 2. Handle bytes, bytearray, or memoryview if isinstance(raw, (bytes, bytearray, memoryview)): # Check if it starts with "[" indicating a stringified JSON bytes representation stripped = raw.strip() if stripped.startswith(b"b'") or stripped.startswith(b'b"') or stripped.startswith(b"{"): return [] if stripped.startswith(b"["): try: import json parsed = json.loads(stripped.decode("utf-8", errors="ignore")) if isinstance(parsed, list): return [float(x) for x in parsed] except Exception: pass # Otherwise, process as raw binary FLOAT32 array if len(raw) < 4 or len(raw) % 4 != 0: return [] try: return np.frombuffer(raw, dtype=np.float32).tolist() except Exception: return [] return [] def _escape_tag(value: str) -> str: """Backslash-escape RediSearch TAG-query metacharacters in a value. RediSearch treats characters like ``.``, ``|``, ``{`` and ``}`` as query syntax, so any platform/channel/user identifier injected into a ``@field: {value}`` filter must be escaped first or it will be mis-parsed (or let a crafted id alter the query). Pure string transformation with no I/O. Called by :meth:`MessageCache.search_messages` and :meth:`MessageCache.get_recent_for_user` when building TAG filter clauses. Args: value (str): The raw tag value (e.g. a channel or user id) to escape. Returns: str: The value with every reserved character preceded by a backslash. """ out: list[str] = [] for ch in value: if ch in r"\.+*?[^]$(){}=!<>|:-@&~\"'/": out.append("\\") out.append(ch) return "".join(out)
[docs] @dataclass class CachedMessage: """A single cached chat message and its embedding, as stored in Redis. The in-memory representation of one row in the message cache: speaker identity, the channel/platform it belongs to, the text, a monotonic timestamp, and the FLOAT32 embedding used for semantic search. Each instance maps to a ``msg:{uuid}`` Redis hash and a membership in the per-channel ``channel_msgs:{platform}:{channel_id}`` sorted set. Construct one from raw JSON via :meth:`from_json` / :meth:`from_dict`, from a Redis ``HGETALL`` mapping via :meth:`from_redis_hash`, or let :class:`MessageCache` build and persist one for you in :meth:`MessageCache.log_message`. Serialise back out with :meth:`to_dict`, :meth:`to_json`, or :meth:`to_redis_hash`. Constructed directly in tests and by :meth:`MessageCache._fetch_hashes` and :meth:`MessageCache._doc_to_cached_message`. """ user_id: str user_name: str platform: str channel_id: str text: str timestamp: float embedding: list[float] = field(default_factory=list, repr=False) message_key: str = "" message_id: str = "" """Platform-specific message identifier (Discord message ID, Matrix event ID, etc.).""" reply_to_id: str = "" """Platform-specific ID of the message this one replies to, if any.""" kind: str = "user_in" """``user_in`` for messages from humans, ``assistant_out`` for bot replies.""" turn_summary_id: str = "" """Link to the Stargazer-System-Log summary for this assistant turn.""" # -- Serialisation -----------------------------------------------------
[docs] def to_dict(self) -> dict[str, Any]: """Serialise this message to a plain JSON-friendly dict. Unlike :meth:`to_redis_hash`, this keeps the embedding as a list of floats (not a binary blob) and omits the transient ``message_key``, making it suitable for JSON transport and snapshotting. Pure in-memory transform with no side effects. Used by :meth:`to_json` and as the general-purpose dict form consumed across the codebase wherever a message is round-tripped through JSON. Returns: dict[str, Any]: Field/value mapping with ``embedding`` as a list of floats. """ return { "user_id": self.user_id, "user_name": self.user_name, "platform": self.platform, "channel_id": self.channel_id, "text": self.text, "timestamp": self.timestamp, "embedding": self.embedding, "message_id": self.message_id, "reply_to_id": self.reply_to_id, "kind": self.kind, "turn_summary_id": self.turn_summary_id, }
[docs] def to_json(self) -> str: """Serialise this message to a JSON string. Thin wrapper that ``json.dumps`` the output of :meth:`to_dict`, giving a compact string form for transport or logging. The inverse is :meth:`from_json`. Pure in-memory transform with no I/O. Returns: str: The JSON-encoded message. """ return json.dumps(self.to_dict())
[docs] def to_redis_hash(self) -> dict[str, str | bytes | float]: """Produce the field mapping written to this message's Redis hash. Renders the message into the exact shape stored under ``msg:{uuid}``: scalars become strings, the embedding is packed to a FLOAT32 blob via :func:`_embed_to_bytes`, and ``None`` values are coerced to empty strings. An empty ``message_id`` is dropped from the mapping so a later write cannot clobber a gateway-synced id that was filled in out of band. Pure in-memory transform with no I/O. Called by :meth:`MessageCache.log_message` to build the mapping handed to :meth:`MessageCache._atomic_log_to_redis`, and exercised directly in ``tests/core/test_message_cache_sanitize.py`` and ``tests/core/test_message_id_resync.py``. Returns: dict[str, str | bytes | float]: The hash field mapping, with the embedding as ``bytes`` and ``message_id`` omitted when empty. """ raw = { "user_id": self.user_id, "user_name": self.user_name, "platform": self.platform, "channel_id": self.channel_id, "text": self.text, "timestamp": str(self.timestamp), "embedding": _embed_to_bytes(self.embedding), "message_id": self.message_id, "reply_to_id": self.reply_to_id, "kind": self.kind, "turn_summary_id": self.turn_summary_id, } sanitized = {k: (v if v is not None else "") for k, v in raw.items()} # Omit message_id if empty/falsy to prevent overwriting gateway-synced message_ids if not sanitized.get("message_id"): sanitized.pop("message_id", None) return sanitized
[docs] @classmethod def from_dict(cls, data: dict[str, Any]) -> CachedMessage: """Reconstruct a :class:`CachedMessage` from a plain dict. The inverse of :meth:`to_dict`: coerces each field to its expected type with permissive defaults (missing keys become empty strings, zero timestamps, an empty embedding, and ``kind`` defaulting to ``user_in``) so partial or legacy payloads still deserialise cleanly. Pure in-memory transform with no I/O. Called by :meth:`from_json` and used directly by callers that already hold a decoded dict. Args: data (dict[str, Any]): A dict produced by :meth:`to_dict` (or a close equivalent). Returns: CachedMessage: The reconstructed message. """ return cls( user_id=str(data.get("user_id", "")), user_name=str(data.get("user_name", "")), platform=str(data.get("platform", "")), channel_id=str(data.get("channel_id", "")), text=str(data.get("text", "")), timestamp=float(data.get("timestamp", 0.0)), embedding=data.get("embedding") or [], message_id=str(data.get("message_id", "")), reply_to_id=str(data.get("reply_to_id", "")), kind=str(data.get("kind", "user_in") or "user_in"), turn_summary_id=str(data.get("turn_summary_id", "")), )
[docs] @classmethod def from_json(cls, raw: str) -> CachedMessage: """Reconstruct a :class:`CachedMessage` from a JSON string. The inverse of :meth:`to_json`: parses the string with ``json.loads`` and delegates field coercion to :meth:`from_dict`. Pure in-memory transform with no I/O. Args: raw (str): A JSON string produced by :meth:`to_json`. Returns: CachedMessage: The reconstructed message. Raises: json.JSONDecodeError: If *raw* is not valid JSON. """ return cls.from_dict(json.loads(raw))
[docs] @classmethod def from_redis_hash( cls, mapping: dict[str, Any], key: str = "", ) -> CachedMessage: """Reconstruct a :class:`CachedMessage` from a Redis ``HGETALL`` mapping. The inverse of :meth:`to_redis_hash`: coerces each stored field back to its native type and decodes the binary embedding via :func:`_bytes_to_embed`, but only when the raw blob looks like a real vector (``bytes`` longer than 100 bytes) rather than a decoded-string artifact, otherwise leaving the embedding empty. The originating Redis key is preserved on the returned object as ``message_key``. Pure in-memory transform with no I/O. Args: mapping (dict[str, Any]): The raw field/value mapping returned by an ``HGETALL`` on a ``msg:{uuid}`` hash. key (str): The Redis key the hash was read from, stored on the result as ``message_key``. Returns: CachedMessage: The reconstructed message. """ emb_raw = mapping.get("embedding") if isinstance(emb_raw, bytes) and len(emb_raw) > 100: embedding = _bytes_to_embed(emb_raw) else: embedding = [] return cls( user_id=str(mapping.get("user_id", "")), user_name=str(mapping.get("user_name", "")), platform=str(mapping.get("platform", "")), channel_id=str(mapping.get("channel_id", "")), text=str(mapping.get("text", "")), timestamp=float(mapping.get("timestamp", 0)), embedding=embedding, message_key=key, message_id=str(mapping.get("message_id", "")), reply_to_id=str(mapping.get("reply_to_id", "")), kind=str(mapping.get("kind", "user_in") or "user_in"), turn_summary_id=str(mapping.get("turn_summary_id", "")), )
# -- Representation ---------------------------------------------------- @property def repr(self) -> str: """Render a human-readable ``[time] user: text`` line for this message. Formats the UTC timestamp and truncates the body to ~120 characters, producing the compact form used when injecting recalled messages into the LLM prompt or writing them to logs (distinct from :meth:`__repr__`, which is the developer-facing debug form). Pure in-memory formatting with no I/O. Returns: str: A one-line ``[YYYY-MM-DD HH:MM:SS] user_name: preview`` string. """ dt = datetime.fromtimestamp(self.timestamp, tz=timezone.utc) ts_str = dt.strftime("%Y-%m-%d %H:%M:%S") preview = self.text if len(self.text) <= 120 else self.text[:117] + "..." return f"[{ts_str}] {self.user_name}: {preview}"
[docs] def __repr__(self) -> str: """Return a developer-facing debug representation of this message. Summarises the identifying fields plus the embedding length (rather than the full vector) so log lines and debugger output stay readable. This is the standard ``repr()`` form; the prompt/log-friendly one-liner lives in the :attr:`repr` property. Pure in-memory formatting with no I/O. Returns: str: A ``CachedMessage(...)`` debug string. """ return ( f"CachedMessage(user={self.user_name!r}, platform={self.platform!r}, " f"channel={self.channel_id!r}, ts={self.timestamp}, " f"embedding_len={len(self.embedding)})" )
[docs] class MessageCache: """Async Redis message cache with automatic embedding generation and RediSearch-backed vector search. Parameters ---------- redis_url: Redis connection URL (e.g. ``"redis://localhost:6379/0"``). openrouter_client: Shared :class:`OpenRouterClient` used to call the embeddings API. embedding_model: Model identifier for the embeddings endpoint (e.g. ``"google/gemini-embedding-001"``). """
[docs] def __init__( self, redis_url: str, openrouter_client: OpenRouterClient, embedding_model: str = "google/gemini-embedding-001", ssl_kwargs: dict | None = None, redis_sentinels: list[str] | None = None, redis_sentinel_master: str = "falkordb", resilience_kwargs: dict | None = None, ) -> None: """Initialize the instance. Args: redis_url (str): The redis url value. openrouter_client (OpenRouterClient): The openrouter client value. embedding_model (str): The embedding model value. ssl_kwargs (dict | None): Optional SSL/mTLS keyword arguments forwarded to ``redis.asyncio.from_url`` or ``Sentinel``. redis_sentinels (list[str] | None): Optional list of sentinel hosts. redis_sentinel_master (str): The sentinel master name. resilience_kwargs (dict | None): Optional retry/backoff/health-check keyword arguments (see ``Config.redis_resilience_kwargs``) so reads/writes ride through a Sentinel failover window instead of raising on the first attempt. """ _ssl = ssl_kwargs or {} if _ssl: _ssl = dict(_ssl) _ssl["ssl"] = True _resil = dict(resilience_kwargs or {}) if redis_sentinels: sentinels = [] for s in redis_sentinels: parts = s.split(":") if len(parts) == 2: sentinels.append((parts[0], int(parts[1]))) else: sentinels.append((parts[0], 26379)) sentinel = Sentinel( sentinels, sentinel_kwargs=_ssl, **{**_ssl, **_resil}, ) self._redis = sentinel.master_for( redis_sentinel_master, decode_responses=True, **_resil, ) self._redis_raw = sentinel.master_for( redis_sentinel_master, decode_responses=False, **_resil, ) else: self._redis: aioredis.Redis = aioredis.from_url( redis_url, decode_responses=True, **{**_ssl, **_resil}, ) self._redis_raw: aioredis.Redis = aioredis.from_url( redis_url, decode_responses=False, **{**_ssl, **_resil}, ) self._openrouter = openrouter_client self._embedding_model = embedding_model self._lua_log_script: str | None = None
async def _monotonic_channel_score( self, platform: str, channel_id: str, timestamp: float, ) -> float: """Compute a strictly increasing sorted-set score for a channel message. The per-channel ``channel_msgs:*`` zset orders messages by time, but two deliveries can share the same millisecond, which would collide on score and corrupt ordering. This widens the millisecond timestamp by 1000 and adds a per-channel ``INCR`` sequence (mod 1000) so each message in a channel gets a unique, monotonically increasing score even under concurrent logging. ``INCR`` and ``EXPIRE`` the ``sg:channel_seq:{platform}:{channel_id}`` counter (built by :meth:`_channel_seq_key`) in Redis on every call. Called only by :meth:`log_message`. Args: platform (str): Platform identifier (e.g. ``"discord"``). channel_id (str): Channel identifier. timestamp (float): The message's Unix timestamp in seconds. Returns: float: A strictly increasing zset score for this channel. """ seq_key = self._channel_seq_key(platform, channel_id) seq = int(await self._redis.incr(seq_key)) await self._redis.expire(seq_key, TTL_SECONDS) base = int(timestamp * 1000) * 1000 return float(base + (seq % 1000)) async def _load_log_lua(self) -> str: """Load and memoize the atomic message-log Lua script source. Reads ``atomic_log_message.lua`` from disk on first call and caches the text in ``self._lua_log_script`` so subsequent calls avoid further file I/O. If the file cannot be read, an empty string is cached so the caller can detect the absence and fall back to a non-atomic pipeline. Reads the on-disk ``atomic_log_message.lua`` file (the same script that ``_atomic_log_to_redis`` later ``EVAL``\\ s in Redis). Called only by :meth:`_atomic_log_to_redis`, which in turn is invoked from :meth:`log_message`. Returns: str: The Lua script source, or an empty string when the file is missing or unreadable. """ if self._lua_log_script is None: try: with open("atomic_log_message.lua", "r") as f: self._lua_log_script = f.read() except Exception: self._lua_log_script = "" return self._lua_log_script async def _atomic_log_to_redis( self, *, key: str, zset_key: str, idem_key: str, score: float, mapping: dict[str, str], ) -> str: """Atomically write a message hash and index it in the channel zset. Persists the message hash, sets its TTL, adds the key to the per-channel sorted set, and registers the channel in ``sg:active_channels`` as a single atomic unit. When an *idem_key* is supplied, the operation also claims it ``SET NX`` so concurrent workers logging the same platform delivery converge on one ``msg:*`` row rather than creating duplicates; if the key was already claimed, the previously stored key is returned and no new row is written. Loads the script via :meth:`_load_log_lua` and, when present, runs it through ``self._redis.eval`` (Redis ``EVAL``) — passing 3 KEYS (the msg hash key, the zset key, the idem key) plus the TTL, score, field count, and the flattened field/value pairs (binary embedding values are passed through as-is). If the Lua source is unavailable it falls back to a non-atomic ``self._redis.pipeline()`` doing ``HSET`` + ``EXPIRE`` + ``ZADD`` + ``EXPIRE`` and ``SADD`` to ``sg:active_channels`` (this fallback path does not honor *idem_key*). Called only by :meth:`log_message`; also stubbed in ``tests/test_context_hardening.py``. Args: key (str): Target Redis hash key for the message (``msg:{uuid}``). zset_key (str): Per-channel sorted-set key (``channel_msgs:{platform}:{channel_id}``) the message is added to. idem_key (str): Idempotency key (``msgid:{platform}:{channel_id}: {message_id}``); empty string skips the claim step. score (float): Monotonic zset score positioning the message in time. mapping (dict[str, str]): Field/value pairs to write to the hash; ``bytes`` values (e.g. the embedding blob) are preserved. Returns: str: The Redis key the message was stored under — either *key*, or the pre-existing key when an idempotent claim already existed. """ script = await self._load_log_lua() if not script: pipe = self._redis.pipeline() pipe.hset(key, mapping=mapping) pipe.expire(key, TTL_SECONDS) pipe.zadd(zset_key, {key: score}) pipe.expire(zset_key, TTL_SECONDS) # Register active channel in the cache set parts = zset_key.split(":", 2) if len(parts) >= 3: pipe.sadd("sg:active_channels", f"{parts[1]}:{parts[2]}") pipe.expire("sg:active_channels", TTL_SECONDS) await pipe.execute() return key flat: list[str | bytes] = [] for field, value in mapping.items(): if isinstance(value, bytes): flat.extend([str(field), value]) else: flat.extend([str(field), str(value)]) result = await self._redis.eval( script, 3, key, zset_key, idem_key or "", str(TTL_SECONDS), str(score), str(len(mapping)), *flat, ) if isinstance(result, bytes): result = result.decode("utf-8") return str(result) # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------
[docs] async def log_message( self, platform: str, channel_id: str, user_id: str, user_name: str, text: str, timestamp: float | None = None, embedding: list[float] | None = None, defer_embedding: bool = False, message_id: str = "", reply_to_id: str = "", kind: str = "user_in", turn_summary_id: str = "", message_key: str = "", ) -> CachedMessage: """Log a message to Redis with an embedding and a 90-day TTL. *text* is passed through :func:`strip_llm_injection_artifacts_for_cache` so prompt-only suffixes (e.g. channel semantic recall XML) are never stored or embedded. If a caller-supplied *embedding* was computed from unstripped text, it is dropped and recomputed from the cleaned *text*. The message is stored as a Redis hash (``msg:{uuid}``) with the embedding as a binary FLOAT32 blob, automatically indexed by the ``idx:messages`` RediSearch index. If you add the ``kind`` TAG field to an existing RediSearch index, run ``python init_redis_indexes.py`` (or rely on bot startup) so ``ALTER`` can add ``kind``; otherwise recreate ``idx:messages``. Parameters ---------- embedding: Pre-computed embedding vector. When provided the internal embedding API call is skipped, saving a round-trip. defer_embedding: When ``True`` **and** no pre-computed *embedding* is given, store a zero-vector placeholder instead of calling the embeddings API. The caller is responsible for enqueuing the returned ``message_key`` in an :class:`EmbeddingBatchQueue` so the real embedding is written later. Returns ------- CachedMessage The message object that was written to Redis (includes the embedding vector, or an empty list when deferred). """ _original_text = text text = strip_llm_injection_artifacts_for_cache(text) if embedding is not None and text != _original_text: embedding = None if timestamp is None: timestamp = time.time() if embedding is None: if defer_embedding or not text or not text.strip(): embedding = [0.0] * VECTOR_DIM else: embedding = await self._openrouter.embed( text, self._embedding_model, ) message = CachedMessage( user_id=user_id, user_name=user_name, platform=platform, channel_id=channel_id, text=text, timestamp=timestamp, embedding=embedding, message_id=message_id, reply_to_id=reply_to_id, kind=kind or "user_in", turn_summary_id=turn_summary_id, ) # Idempotency: when a platform message_id is known (and the caller did # not pin an explicit message_key), claim a per-message-id key so two # concurrent workers logging the same delivery converge on ONE msg:* # row instead of creating duplicates in the channel zset. key = message_key idem_key = "" if not key and message_id: idem_key = f"msgid:{platform}:{channel_id}:{message_id}" key = f"msg:{_uuid.uuid4()}" if not key: key = f"msg:{_uuid.uuid4()}" message.message_key = key mapping = message.to_redis_hash() zset_key = self._channel_zset_key(platform, channel_id) score = await self._monotonic_channel_score(platform, channel_id, timestamp) stored_key = await self._atomic_log_to_redis( key=key, zset_key=zset_key, idem_key=idem_key, score=score, mapping=mapping, ) message.message_key = stored_key return message
[docs] async def get_recent( self, platform: str, channel_id: str, count: int = 50, ) -> list[CachedMessage]: """Return the *count* most recent cached messages for a channel. The primary "recent history" read used when assembling the LLM prompt and by the heartbeat/proactive/backfill paths. ``ZREVRANGE`` the per-channel ``channel_msgs:{platform}:{channel_id}`` sorted set for the newest keys, then hydrate them via :meth:`_fetch_hashes`; this O(log N) path never touches the RediSearch module, so recent history stays available even if the vector index is unavailable. Results come back newest-first. Called widely, including from :mod:`prompt_context`, ``message_processor.proactive_gates``, ``message_processor.history_backfill``, ``message_processor.channel_heartbeat``, and ``build_kg``. Args: platform (str): Platform identifier (e.g. ``"discord"``). channel_id (str): Channel identifier. count (int): Maximum number of messages to return. Returns: list[CachedMessage]: Up to *count* messages, newest first; empty when the channel has no cached messages. """ zset_key = self._channel_zset_key(platform, channel_id) msg_keys: list[str] = await self._redis.zrevrange(zset_key, 0, count - 1) if not msg_keys: return [] return await self._fetch_hashes(zset_key, msg_keys)
[docs] async def get_by_timerange( self, platform: str, channel_id: str, start: float, end: float, ) -> list[CachedMessage]: """Return cached messages whose timestamp falls within a range. Time-window read over the per-channel ``channel_msgs:{platform}:{channel_id}`` sorted set. Because zset scores are the widened monotonic values from :meth:`_monotonic_channel_score` (not raw seconds), bounds supplied as plain Unix seconds (heuristically, anything below ``1e12``) are scaled up by 1_000_000 to match. Runs a ``ZRANGEBYSCORE`` and hydrates the keys via :meth:`_fetch_hashes`; results are ascending by timestamp. Exercised in ``tests/core/test_message_cache_ts_alignment.py``. Args: platform (str): Platform identifier (e.g. ``"discord"``). channel_id (str): Channel identifier. start (float): Inclusive lower bound, Unix seconds (or a pre-scaled zset score). end (float): Inclusive upper bound, Unix seconds (or a pre-scaled zset score). Returns: list[CachedMessage]: Matching messages ascending by timestamp; empty when none fall in the range. """ # Scale raw Unix timestamps (seconds) to monotonic zset scores if unscaled if start is not None and start < 1e12: start = start * 1_000_000 if end is not None and end < 1e12: end = end * 1_000_000 zset_key = self._channel_zset_key(platform, channel_id) msg_keys: list[str] = await self._redis.zrangebyscore( zset_key, min=start, max=end, ) if not msg_keys: return [] return await self._fetch_hashes(zset_key, msg_keys)
[docs] async def get_messages_after( self, platform: str, channel_id: str, after_ts_exclusive: float | None, *, zset_batch: int = 5000, ) -> list[CachedMessage]: """Return channel messages strictly newer than a timestamp floor. The "everything since X" read used to gather context the model has not yet seen. Paginates ``ZRANGEBYSCORE`` over the per-channel ``channel_msgs:{platform}:{channel_id}`` sorted set in *zset_batch*-sized windows (using Redis exclusive lower-bound syntax ``(score`` so the floor itself is excluded), then hydrates the collected keys via :meth:`_fetch_hashes` in batches of 400. As with :meth:`get_by_timerange`, a floor given in plain Unix seconds (below ``1e12``) is scaled up to a monotonic zset score; ``None`` means "from the beginning". Called by ``openrouter_client.executor`` and exercised in ``tests/core/test_message_cache_ts_alignment.py``. Args: platform (str): Platform identifier (e.g. ``"discord"``). channel_id (str): Channel identifier. after_ts_exclusive (float | None): Exclusive lower bound in Unix seconds (or a pre-scaled score); ``None`` returns all messages. zset_batch (int): Page size for each ``ZRANGEBYSCORE`` scan. Returns: list[CachedMessage]: Matching messages ascending by timestamp; empty when none qualify. """ # Scale raw Unix timestamp (seconds) to monotonic zset score if unscaled if after_ts_exclusive is not None and after_ts_exclusive < 1e12: after_ts_exclusive = after_ts_exclusive * 1_000_000 zset_key = self._channel_zset_key(platform, channel_id) min_score = "-inf" if after_ts_exclusive is None else f"({after_ts_exclusive}" msg_keys: list[str] = [] offset = 0 while True: batch: list[Any] = await self._redis.zrangebyscore( zset_key, min_score, "+inf", start=offset, num=zset_batch, ) if not batch: break msg_keys.extend( k.decode() if isinstance(k, bytes) else str(k) for k in batch ) offset += len(batch) if len(batch) < zset_batch: break if not msg_keys: return [] hm_batch = 400 messages: list[CachedMessage] = [] for i in range(0, len(msg_keys), hm_batch): part = msg_keys[i : i + hm_batch] messages.extend(await self._fetch_hashes(zset_key, part)) return messages
[docs] async def update_text_by_message_id( self, platform: str, channel_id: str, message_id: str, new_text: str, ) -> str | None: """Overwrite a cached message's text, located by platform message id. Supports the edited-message path: when a user edits a message on the platform, the gateway resyncs the new body here. Resolves the message's Redis key via :meth:`find_key_by_message_id` (idempotency-index lookup, falling back to a channel-zset scan) and, on a hit, ``HSET``\\ s the ``text`` field on that ``msg:*`` hash. The embedding is intentionally left untouched. Called by ``message_processor.processor`` on edit events. Args: platform (str): Platform identifier (e.g. ``"discord"``). channel_id (str): Channel identifier. message_id (str): Platform-specific id of the edited message. new_text (str): The replacement text to store. Returns: str | None: The Redis key that was updated, or ``None`` if no matching message was found. """ target_key = await self.find_key_by_message_id( platform, channel_id, message_id, ) if target_key is None: return None await self._redis.hset(target_key, "text", new_text) return target_key
[docs] async def find_key_by_message_id( self, platform: str, channel_id: str, message_id: str, ) -> str | None: """Resolve the Redis key of a cached message by its platform message id. Tries the fast path first — ``GET`` the idempotency index ``msgid:{platform}:{channel_id}:{message_id}`` written at log time and confirm the pointed-at hash still ``EXISTS`` — then falls back to a bounded ``ZREVRANGE`` scan (windows of 500, up to 5000 keys) of the per-channel sorted set, pipelining ``HGET ... message_id`` to find the match. Reads only; index-lookup failures are swallowed and logged at debug. Called by :meth:`update_text_by_message_id`, :meth:`find_keys_by_message_ids`, and ``tools.xray_tool``. Args: platform (str): Platform identifier (e.g. ``"discord"``). channel_id (str): Channel identifier. message_id (str): Platform-specific message id to resolve. Returns: str | None: The matching Redis key (e.g. ``"msg:abc-123"``), or ``None`` if not found. """ idem_key = f"msgid:{platform}:{channel_id}:{message_id}" try: indexed = await self._redis.get(idem_key) if indexed: if isinstance(indexed, bytes): indexed = indexed.decode("utf-8") exists = await self._redis.exists(indexed) if exists: return indexed except Exception: logger.debug( "msgid index lookup failed for %s", message_id, exc_info=True, ) zset_key = self._channel_zset_key(platform, channel_id) offset = 0 window = 500 while offset < 5000: msg_keys: list[str] = await self._redis.zrevrange( zset_key, offset, offset + window - 1, ) if not msg_keys: break pipe = self._redis.pipeline() for key in msg_keys: pipe.hget(key, "message_id") results = await pipe.execute() for key, stored_id in zip(msg_keys, results): if stored_id is not None and str(stored_id) == message_id: return key if len(msg_keys) < window: break offset += window return None
[docs] async def find_keys_by_message_ids( self, platform: str, channel_id: str, message_ids: list[str], ) -> dict[str, str]: """Batch version of :meth:`find_key_by_message_id`. Returns a mapping of ``{message_id: redis_key}`` for every *message_id* that was found in Redis. Missing IDs are omitted. Fetches the channel zset once and builds the full lookup map in a single pipeline, instead of repeating the scan per message. """ if not message_ids: return {} wanted = set(message_ids) found: dict[str, str] = {} for mid in message_ids: key = await self.find_key_by_message_id(platform, channel_id, mid) if key: found[mid] = key return found
[docs] async def has_real_embedding(self, redis_key: str) -> bool: """Report whether a message hash holds a real (non-placeholder) embedding. Messages logged with ``defer_embedding=True`` (or with empty text) get a zero-vector placeholder so they can be backfilled later; this lets the embedding backfill workers tell which ``msg:*`` rows still need a real vector. ``HGET``\\ s the ``embedding`` field over the raw, non-``decode_responses`` Redis connection (``self._redis_raw``) so the binary FLOAT32 bytes are not mangled by UTF-8 decoding, then checks the blob is full length and not all zeros. Read-only. See :meth:`has_real_embedding_many` for the pipelined batch form. Args: redis_key (str): The ``msg:{uuid}`` key to inspect. Returns: bool: ``True`` if the stored embedding is full-length and non-zero; ``False`` if missing, too short, or an all-zero placeholder. """ raw: bytes | None = await self._redis_raw.hget(redis_key, "embedding") if raw is None or len(raw) < VECTOR_DIM * 4: return False return raw != b"\x00" * len(raw)
[docs] async def has_real_embedding_many( self, redis_keys: list[str], ) -> list[bool]: """Batch-check which message hashes hold a real embedding. The pipelined form of :meth:`has_real_embedding`, used by the embedding backfill paths to filter a whole batch of keys in one Redis round trip instead of one call each. Pipelines ``HGET ... embedding`` for every key over the raw, non-``decode_responses`` connection (``self._redis_raw``) so binary FLOAT32 data survives intact, then applies the same full-length-and-non-zero test per key. Read-only. Called from ``background_tasks`` and ``message_processor.history_backfill``. Args: redis_keys (list[str]): The ``msg:{uuid}`` keys to inspect. Returns: list[bool]: One flag per input key (positionally aligned), ``True`` where the embedding is real and ``False`` where it is missing, short, or a zero placeholder. """ if not redis_keys: return [] pipe = self._redis_raw.pipeline() for key in redis_keys: pipe.hget(key, "embedding") results = await pipe.execute() min_len = VECTOR_DIM * 4 zero_blob = b"\x00" * min_len out: list[bool] = [] for raw in results: if raw is None or len(raw) < min_len: out.append(False) else: out.append(raw != zero_blob) return out
[docs] async def mark_deleted_by_message_id( self, platform: str, channel_id: str, message_id: str, deleted_at_iso: str, ) -> bool: """Tombstone a cached message by prefixing a deletion marker to its text. Supports the platform delete path: when a user deletes a message, the gateway calls this so cached/recalled history shows it was removed rather than silently dropping it (the original text is kept after the marker). ``ZREVRANGE``\\ s the most recent 200 keys of the per-channel ``channel_msgs:{platform}:{channel_id}`` sorted set, pipelines ``HGET ... message_id`` to find the match, and — if the marker is not already present — ``HSET``\\ s the ``text`` field with a leading ``[deleted at ...]`` tag. Called by ``message_processor.processor`` on delete events. Args: platform (str): Platform identifier (e.g. ``"discord"``). channel_id (str): Channel identifier. message_id (str): Platform-specific id of the deleted message. deleted_at_iso (str): ISO-8601 deletion timestamp embedded in the marker. Returns: bool: ``True`` if the message was found (and tagged, unless already tagged); ``False`` if no match was in the recent window. """ zset_key = self._channel_zset_key(platform, channel_id) msg_keys: list[str] = await self._redis.zrevrange(zset_key, 0, 199) if not msg_keys: return False pipe = self._redis.pipeline() for key in msg_keys: pipe.hget(key, "message_id") results = await pipe.execute() target_key: str | None = None for key, stored_id in zip(msg_keys, results): if stored_id is not None and str(stored_id) == message_id: target_key = key break if target_key is None: return False current_text = await self._redis.hget(target_key, "text") tag = f"[deleted at {deleted_at_iso}] " if current_text and tag not in str(current_text): new_text = tag + str(current_text) await self._redis.hset(target_key, "text", new_text) return True
[docs] async def search_messages( self, query: str, limit: int = 10, channel_id: str | None = None, platform: str | None = None, user_id: str | None = None, *, channel_ids: list[str] | None = None, min_timestamp: float | None = None, query_embedding: list[float] | np.ndarray | None = None, ) -> list[dict[str, Any]]: """Semantic search across cached messages using RediSearch KNN. Generates an embedding for *query* unless *query_embedding* is provided, then runs vector similarity search filtered by optional *channel_id*, *platform*, and/or *user_id*. When *channel_ids* is provided, the filter becomes an OR over the listed channel IDs (RediSearch ``@channel_id:{c1|c2|c3}`` syntax) and *channel_id* must be ``None``. Each id is escaped individually. When *min_timestamp* is set (Unix seconds), only messages with ``timestamp >= min_timestamp`` are considered. In that case *platform* must be set, plus either *channel_id* or *channel_ids* (strict per-channel(s) recall). Returns a list of dicts with ``text``, ``user_name``, ``redis_key``, ``timestamp``, ``similarity``, ``channel_id``, and ``platform``. """ if channel_id is not None and channel_ids: logger.warning( "search_messages: channel_id and channel_ids are mutually " "exclusive; ignoring channel_ids", ) channel_ids = None if channel_ids: _seen: set[str] = set() _normalized: list[str] = [] for _cid in channel_ids: _c = str(_cid).strip() if not _c or _c in _seen: continue _seen.add(_c) _normalized.append(_c) channel_ids = _normalized or None if query_embedding is None: if not query or not query.strip(): return [] embedding = await self._openrouter.embed( query, self._embedding_model, ) query_blob = _embed_to_bytes(embedding) else: arr = np.asarray(query_embedding, dtype=np.float32).reshape(-1) if arr.size != VECTOR_DIM: logger.warning( "search_messages: query_embedding dim %d != %d", arr.size, VECTOR_DIM, ) return [] query_blob = _embed_to_bytes(arr) if min_timestamp is not None: if not platform or not (channel_id or channel_ids): logger.warning( "search_messages: min_timestamp requires platform and " "(channel_id or channel_ids)", ) return [] pre_filter = "*" parts: list[str] = [] if platform: parts.append(f"@platform:{{{_escape_tag(platform)}}}") if channel_id: parts.append(f"@channel_id:{{{_escape_tag(channel_id)}}}") elif channel_ids: _joined = "|".join(_escape_tag(c) for c in channel_ids) parts.append(f"@channel_id:{{{_joined}}}") if user_id: parts.append(f"@user_id:{{{_escape_tag(user_id)}}}") if min_timestamp is not None: parts.append(f"@timestamp:[{min_timestamp} +inf]") if parts: pre_filter = " ".join(parts) knn = ( f"({pre_filter})=>" f"[KNN {limit} @embedding $query_vec " f"EF_RUNTIME {DEFAULT_KNN_EF_RUNTIME} AS score]" ) q = ( Query(knn) .sort_by("score") .return_fields( "user_id", "user_name", "platform", "channel_id", "text", "timestamp", "score", ) .paging(0, limit) .dialect(2) ) result = await self._redis.ft(MSG_INDEX_NAME).search( q, query_params={"query_vec": query_blob}, ) out: list[dict[str, Any]] = [] for doc in result.docs: score_raw = getattr(doc, "score", None) cosine_dist = float(score_raw) if score_raw is not None else 1.0 similarity = 1.0 - cosine_dist _id = getattr(doc, "id", None) redis_key = ( _id.decode("utf-8", errors="replace") if isinstance(_id, bytes) else str(_id or "") ) out.append( { "redis_key": redis_key, "text": self._doc_str(doc, "text"), "user_name": self._doc_str(doc, "user_name"), "user_id": self._doc_str(doc, "user_id"), "channel_id": self._doc_str(doc, "channel_id"), "platform": self._doc_str(doc, "platform"), "timestamp": float(self._doc_str(doc, "timestamp") or 0), "similarity": round(similarity, 4), } ) return out
[docs] async def get_messages_around_key( self, platform: str, channel_id: str, redis_key: str, before: int = 5, after: int = 5, ) -> list[CachedMessage]: """Return a chronological window of messages centered on one key. Powers semantic-recall context expansion: after a vector hit identifies a single relevant message, this widens it to the surrounding conversation so the recalled snippet reads coherently. Looks up the key's ``ZRANK`` in the per-channel ``channel_msgs:{platform}:{channel_id}`` sorted set, ``ZRANGE``\\ s the rank window ``[rank - before, rank + after]``, and hydrates the keys via :meth:`_fetch_hashes`. If the key is absent from the zset but its hash still ``EXISTS``, the single message is returned; if it is gone entirely, an empty list is returned. Called by ``message_processor.channel_semantic_recall``. Args: platform (str): Platform identifier (e.g. ``"discord"``). channel_id (str): Channel identifier. redis_key (str): The ``msg:{uuid}`` key to center the window on. before (int): Number of preceding messages to include. after (int): Number of following messages to include. Returns: list[CachedMessage]: Up to ``before + 1 + after`` messages ascending by time; possibly just the one message, or empty if the key is gone. """ if not redis_key: return [] zset_key = self._channel_zset_key(platform, channel_id) rank = await self._redis.zrank(zset_key, redis_key) if rank is None: exists = await self._redis.exists(redis_key) if exists: return await self._fetch_hashes(zset_key, [redis_key]) return [] start = max(0, int(rank) - before) stop = int(rank) + after msg_keys_raw = await self._redis.zrange(zset_key, start, stop) msg_keys: list[str] = [ k.decode() if isinstance(k, bytes) else str(k) for k in (msg_keys_raw or []) ] if not msg_keys: return [] return await self._fetch_hashes(zset_key, msg_keys)
[docs] async def get_recent_for_user( self, platform: str, user_id: str, limit: int = 20, ) -> list[CachedMessage]: """Return a user's most recent cached messages across all channels. Unlike the channel-scoped reads, this finds a user's history without knowing where they spoke, by running an ``FT.SEARCH`` over the ``idx:messages`` RediSearch index with TAG filters on platform and user_id (escaped via :func:`_escape_tag`), sorted by timestamp descending. Results are materialised through :meth:`_doc_to_cached_message` (embeddings are not returned). Backs :meth:`get_recent_speaker_channels` and several tools (``tools.xray_tool``, ``tools.dm_history``, ``tools.chat_analytics``, ``tools.gravimetric_telescope``). Args: platform (str): Platform identifier (e.g. ``"discord"``). user_id (str): The speaker's platform user id. limit (int): Maximum number of messages to return. Returns: list[CachedMessage]: The user's messages newest-first, across every channel (including DMs). """ tag_filter = ( f"@platform:{{{_escape_tag(platform)}}} " f"@user_id:{{{_escape_tag(user_id)}}}" ) q = ( Query(tag_filter) .sort_by("timestamp", asc=False) .return_fields( "user_id", "user_name", "platform", "channel_id", "text", "timestamp", "message_id", "reply_to_id", "kind", ) .paging(0, limit) .dialect(2) ) result = await self._redis.ft(MSG_INDEX_NAME).search(q) return [self._doc_to_cached_message(doc) for doc in result.docs]
[docs] async def get_recent_speaker_channels( self, platform: str, user_id: str, limit: int = 10, lookback: int = 300, ) -> list[str]: """Return the speaker's most-recently-used channel IDs (MRU order). Walks the last *lookback* messages this user sent on *platform* (newest first) and returns up to *limit* unique channel IDs in most-recent-first order. Privacy-by-construction: the result only contains channels where *user_id* has actually sent at least one message — i.e. channels the speaker demonstrably had access to at write time. Other users' DMs and channels the speaker only lurked in cannot leak through this function. """ if not platform or not user_id: return [] if limit <= 0: return [] try: messages = await self.get_recent_for_user( platform=platform, user_id=user_id, limit=max(1, int(lookback)), ) except Exception: logger.debug( "get_recent_speaker_channels: get_recent_for_user failed for " "%s:%s", platform, user_id, exc_info=True, ) return [] seen: set[str] = set() ordered: list[str] = [] for m in messages: cid = (m.channel_id or "").strip() if not cid or cid in seen: continue seen.add(cid) ordered.append(cid) if len(ordered) >= limit: break return ordered
# ------------------------------------------------------------------ # Channel metadata cache # # Lightweight ``channel_meta:{platform}:{channel_id}`` HASH carrying # human-readable channel/guild names so cross-channel recall windows # can be tagged "this is FROM #design in MyGuild, not the current # room". Population is fire-and-forget at message-ingest time; reads # are batched at recall time. TTL matches the message cache (90d). # ------------------------------------------------------------------ @staticmethod def _channel_meta_key(platform: str, channel_id: str) -> str: """Build the Redis hash key for a channel's metadata record. Used by :meth:`record_channel_metadata` (writes) and :meth:`get_channel_metadata_many` (batched reads) to locate the lightweight ``channel_meta:{platform}:{channel_id}`` hash that carries human-readable channel/guild names and the ``is_dm`` flag. Args: platform (str): Platform identifier (e.g. ``"discord"``). channel_id (str): Channel identifier. Returns: str: The ``channel_meta:{platform}:{channel_id}`` key. """ return f"channel_meta:{platform}:{channel_id}"
[docs] async def record_channel_metadata( self, platform: str, channel_id: str, *, channel_name: str = "", guild_id: str = "", guild_name: str = "", is_dm: bool = False, ) -> None: """Upsert human-readable metadata for a channel. Empty fields are still written so a later read can tell the difference between "channel was never seen" (HGETALL returns empty mapping) and "channel was seen but has no guild" (DM-style: guild_id/guild_name explicitly empty). ``is_dm`` is persisted as ``"1"`` / ``"0"`` and is consulted by the cross-channel recall path to avoid ever leaking 1-on-1 DM (Discord) or 2-member room (Matrix) context into another channel's prompt. """ if not platform or not channel_id: return key = self._channel_meta_key(platform, channel_id) mapping = { "channel_name": channel_name or "", "guild_id": guild_id or "", "guild_name": guild_name or "", "is_dm": "1" if is_dm else "0", "updated_at": str(time.time()), } try: pipe = self._redis.pipeline() pipe.hset(key, mapping=mapping) pipe.expire(key, TTL_SECONDS) await pipe.execute() except Exception: logger.debug( "record_channel_metadata: write failed for %s:%s", platform, channel_id, exc_info=True, )
[docs] async def get_channel_metadata_many( self, platform: str, channel_ids: list[str], ) -> dict[str, dict[str, Any]]: """Batch-fetch channel metadata. Missing channels are omitted. Returns ``{channel_id: {"channel_name": str, "guild_id": str, "guild_name": str, "is_dm": bool, "updated_at": str}}``. ``is_dm`` is decoded back to a Python ``bool``. Channels written before the ``is_dm`` field existed default to ``False`` (legacy rows are unlikely to be DMs because there is no pre-existing deploy that would have populated this cache). """ if not platform or not channel_ids: return {} unique: list[str] = [] seen: set[str] = set() for c in channel_ids: cs = str(c).strip() if not cs or cs in seen: continue seen.add(cs) unique.append(cs) if not unique: return {} try: pipe = self._redis.pipeline() for cid in unique: pipe.hgetall(self._channel_meta_key(platform, cid)) results = await pipe.execute() except Exception: logger.debug( "get_channel_metadata_many: pipeline failed for %s", platform, exc_info=True, ) return {} out: dict[str, dict[str, Any]] = {} for cid, raw in zip(unique, results): if not raw: continue out[cid] = { "channel_name": str(raw.get("channel_name", "") or ""), "guild_id": str(raw.get("guild_id", "") or ""), "guild_name": str(raw.get("guild_name", "") or ""), "is_dm": str(raw.get("is_dm", "0") or "0") == "1", "updated_at": str(raw.get("updated_at", "") or ""), } return out
# ------------------------------------------------------------------ # Thought-summary storage # ------------------------------------------------------------------
[docs] async def log_thought_summary( self, channel_id: str, thought_text: str, timestamp: float | None = None, ) -> None: """Store a thought summary in a per-channel Redis sorted set. Deduplicates by comparing against the most recent entry — if it has the same content the new entry is silently skipped. Parameters ---------- channel_id: Platform-agnostic channel identifier. thought_text: The raw text extracted from ``<thought>`` tags. timestamp: Unix timestamp; defaults to ``time.time()``. """ if not thought_text or not thought_text.strip(): return if timestamp is None: timestamp = time.time() zset_key = f"thought_summaries:{channel_id}" # Deduplicate against the most recent entry try: recent = await self._redis.zrevrange(zset_key, 0, 0) if recent: last_data = json.loads(recent[0]) if last_data.get("content") == thought_text: logger.debug( "Skipping duplicate thought summary for channel %s", channel_id, ) return except Exception: logger.debug( "Thought summary deduplication check failed, storing anyway", exc_info=True, ) entry = json.dumps( { "content": thought_text, "timestamp": timestamp, } ) pipe = self._redis.pipeline() pipe.zadd(zset_key, {entry: timestamp}) pipe.expire(zset_key, TTL_SECONDS) await pipe.execute() logger.debug( "Stored thought summary in Redis for channel %s", channel_id, )
[docs] async def get_recent_thought_summaries( self, channel_id: str, limit: int = 30, ) -> list[dict[str, Any]]: """Retrieve unique thought summaries for a channel (newest first). Parameters ---------- channel_id: Platform-agnostic channel identifier. limit: Maximum number of unique summaries to return. Returns ------- list[dict] Each dict has ``content`` (str) and ``timestamp`` (float). """ zset_key = f"thought_summaries:{channel_id}" try: raw_entries: list[str] = await self._redis.zrevrange( zset_key, 0, limit * 3, ) except Exception: logger.debug( "Failed to retrieve thought summaries for channel %s", channel_id, exc_info=True, ) _note_context_read_degraded("thought_summaries") return [] summaries: list[dict[str, Any]] = [] seen_contents: set[str] = set() for entry_json in raw_entries: try: data = json.loads(entry_json) except json.JSONDecodeError: continue content = data.get("content", "") if content in seen_contents: continue seen_contents.add(content) summaries.append(data) if len(summaries) >= limit: break return summaries
[docs] async def backfill_channel_indexes(self, batch_size: int = 500) -> int: """Rebuild the per-channel sorted-set indexes from stored message hashes. A maintenance/migration routine for when the ``channel_msgs:*`` indexes are missing or incomplete (e.g. after introducing them) but the underlying ``msg:*`` hashes still exist. ``SCAN``\\ s all ``msg:*`` keys in *batch_size* pages, pipelines ``HMGET ... platform channel_id timestamp`` to read their routing fields, then ``ZADD``\\ s each key into its ``channel_msgs:{platform}:{channel_id}`` sorted set, refreshes the zset TTL, and registers the channel in ``sg:active_channels`` — logging a total at the end. Idempotent: ``ZADD`` simply re-scores existing members, so it is safe to run repeatedly. No in-repo callers (operational entry point, invoked manually). Args: batch_size (int): Number of keys to scan and process per page. Returns: int: The number of messages (re)indexed. """ cursor: int | str = 0 total = 0 while True: cursor, keys = await self._redis.scan( cursor, match="msg:*", count=batch_size, ) if keys: pipe = self._redis.pipeline() for key in keys: pipe.hmget(key, "platform", "channel_id", "timestamp") results = await pipe.execute() pipe = self._redis.pipeline() added = 0 for key, values in zip(keys, results): plat, chan, ts = values if not plat or not chan or not ts: continue zset_key = self._channel_zset_key(plat, chan) pipe.zadd(zset_key, {key: float(ts)}) pipe.expire(zset_key, TTL_SECONDS) pipe.sadd("sg:active_channels", f"{plat}:{chan}") pipe.expire("sg:active_channels", TTL_SECONDS) added += 1 if added: await pipe.execute() total += added if cursor == 0: break logger.info("backfill_channel_indexes: indexed %d messages", total) return total
# ------------------------------------------------------------------ # Tool-call logging (Stargazer-System-Log pipeline) # ------------------------------------------------------------------
[docs] async def log_tool_call_record( self, *, record_id: str, tool_name: str, raw_arguments_json: str, result_output: str, success: bool, execution_start: float, execution_end: float, duration_ms: float, order_index: int, round_number: int, channel_id: str, platform: str, ) -> str: """Persist one hidden tool-call execution trace to Redis. Part of the Stargazer-System-Log pipeline: every individual tool invocation in a turn is recorded so its arguments, output, timing, and success can be replayed or summarised later. Pipelines an ``HSET`` of the ``toolcall:{record_id}`` hash (result output capped at 50_000 chars), sets a 90-day TTL, and ``ZADD``\\ s the key into the per-channel ``tool_call_log:{platform}:{channel_id}`` sorted set scored by execution start. Deliberately kept out of ``channel_msgs:*`` so the prompt context builder never surfaces these raw traces; they are reached only via their summary (see :meth:`log_tool_call_summary`). The ``turn_summary_id`` is left blank here and back-patched later. Called by ``message_processor.generate_and_send``. Args: record_id (str): Unique id for this trace; forms the ``toolcall:{record_id}`` key. tool_name (str): Name of the invoked tool. raw_arguments_json (str): JSON-encoded tool arguments as called. result_output (str): The tool's output (stored truncated to 50_000 chars). success (bool): Whether the call succeeded. execution_start (float): Unix start time; also the zset score. execution_end (float): Unix end time. duration_ms (float): Wall-clock duration in milliseconds. order_index (int): Position of this call within the turn. round_number (int): Tool-calling round the call belongs to. channel_id (str): Channel identifier. platform (str): Platform identifier (e.g. ``"discord"``). Returns: str: The Redis key the record was stored under (``toolcall:{record_id}``). """ key = f"toolcall:{record_id}" mapping: dict[str, str] = { "record_id": record_id, "tool_name": tool_name, "raw_arguments_json": raw_arguments_json, "result_output": result_output[:50_000], # safety cap "success": "1" if success else "0", "execution_start": str(execution_start), "execution_end": str(execution_end), "duration_ms": str(round(duration_ms, 1)), "order_index": str(order_index), "round_number": str(round_number), "channel_id": channel_id, "platform": platform, "turn_summary_id": "", # back-patched after summary write } zset_key = f"tool_call_log:{platform}:{channel_id}" pipe = self._redis.pipeline() pipe.hset(key, mapping=mapping) pipe.expire(key, TTL_SECONDS) pipe.zadd(zset_key, {key: execution_start}) pipe.expire(zset_key, TTL_SECONDS) await pipe.execute() logger.debug( "Stored tool-call record %s (%s) for %s:%s", record_id, tool_name, platform, channel_id, ) return key
[docs] async def log_tool_call_summary( self, *, summary_id: str, record_ids: list[str], summary_text: str, channel_id: str, platform: str, timestamp: float | None = None, ) -> str: """Persist the turn-level Stargazer-System-Log summary linking its traces. The visible counterpart to the hidden :meth:`log_tool_call_record` traces: one human-readable summary of everything the tools did this turn, keyed by the assistant message it belongs to. Pipelines an ``HSET`` of the ``toolcall_summary:{summary_id}`` hash (with *record_ids* stored as a JSON array), sets a 90-day TTL, ``ZADD``\\ s it into the per-channel ``tool_call_summaries:{platform}:{channel_id}`` sorted set, and back-patches each referenced ``toolcall:{rid}`` record's ``turn_summary_id`` so traces point back at their summary. Called by ``message_processor.generate_and_send`` and ``tools.librarian_tool``. Args: summary_id (str): Unique id for the summary; forms the ``toolcall_summary:{summary_id}`` key. record_ids (list[str]): Ids of the ``toolcall:*`` records this summary covers. summary_text (str): The human-readable turn summary. channel_id (str): Channel identifier. platform (str): Platform identifier (e.g. ``"discord"``). timestamp (float | None): Unix timestamp and zset score; defaults to ``time.time()``. Returns: str: The Redis key the summary was stored under (``toolcall_summary:{summary_id}``). """ import jsonutil as json if timestamp is None: timestamp = time.time() key = f"toolcall_summary:{summary_id}" mapping: dict[str, str] = { "summary_id": summary_id, "record_ids": json.dumps(record_ids), "summary_text": summary_text, "channel_id": channel_id, "platform": platform, "timestamp": str(timestamp), "user_id": "0000000000000000000000", } zset_key = f"tool_call_summaries:{platform}:{channel_id}" pipe = self._redis.pipeline() pipe.hset(key, mapping=mapping) pipe.expire(key, TTL_SECONDS) pipe.zadd(zset_key, {key: timestamp}) pipe.expire(zset_key, TTL_SECONDS) # Back-patch each record's turn_summary_id for rid in record_ids: rec_key = f"toolcall:{rid}" pipe.hset(rec_key, "turn_summary_id", summary_id) await pipe.execute() logger.debug( "Stored tool-call summary %s (%d records) for %s:%s", summary_id, len(record_ids), platform, channel_id, ) return key
[docs] async def get_tool_call_records_by_summary( self, summary_id: str, ) -> list[dict[str, str]]: """Fetch all hidden tool-call records belonging to a summary. The expansion read for the Stargazer-System-Log: given a summary id, pull back the full per-call traces it links so a user (or tool) can inspect exactly what each tool did. ``HGETALL``\\ s the ``toolcall_summary:{summary_id}`` hash, parses its ``record_ids`` JSON array, pipelines ``HGETALL`` for each ``toolcall:{rid}`` hash, skips expired/empty rows, and sorts the survivors by ``order_index``. Returns empty when the summary id is unknown or has no records. Called by ``tools.retrieve_tool_call_log``. Args: summary_id (str): The summary id whose linked records to fetch. Returns: list[dict[str, str]]: The record field mappings ordered by ``order_index``; empty if the summary or its records are missing. """ import jsonutil as json summary_key = f"toolcall_summary:{summary_id}" raw = await self._redis.hgetall(summary_key) if not raw: return [] record_ids_json = raw.get("record_ids", "[]") try: record_ids: list[str] = json.loads(record_ids_json) except (json.JSONDecodeError, TypeError): return [] if not record_ids: return [] # Pipeline-fetch all record hashes pipe = self._redis.pipeline() for rid in record_ids: pipe.hgetall(f"toolcall:{rid}") results = await pipe.execute() records: list[dict[str, str]] = [] for raw_hash in results: if not raw_hash or all(v is None for v in raw_hash.values()): continue records.append(dict(raw_hash)) # Sort by order_index records.sort(key=lambda r: int(r.get("order_index", 0))) return records
[docs] async def get_recent_tool_call_summaries( self, platform: str, channel_id: str, limit: int = 10, ) -> list[dict[str, str]]: """Fetch the most recent tool-call summaries for a channel. Feeds the context-injection and observability surfaces with a channel's recent Stargazer-System-Log summaries. ``ZREVRANGE``\\ s the per-channel ``tool_call_summaries:{platform}:{channel_id}`` sorted set for the newest *limit* keys and pipelines ``HGETALL`` for each, dropping expired/empty rows. On a Redis read error it records a degraded-read metric via :func:`_note_context_read_degraded` and fails open with an empty list so a missing read never blocks a turn. Called by ``message_processor.context_injections``, ``message_processor.generate_and_send``, and ``web.obs``. Args: platform (str): Platform identifier (e.g. ``"discord"``). channel_id (str): Channel identifier. limit (int): Maximum number of summaries to return. Returns: list[dict[str, str]]: Summary field mappings newest-first (each with ``summary_id``, ``summary_text``, ``timestamp``, and a ``record_ids`` JSON array); empty on miss or degraded read. """ zset_key = f"tool_call_summaries:{platform}:{channel_id}" try: summary_keys: list[str] = await self._redis.zrevrange( zset_key, 0, limit - 1, ) except Exception: logger.debug( "Failed to retrieve tool-call summaries for %s:%s", platform, channel_id, exc_info=True, ) _note_context_read_degraded("tool_call_summaries") return [] if not summary_keys: return [] pipe = self._redis.pipeline() for sk in summary_keys: pipe.hgetall(sk) results = await pipe.execute() summaries: list[dict[str, str]] = [] for raw_hash in results: if not raw_hash or all(v is None for v in raw_hash.values()): continue summaries.append(dict(raw_hash)) return summaries
# ------------------------------------------------------------------ # Context break (per-channel timestamp floor) # ------------------------------------------------------------------ @staticmethod def _ctxbreak_key(platform: str, channel_id: str) -> str: """Build the Redis string key holding a channel's context-break floor. Names the ``ctxbreak:{platform}:{channel_id}`` key whose value is a Unix-timestamp floor; every context-building pathway must exclude messages at or before that timestamp. Pure key formatting with no Redis or other side effects. Used only by :meth:`set_ctxbreak_ts` (which ``SET``\\ s the floor) and :meth:`get_ctxbreak_ts` (which ``GET``\\ s it); no external callers. Args: platform (str): Platform identifier (e.g. ``"discord"``). channel_id (str): Channel identifier. Returns: str: The ``ctxbreak:{platform}:{channel_id}`` key. """ return f"ctxbreak:{platform}:{channel_id}"
[docs] async def set_ctxbreak_ts( self, platform: str, channel_id: str, ts: float, ) -> None: """Persist a context-break timestamp floor for a channel. Implements the user-facing "context break" / fresh-start command: once set, every context-building pathway must drop messages whose timestamp is ``<=`` this value, so the model starts the conversation cleanly from this point. ``SET``\\ s the ``ctxbreak:{platform}:{channel_id}`` string key (built by :meth:`_ctxbreak_key`) to the stringified timestamp. The counterpart reader is :meth:`get_ctxbreak_ts`. Called by ``message_processor.processor``. Args: platform (str): Platform identifier (e.g. ``"discord"``). channel_id (str): Channel identifier. ts (float): The Unix-timestamp floor; messages at or before it are excluded from context. """ await self._redis.set(self._ctxbreak_key(platform, channel_id), str(ts))
[docs] async def get_ctxbreak_ts( self, platform: str, channel_id: str, ) -> float | None: """Return a channel's context-break timestamp floor, or ``None``. The reader paired with :meth:`set_ctxbreak_ts`: every context assembly path consults this to learn the cutoff below which messages must be excluded. ``GET``\\ s the ``ctxbreak:{platform}:{channel_id}`` key (built by :meth:`_ctxbreak_key`) and parses it to a float, returning ``None`` when unset or unparseable (i.e. no break in effect). Called by :mod:`prompt_context`, ``message_processor.proactive_gates``, ``message_processor.generate_and_send``, ``message_processor.history_backfill``, and ``web.obs``. Args: platform (str): Platform identifier (e.g. ``"discord"``). channel_id (str): Channel identifier. Returns: float | None: The context-break floor in Unix seconds, or ``None`` if none is set. """ raw = await self._redis.get(self._ctxbreak_key(platform, channel_id)) if raw is None: return None try: return float(raw) except (ValueError, TypeError): return None
@property def redis_client(self) -> aioredis.Redis: """The underlying async Redis connection. Exposed so that other subsystems (e.g. :class:`ToolContext`) can share the same connection without reaching into private state. """ return self._redis @property def redis_raw_client(self) -> aioredis.Redis: """The underlying async Redis connection with decode_responses=False. Exposed so that subsystems needing raw binary payloads (e.g., RedisEventBus) can use it. """ return self._redis_raw # ------------------------------------------------------------------ # Lifecycle # ------------------------------------------------------------------
[docs] async def close(self) -> None: """Close both underlying Redis connections held by this cache. Releases the decoded (``self._redis``) and raw-binary (``self._redis_raw``) connection pools created in :meth:`__init__`, calling ``aclose()`` on each. Invoked at shutdown / teardown so sockets and pooled connections are not leaked. Called by maintenance scripts such as ``build_kg``. """ await self._redis.aclose() await self._redis_raw.aclose()
# ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ @staticmethod def _channel_zset_key(platform: str, channel_id: str) -> str: """Build the Redis key for a channel's time-ordered message index. Names the ``channel_msgs:{platform}:{channel_id}`` sorted set that is the backbone of every per-channel read in this class — recent history, time ranges, "since", and window expansion all operate on it. Pure key formatting with no Redis or other side effects. Used throughout :class:`MessageCache` (e.g. by :meth:`log_message`, :meth:`get_recent`, :meth:`get_messages_after`, :meth:`get_messages_around_key`, :meth:`find_key_by_message_id`). Args: platform (str): Platform identifier (e.g. ``"discord"``). channel_id (str): Channel identifier. Returns: str: The ``channel_msgs:{platform}:{channel_id}`` key. """ return f"channel_msgs:{platform}:{channel_id}" @staticmethod def _channel_seq_key(platform: str, channel_id: str) -> str: """Build the Redis key for a channel's monotonic sequence counter. Names the ``INCR`` counter used by :meth:`_monotonic_channel_score` (its only caller) to break ties between messages that share the same millisecond timestamp, guaranteeing strictly increasing zset scores. Args: platform (str): Platform identifier (e.g. ``"discord"``). channel_id (str): Channel identifier. Returns: str: The ``sg:channel_seq:{platform}:{channel_id}`` key. """ return f"sg:channel_seq:{platform}:{channel_id}" _HASH_FIELDS = ( "user_id", "user_name", "platform", "channel_id", "text", "timestamp", "message_id", "reply_to_id", "kind", ) async def _fetch_hashes( self, zset_key: str, msg_keys: list[str], ) -> list[CachedMessage]: """Hydrate message keys into :class:`CachedMessage` objects, self-healing. The shared materialiser behind the per-channel reads (:meth:`get_recent`, :meth:`get_by_timerange`, :meth:`get_messages_after`, :meth:`get_messages_around_key`). Pipelines ``HMGET`` of the scalar ``_HASH_FIELDS`` for each key — deliberately not ``HGETALL``, so the binary FLOAT32 embedding blob (incompatible with this ``decode_responses=True`` connection) is skipped and the returned messages carry an empty embedding. Keys whose hash has expired or been evicted (all-``None`` result) are collected and ``ZREM``\\ ed from *zset_key* in a second pipeline, so reads keep the sorted-set index self-healing rather than accumulating dangling members. Internal; not called outside :class:`MessageCache`. Args: zset_key (str): The per-channel sorted-set key the keys came from, used to prune stale members. msg_keys (list[str]): The ``msg:{uuid}`` keys to hydrate. Returns: list[CachedMessage]: One message per live key (embedding empty), positionally following *msg_keys* minus any pruned entries. """ if not msg_keys: return [] pipe = self._redis.pipeline() for key in msg_keys: pipe.hmget(key, *self._HASH_FIELDS) results = await pipe.execute() messages: list[CachedMessage] = [] stale_keys: list[str] = [] for key, values in zip(msg_keys, results): if not values or all(v is None for v in values): stale_keys.append(key) continue mapping = dict(zip(self._HASH_FIELDS, values)) messages.append( CachedMessage( user_id=str(mapping.get("user_id") or ""), user_name=str(mapping.get("user_name") or ""), platform=str(mapping.get("platform") or ""), channel_id=str(mapping.get("channel_id") or ""), text=str(mapping.get("text") or ""), timestamp=float(mapping.get("timestamp") or 0), embedding=[], message_key=key, message_id=str(mapping.get("message_id") or ""), reply_to_id=str(mapping.get("reply_to_id") or ""), kind=str(mapping.get("kind") or "user_in"), ) ) if stale_keys: pipe = self._redis.pipeline() for key in stale_keys: pipe.zrem(zset_key, key) await pipe.execute() return messages @staticmethod def _doc_str(doc: Any, field_name: str) -> str: """Safely read one field off a RediSearch result document as a string. RediSearch ``Document`` objects only carry the fields a query asked to return and raise on absent attributes, so this getattr-with-default helper normalises a possibly-missing field to a plain string (``""`` when absent or ``None``). Pure attribute access with no I/O. Used by :meth:`search_messages` and :meth:`_doc_to_cached_message` to read fields off ``FT.SEARCH`` hits. Args: doc (Any): A RediSearch result document. field_name (str): The field to read. Returns: str: The field value as a string, or ``""`` if missing. """ val = getattr(doc, field_name, None) if val is None: return "" return str(val) @classmethod def _doc_to_cached_message(cls, doc: Any) -> CachedMessage: """Build a :class:`CachedMessage` from a RediSearch result document. Bridges the ``FT.SEARCH`` result shape to the cache's domain object, reading each field through :meth:`_doc_str` (so missing fields degrade to sensible defaults) and carrying the document id across as ``message_key``. The embedding is left empty because search documents do not return the vector blob. Pure transform with no I/O. Called by :meth:`get_recent_for_user`. Args: doc (Any): A RediSearch result document for one ``msg:*`` row. Returns: CachedMessage: The materialised message (with an empty embedding). """ return CachedMessage( user_id=cls._doc_str(doc, "user_id"), user_name=cls._doc_str(doc, "user_name"), platform=cls._doc_str(doc, "platform"), channel_id=cls._doc_str(doc, "channel_id"), text=cls._doc_str(doc, "text"), timestamp=float(cls._doc_str(doc, "timestamp") or 0), embedding=[], message_key=getattr(doc, "id", ""), message_id=cls._doc_str(doc, "message_id"), reply_to_id=cls._doc_str(doc, "reply_to_id"), kind=cls._doc_str(doc, "kind") or "user_in", )
# ------------------------------------------------------------------ # Shared utility — active channel discovery # ------------------------------------------------------------------
[docs] async def get_active_channels( redis: Any, limit: int = 10, ) -> list[tuple[str, str]]: """Discover the most recently active channels across the whole cache. A module-level utility (not a :class:`MessageCache` method, so it can run against any raw Redis handle) that background workers use to decide which channels to summarise, extract knowledge from, or run heartbeats on. ``SCAN``\\ s all ``channel_msgs:*`` keys, pipelines ``ZREVRANGE ... 0 0 WITHSCORES`` to read each channel's newest message score, parses the ``platform`` and ``channel_id`` back out of the key name, and returns the top *limit* by recency. Read-only; on any failure it logs, records a degraded-read metric via :func:`_note_context_read_degraded`, and fails open with an empty list. Called by ``background_tasks``, ``kg_extraction``, ``background_agents.channel_summarizer``, and ``background_agents.channel_heartbeat``. Args: redis (Any): An async Redis client (decoded or raw) to scan. limit (int): Maximum number of channels to return. Returns: list[tuple[str, str]]: Up to *limit* ``(platform, channel_id)`` tuples, most-recently-active first; empty on error or when no channels exist. """ try: all_keys: list[str] = [] async for key in redis.scan_iter( match="channel_msgs:*", count=500, ): k = key.decode() if isinstance(key, bytes) else key all_keys.append(k) if not all_keys: return [] pipe = redis.pipeline() for key in all_keys: pipe.zrevrange(key, 0, 0, withscores=True) results = await pipe.execute() scored: list[tuple[float, str, str]] = [] for key, result in zip(all_keys, results): if not result: continue _member, score = result[0] parts = key.split(":", 2) if len(parts) < 3: continue scored.append((float(score), parts[1], parts[2])) scored.sort(reverse=True) return [(platform, cid) for _, platform, cid in scored[:limit]] except Exception: logger.exception("Failed to discover active channels") _note_context_read_degraded("active_channels") return []