"""Redis-backed message cache with per-channel sorted-set indexes and
RediSearch vector search.
Each message is stored as a Redis hash keyed by ``msg:{uuid}`` with
the embedding as a binary FLOAT32 blob. A per-channel sorted set
(``channel_msgs:{platform}:{channel_id}``) provides O(log N) time-
ordered lookups by channel, while a RediSearch HNSW index
(``idx:messages``) enables semantic KNN similarity search.
All cached messages are represented as :class:`CachedMessage` objects
and are given a 90-day TTL in Redis.
"""
from __future__ import annotations
import jsonutil as json
import logging
import re
import time
import uuid as _uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any
import numpy as np
import redis.asyncio as aioredis
from redis.asyncio.sentinel import Sentinel
from redis.commands.search.query import Query
from classifiers.redis_vector_index import DEFAULT_KNN_EF_RUNTIME
from init_redis_indexes import MSG_INDEX_NAME, VECTOR_DIM
from observability import observability
from openrouter_client import OpenRouterClient
logger = logging.getLogger(__name__)
TTL_SECONDS = 90 * 86400 # 90 days
def _note_context_read_degraded(section: str) -> None:
"""Emit a metric when a context-critical Redis read degrades to empty.
Client-level retry (see ``Config.redis_resilience_kwargs``) already rides
through a transient Sentinel failover window, so this fires only when a read
*still* failed after retries — making the (intentionally fail-open) degraded
turn visible instead of silent.
"""
try:
observability.increment(
"context_read_degraded", tags={"section": section}
)
except Exception:
pass
# Prompt-time injections appended to the user turn for the LLM must never be
# persisted in Redis (duplicate noise + wrong embeddings). Stripped in
# :meth:`MessageCache.log_message` before embed + write.
_CHANNEL_SEMANTIC_RECALL_BLOCK_RE = re.compile(
r"<channel_semantic_recall\b[^>]*>.*?</channel_semantic_recall\s*>",
re.DOTALL | re.IGNORECASE,
)
[docs]
def strip_llm_injection_artifacts_for_cache(text: str) -> str:
"""Remove prompt-injection blobs that must not be stored in the message cache.
Callers normally pass raw platform text; this guards against accidental
forwarding of augmented LLM-request content into :meth:`log_message`.
"""
if not text:
return text
t = _CHANNEL_SEMANTIC_RECALL_BLOCK_RE.sub("", text)
# Malformed / truncated XML: drop from the opening tag onward.
key = "<channel_semantic_recall"
lower = t.lower()
idx = lower.find(key)
if idx >= 0:
t = t[:idx]
return t.rstrip()
def _embed_to_bytes(embedding: list[float] | np.ndarray) -> bytes:
"""Pack an embedding vector into a little-endian FLOAT32 byte blob.
RediSearch HNSW vector fields store and compare raw FLOAT32 buffers, so
every embedding written to a ``msg:*`` hash or passed as the ``$query_vec``
KNN parameter must first be flattened to this binary form. Pure numpy
conversion with no I/O. Called by :meth:`CachedMessage.to_redis_hash` (when
serialising a message for storage) and by
:meth:`MessageCache.search_messages` (when turning a query embedding into a
search blob).
Args:
embedding (list[float] | np.ndarray): The embedding vector to encode.
Returns:
bytes: The vector as a contiguous FLOAT32 byte buffer.
"""
arr = np.array(embedding, dtype=np.float32)
return arr.tobytes()
def _bytes_to_embed(raw: Any) -> list[float]:
"""Decode a stored embedding back into a Python list of floats.
The inverse of :func:`_embed_to_bytes`, hardened against the several shapes
an embedding can take once it has round-tripped through Redis: a raw FLOAT32
binary buffer (the normal case), a JSON array string, or a corrupted
``b'...'`` / ``{...}`` repr that should be treated as empty. Tolerates bad
input by returning an empty list rather than raising. Pure decoding with no
I/O. Called by :meth:`CachedMessage.from_redis_hash` when reconstructing a
message from an ``HGETALL`` result.
Args:
raw (Any): The stored embedding value — ``bytes`` / ``bytearray`` /
``memoryview`` for a binary blob, a ``str`` for a JSON or
space-separated representation, or ``None``.
Returns:
list[float]: The decoded vector, or an empty list when the input is
missing, malformed, or not a valid FLOAT32-length buffer.
"""
if raw is None:
return []
# 1. Handle string/unicode representation directly
if isinstance(raw, str):
stripped = raw.strip()
if stripped.startswith("b'") or stripped.startswith('b"') or stripped.startswith("{"):
return []
if stripped.startswith("["):
try:
import json
parsed = json.loads(stripped)
if isinstance(parsed, list):
return [float(x) for x in parsed]
except Exception:
pass
# Fallback: space-separated values?
try:
return [float(x) for x in stripped.split()]
except Exception:
return []
# 2. Handle bytes, bytearray, or memoryview
if isinstance(raw, (bytes, bytearray, memoryview)):
# Check if it starts with "[" indicating a stringified JSON bytes representation
stripped = raw.strip()
if stripped.startswith(b"b'") or stripped.startswith(b'b"') or stripped.startswith(b"{"):
return []
if stripped.startswith(b"["):
try:
import json
parsed = json.loads(stripped.decode("utf-8", errors="ignore"))
if isinstance(parsed, list):
return [float(x) for x in parsed]
except Exception:
pass
# Otherwise, process as raw binary FLOAT32 array
if len(raw) < 4 or len(raw) % 4 != 0:
return []
try:
return np.frombuffer(raw, dtype=np.float32).tolist()
except Exception:
return []
return []
def _escape_tag(value: str) -> str:
"""Backslash-escape RediSearch TAG-query metacharacters in a value.
RediSearch treats characters like ``.``, ``|``, ``{`` and ``}`` as query
syntax, so any platform/channel/user identifier injected into a ``@field:
{value}`` filter must be escaped first or it will be mis-parsed (or let a
crafted id alter the query). Pure string transformation with no I/O. Called
by :meth:`MessageCache.search_messages` and
:meth:`MessageCache.get_recent_for_user` when building TAG filter clauses.
Args:
value (str): The raw tag value (e.g. a channel or user id) to escape.
Returns:
str: The value with every reserved character preceded by a backslash.
"""
out: list[str] = []
for ch in value:
if ch in r"\.+*?[^]$(){}=!<>|:-@&~\"'/":
out.append("\\")
out.append(ch)
return "".join(out)
[docs]
@dataclass
class CachedMessage:
"""A single cached chat message and its embedding, as stored in Redis.
The in-memory representation of one row in the message cache: speaker
identity, the channel/platform it belongs to, the text, a monotonic
timestamp, and the FLOAT32 embedding used for semantic search. Each
instance maps to a ``msg:{uuid}`` Redis hash and a membership in the
per-channel ``channel_msgs:{platform}:{channel_id}`` sorted set.
Construct one from raw JSON via :meth:`from_json` / :meth:`from_dict`, from
a Redis ``HGETALL`` mapping via :meth:`from_redis_hash`, or let
:class:`MessageCache` build and persist one for you in
:meth:`MessageCache.log_message`. Serialise back out with :meth:`to_dict`,
:meth:`to_json`, or :meth:`to_redis_hash`. Constructed directly in tests and
by :meth:`MessageCache._fetch_hashes` and
:meth:`MessageCache._doc_to_cached_message`.
"""
user_id: str
user_name: str
platform: str
channel_id: str
text: str
timestamp: float
embedding: list[float] = field(default_factory=list, repr=False)
message_key: str = ""
message_id: str = ""
"""Platform-specific message identifier (Discord message ID, Matrix event ID, etc.)."""
reply_to_id: str = ""
"""Platform-specific ID of the message this one replies to, if any."""
kind: str = "user_in"
"""``user_in`` for messages from humans, ``assistant_out`` for bot replies."""
turn_summary_id: str = ""
"""Link to the Stargazer-System-Log summary for this assistant turn."""
# -- Serialisation -----------------------------------------------------
[docs]
def to_dict(self) -> dict[str, Any]:
"""Serialise this message to a plain JSON-friendly dict.
Unlike :meth:`to_redis_hash`, this keeps the embedding as a list of
floats (not a binary blob) and omits the transient ``message_key``,
making it suitable for JSON transport and snapshotting. Pure in-memory
transform with no side effects. Used by :meth:`to_json` and as the
general-purpose dict form consumed across the codebase wherever a
message is round-tripped through JSON.
Returns:
dict[str, Any]: Field/value mapping with ``embedding`` as a list of
floats.
"""
return {
"user_id": self.user_id,
"user_name": self.user_name,
"platform": self.platform,
"channel_id": self.channel_id,
"text": self.text,
"timestamp": self.timestamp,
"embedding": self.embedding,
"message_id": self.message_id,
"reply_to_id": self.reply_to_id,
"kind": self.kind,
"turn_summary_id": self.turn_summary_id,
}
[docs]
def to_json(self) -> str:
"""Serialise this message to a JSON string.
Thin wrapper that ``json.dumps`` the output of :meth:`to_dict`, giving a
compact string form for transport or logging. The inverse is
:meth:`from_json`. Pure in-memory transform with no I/O.
Returns:
str: The JSON-encoded message.
"""
return json.dumps(self.to_dict())
[docs]
def to_redis_hash(self) -> dict[str, str | bytes | float]:
"""Produce the field mapping written to this message's Redis hash.
Renders the message into the exact shape stored under ``msg:{uuid}``:
scalars become strings, the embedding is packed to a FLOAT32 blob via
:func:`_embed_to_bytes`, and ``None`` values are coerced to empty
strings. An empty ``message_id`` is dropped from the mapping so a later
write cannot clobber a gateway-synced id that was filled in out of band.
Pure in-memory transform with no I/O. Called by
:meth:`MessageCache.log_message` to build the mapping handed to
:meth:`MessageCache._atomic_log_to_redis`, and exercised directly in
``tests/core/test_message_cache_sanitize.py`` and
``tests/core/test_message_id_resync.py``.
Returns:
dict[str, str | bytes | float]: The hash field mapping, with the
embedding as ``bytes`` and ``message_id`` omitted when empty.
"""
raw = {
"user_id": self.user_id,
"user_name": self.user_name,
"platform": self.platform,
"channel_id": self.channel_id,
"text": self.text,
"timestamp": str(self.timestamp),
"embedding": _embed_to_bytes(self.embedding),
"message_id": self.message_id,
"reply_to_id": self.reply_to_id,
"kind": self.kind,
"turn_summary_id": self.turn_summary_id,
}
sanitized = {k: (v if v is not None else "") for k, v in raw.items()}
# Omit message_id if empty/falsy to prevent overwriting gateway-synced message_ids
if not sanitized.get("message_id"):
sanitized.pop("message_id", None)
return sanitized
[docs]
@classmethod
def from_dict(cls, data: dict[str, Any]) -> CachedMessage:
"""Reconstruct a :class:`CachedMessage` from a plain dict.
The inverse of :meth:`to_dict`: coerces each field to its expected type
with permissive defaults (missing keys become empty strings, zero
timestamps, an empty embedding, and ``kind`` defaulting to ``user_in``)
so partial or legacy payloads still deserialise cleanly. Pure in-memory
transform with no I/O. Called by :meth:`from_json` and used directly by
callers that already hold a decoded dict.
Args:
data (dict[str, Any]): A dict produced by :meth:`to_dict` (or a
close equivalent).
Returns:
CachedMessage: The reconstructed message.
"""
return cls(
user_id=str(data.get("user_id", "")),
user_name=str(data.get("user_name", "")),
platform=str(data.get("platform", "")),
channel_id=str(data.get("channel_id", "")),
text=str(data.get("text", "")),
timestamp=float(data.get("timestamp", 0.0)),
embedding=data.get("embedding") or [],
message_id=str(data.get("message_id", "")),
reply_to_id=str(data.get("reply_to_id", "")),
kind=str(data.get("kind", "user_in") or "user_in"),
turn_summary_id=str(data.get("turn_summary_id", "")),
)
[docs]
@classmethod
def from_json(cls, raw: str) -> CachedMessage:
"""Reconstruct a :class:`CachedMessage` from a JSON string.
The inverse of :meth:`to_json`: parses the string with ``json.loads``
and delegates field coercion to :meth:`from_dict`. Pure in-memory
transform with no I/O.
Args:
raw (str): A JSON string produced by :meth:`to_json`.
Returns:
CachedMessage: The reconstructed message.
Raises:
json.JSONDecodeError: If *raw* is not valid JSON.
"""
return cls.from_dict(json.loads(raw))
[docs]
@classmethod
def from_redis_hash(
cls,
mapping: dict[str, Any],
key: str = "",
) -> CachedMessage:
"""Reconstruct a :class:`CachedMessage` from a Redis ``HGETALL`` mapping.
The inverse of :meth:`to_redis_hash`: coerces each stored field back to
its native type and decodes the binary embedding via
:func:`_bytes_to_embed`, but only when the raw blob looks like a real
vector (``bytes`` longer than 100 bytes) rather than a decoded-string
artifact, otherwise leaving the embedding empty. The originating Redis
key is preserved on the returned object as ``message_key``. Pure
in-memory transform with no I/O.
Args:
mapping (dict[str, Any]): The raw field/value mapping returned by an
``HGETALL`` on a ``msg:{uuid}`` hash.
key (str): The Redis key the hash was read from, stored on the
result as ``message_key``.
Returns:
CachedMessage: The reconstructed message.
"""
emb_raw = mapping.get("embedding")
if isinstance(emb_raw, bytes) and len(emb_raw) > 100:
embedding = _bytes_to_embed(emb_raw)
else:
embedding = []
return cls(
user_id=str(mapping.get("user_id", "")),
user_name=str(mapping.get("user_name", "")),
platform=str(mapping.get("platform", "")),
channel_id=str(mapping.get("channel_id", "")),
text=str(mapping.get("text", "")),
timestamp=float(mapping.get("timestamp", 0)),
embedding=embedding,
message_key=key,
message_id=str(mapping.get("message_id", "")),
reply_to_id=str(mapping.get("reply_to_id", "")),
kind=str(mapping.get("kind", "user_in") or "user_in"),
turn_summary_id=str(mapping.get("turn_summary_id", "")),
)
# -- Representation ----------------------------------------------------
@property
def repr(self) -> str:
"""Render a human-readable ``[time] user: text`` line for this message.
Formats the UTC timestamp and truncates the body to ~120 characters,
producing the compact form used when injecting recalled messages into
the LLM prompt or writing them to logs (distinct from :meth:`__repr__`,
which is the developer-facing debug form). Pure in-memory formatting
with no I/O.
Returns:
str: A one-line ``[YYYY-MM-DD HH:MM:SS] user_name: preview`` string.
"""
dt = datetime.fromtimestamp(self.timestamp, tz=timezone.utc)
ts_str = dt.strftime("%Y-%m-%d %H:%M:%S")
preview = self.text if len(self.text) <= 120 else self.text[:117] + "..."
return f"[{ts_str}] {self.user_name}: {preview}"
[docs]
def __repr__(self) -> str:
"""Return a developer-facing debug representation of this message.
Summarises the identifying fields plus the embedding length (rather than
the full vector) so log lines and debugger output stay readable. This is
the standard ``repr()`` form; the prompt/log-friendly one-liner lives in
the :attr:`repr` property. Pure in-memory formatting with no I/O.
Returns:
str: A ``CachedMessage(...)`` debug string.
"""
return (
f"CachedMessage(user={self.user_name!r}, platform={self.platform!r}, "
f"channel={self.channel_id!r}, ts={self.timestamp}, "
f"embedding_len={len(self.embedding)})"
)
[docs]
class MessageCache:
"""Async Redis message cache with automatic embedding generation
and RediSearch-backed vector search.
Parameters
----------
redis_url:
Redis connection URL (e.g. ``"redis://localhost:6379/0"``).
openrouter_client:
Shared :class:`OpenRouterClient` used to call the embeddings API.
embedding_model:
Model identifier for the embeddings endpoint
(e.g. ``"google/gemini-embedding-001"``).
"""
[docs]
def __init__(
self,
redis_url: str,
openrouter_client: OpenRouterClient,
embedding_model: str = "google/gemini-embedding-001",
ssl_kwargs: dict | None = None,
redis_sentinels: list[str] | None = None,
redis_sentinel_master: str = "falkordb",
resilience_kwargs: dict | None = None,
) -> None:
"""Initialize the instance.
Args:
redis_url (str): The redis url value.
openrouter_client (OpenRouterClient): The openrouter client value.
embedding_model (str): The embedding model value.
ssl_kwargs (dict | None): Optional SSL/mTLS keyword arguments
forwarded to ``redis.asyncio.from_url`` or ``Sentinel``.
redis_sentinels (list[str] | None): Optional list of sentinel hosts.
redis_sentinel_master (str): The sentinel master name.
resilience_kwargs (dict | None): Optional retry/backoff/health-check
keyword arguments (see ``Config.redis_resilience_kwargs``) so
reads/writes ride through a Sentinel failover window instead of
raising on the first attempt.
"""
_ssl = ssl_kwargs or {}
if _ssl:
_ssl = dict(_ssl)
_ssl["ssl"] = True
_resil = dict(resilience_kwargs or {})
if redis_sentinels:
sentinels = []
for s in redis_sentinels:
parts = s.split(":")
if len(parts) == 2:
sentinels.append((parts[0], int(parts[1])))
else:
sentinels.append((parts[0], 26379))
sentinel = Sentinel(
sentinels,
sentinel_kwargs=_ssl,
**{**_ssl, **_resil},
)
self._redis = sentinel.master_for(
redis_sentinel_master,
decode_responses=True,
**_resil,
)
self._redis_raw = sentinel.master_for(
redis_sentinel_master,
decode_responses=False,
**_resil,
)
else:
self._redis: aioredis.Redis = aioredis.from_url(
redis_url,
decode_responses=True,
**{**_ssl, **_resil},
)
self._redis_raw: aioredis.Redis = aioredis.from_url(
redis_url,
decode_responses=False,
**{**_ssl, **_resil},
)
self._openrouter = openrouter_client
self._embedding_model = embedding_model
self._lua_log_script: str | None = None
async def _monotonic_channel_score(
self,
platform: str,
channel_id: str,
timestamp: float,
) -> float:
"""Compute a strictly increasing sorted-set score for a channel message.
The per-channel ``channel_msgs:*`` zset orders messages by time, but two
deliveries can share the same millisecond, which would collide on score
and corrupt ordering. This widens the millisecond timestamp by 1000 and
adds a per-channel ``INCR`` sequence (mod 1000) so each message in a
channel gets a unique, monotonically increasing score even under
concurrent logging. ``INCR`` and ``EXPIRE`` the
``sg:channel_seq:{platform}:{channel_id}`` counter (built by
:meth:`_channel_seq_key`) in Redis on every call. Called only by
:meth:`log_message`.
Args:
platform (str): Platform identifier (e.g. ``"discord"``).
channel_id (str): Channel identifier.
timestamp (float): The message's Unix timestamp in seconds.
Returns:
float: A strictly increasing zset score for this channel.
"""
seq_key = self._channel_seq_key(platform, channel_id)
seq = int(await self._redis.incr(seq_key))
await self._redis.expire(seq_key, TTL_SECONDS)
base = int(timestamp * 1000) * 1000
return float(base + (seq % 1000))
async def _load_log_lua(self) -> str:
"""Load and memoize the atomic message-log Lua script source.
Reads ``atomic_log_message.lua`` from disk on first call and caches the
text in ``self._lua_log_script`` so subsequent calls avoid further file
I/O. If the file cannot be read, an empty string is cached so the caller
can detect the absence and fall back to a non-atomic pipeline.
Reads the on-disk ``atomic_log_message.lua`` file (the same script that
``_atomic_log_to_redis`` later ``EVAL``\\ s in Redis). Called only by
:meth:`_atomic_log_to_redis`, which in turn is invoked from
:meth:`log_message`.
Returns:
str: The Lua script source, or an empty string when the file is
missing or unreadable.
"""
if self._lua_log_script is None:
try:
with open("atomic_log_message.lua", "r") as f:
self._lua_log_script = f.read()
except Exception:
self._lua_log_script = ""
return self._lua_log_script
async def _atomic_log_to_redis(
self,
*,
key: str,
zset_key: str,
idem_key: str,
score: float,
mapping: dict[str, str],
) -> str:
"""Atomically write a message hash and index it in the channel zset.
Persists the message hash, sets its TTL, adds the key to the per-channel
sorted set, and registers the channel in ``sg:active_channels`` as a
single atomic unit. When an *idem_key* is supplied, the operation also
claims it ``SET NX`` so concurrent workers logging the same platform
delivery converge on one ``msg:*`` row rather than creating duplicates;
if the key was already claimed, the previously stored key is returned
and no new row is written.
Loads the script via :meth:`_load_log_lua` and, when present, runs it
through ``self._redis.eval`` (Redis ``EVAL``) — passing 3 KEYS (the msg
hash key, the zset key, the idem key) plus the TTL, score, field count,
and the flattened field/value pairs (binary embedding values are passed
through as-is). If the Lua source is unavailable it falls back to a
non-atomic ``self._redis.pipeline()`` doing ``HSET`` + ``EXPIRE`` +
``ZADD`` + ``EXPIRE`` and ``SADD`` to ``sg:active_channels`` (this
fallback path does not honor *idem_key*). Called only by
:meth:`log_message`; also stubbed in
``tests/test_context_hardening.py``.
Args:
key (str): Target Redis hash key for the message (``msg:{uuid}``).
zset_key (str): Per-channel sorted-set key
(``channel_msgs:{platform}:{channel_id}``) the message is added
to.
idem_key (str): Idempotency key (``msgid:{platform}:{channel_id}:
{message_id}``); empty string skips the claim step.
score (float): Monotonic zset score positioning the message in time.
mapping (dict[str, str]): Field/value pairs to write to the hash;
``bytes`` values (e.g. the embedding blob) are preserved.
Returns:
str: The Redis key the message was stored under — either *key*, or
the pre-existing key when an idempotent claim already existed.
"""
script = await self._load_log_lua()
if not script:
pipe = self._redis.pipeline()
pipe.hset(key, mapping=mapping)
pipe.expire(key, TTL_SECONDS)
pipe.zadd(zset_key, {key: score})
pipe.expire(zset_key, TTL_SECONDS)
# Register active channel in the cache set
parts = zset_key.split(":", 2)
if len(parts) >= 3:
pipe.sadd("sg:active_channels", f"{parts[1]}:{parts[2]}")
pipe.expire("sg:active_channels", TTL_SECONDS)
await pipe.execute()
return key
flat: list[str | bytes] = []
for field, value in mapping.items():
if isinstance(value, bytes):
flat.extend([str(field), value])
else:
flat.extend([str(field), str(value)])
result = await self._redis.eval(
script,
3,
key,
zset_key,
idem_key or "",
str(TTL_SECONDS),
str(score),
str(len(mapping)),
*flat,
)
if isinstance(result, bytes):
result = result.decode("utf-8")
return str(result)
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
[docs]
async def log_message(
self,
platform: str,
channel_id: str,
user_id: str,
user_name: str,
text: str,
timestamp: float | None = None,
embedding: list[float] | None = None,
defer_embedding: bool = False,
message_id: str = "",
reply_to_id: str = "",
kind: str = "user_in",
turn_summary_id: str = "",
message_key: str = "",
) -> CachedMessage:
"""Log a message to Redis with an embedding and a 90-day TTL.
*text* is passed through :func:`strip_llm_injection_artifacts_for_cache`
so prompt-only suffixes (e.g. channel semantic recall XML) are never
stored or embedded. If a caller-supplied *embedding* was computed from
unstripped text, it is dropped and recomputed from the cleaned *text*.
The message is stored as a Redis hash (``msg:{uuid}``) with the
embedding as a binary FLOAT32 blob, automatically indexed by the
``idx:messages`` RediSearch index.
If you add the ``kind`` TAG field to an existing RediSearch index,
run ``python init_redis_indexes.py`` (or rely on bot startup) so
``ALTER`` can add ``kind``; otherwise recreate ``idx:messages``.
Parameters
----------
embedding:
Pre-computed embedding vector. When provided the internal
embedding API call is skipped, saving a round-trip.
defer_embedding:
When ``True`` **and** no pre-computed *embedding* is given,
store a zero-vector placeholder instead of calling the
embeddings API. The caller is responsible for enqueuing the
returned ``message_key`` in an :class:`EmbeddingBatchQueue`
so the real embedding is written later.
Returns
-------
CachedMessage
The message object that was written to Redis (includes the
embedding vector, or an empty list when deferred).
"""
_original_text = text
text = strip_llm_injection_artifacts_for_cache(text)
if embedding is not None and text != _original_text:
embedding = None
if timestamp is None:
timestamp = time.time()
if embedding is None:
if defer_embedding or not text or not text.strip():
embedding = [0.0] * VECTOR_DIM
else:
embedding = await self._openrouter.embed(
text,
self._embedding_model,
)
message = CachedMessage(
user_id=user_id,
user_name=user_name,
platform=platform,
channel_id=channel_id,
text=text,
timestamp=timestamp,
embedding=embedding,
message_id=message_id,
reply_to_id=reply_to_id,
kind=kind or "user_in",
turn_summary_id=turn_summary_id,
)
# Idempotency: when a platform message_id is known (and the caller did
# not pin an explicit message_key), claim a per-message-id key so two
# concurrent workers logging the same delivery converge on ONE msg:*
# row instead of creating duplicates in the channel zset.
key = message_key
idem_key = ""
if not key and message_id:
idem_key = f"msgid:{platform}:{channel_id}:{message_id}"
key = f"msg:{_uuid.uuid4()}"
if not key:
key = f"msg:{_uuid.uuid4()}"
message.message_key = key
mapping = message.to_redis_hash()
zset_key = self._channel_zset_key(platform, channel_id)
score = await self._monotonic_channel_score(platform, channel_id, timestamp)
stored_key = await self._atomic_log_to_redis(
key=key,
zset_key=zset_key,
idem_key=idem_key,
score=score,
mapping=mapping,
)
message.message_key = stored_key
return message
[docs]
async def get_recent(
self,
platform: str,
channel_id: str,
count: int = 50,
) -> list[CachedMessage]:
"""Return the *count* most recent cached messages for a channel.
The primary "recent history" read used when assembling the LLM prompt
and by the heartbeat/proactive/backfill paths. ``ZREVRANGE`` the
per-channel ``channel_msgs:{platform}:{channel_id}`` sorted set for the
newest keys, then hydrate them via :meth:`_fetch_hashes`; this O(log N)
path never touches the RediSearch module, so recent history stays
available even if the vector index is unavailable. Results come back
newest-first.
Called widely, including from :mod:`prompt_context`,
``message_processor.proactive_gates``,
``message_processor.history_backfill``,
``message_processor.channel_heartbeat``, and ``build_kg``.
Args:
platform (str): Platform identifier (e.g. ``"discord"``).
channel_id (str): Channel identifier.
count (int): Maximum number of messages to return.
Returns:
list[CachedMessage]: Up to *count* messages, newest first; empty
when the channel has no cached messages.
"""
zset_key = self._channel_zset_key(platform, channel_id)
msg_keys: list[str] = await self._redis.zrevrange(zset_key, 0, count - 1)
if not msg_keys:
return []
return await self._fetch_hashes(zset_key, msg_keys)
[docs]
async def get_by_timerange(
self,
platform: str,
channel_id: str,
start: float,
end: float,
) -> list[CachedMessage]:
"""Return cached messages whose timestamp falls within a range.
Time-window read over the per-channel
``channel_msgs:{platform}:{channel_id}`` sorted set. Because zset scores
are the widened monotonic values from :meth:`_monotonic_channel_score`
(not raw seconds), bounds supplied as plain Unix seconds (heuristically,
anything below ``1e12``) are scaled up by 1_000_000 to match. Runs a
``ZRANGEBYSCORE`` and hydrates the keys via :meth:`_fetch_hashes`;
results are ascending by timestamp. Exercised in
``tests/core/test_message_cache_ts_alignment.py``.
Args:
platform (str): Platform identifier (e.g. ``"discord"``).
channel_id (str): Channel identifier.
start (float): Inclusive lower bound, Unix seconds (or a pre-scaled
zset score).
end (float): Inclusive upper bound, Unix seconds (or a pre-scaled
zset score).
Returns:
list[CachedMessage]: Matching messages ascending by timestamp; empty
when none fall in the range.
"""
# Scale raw Unix timestamps (seconds) to monotonic zset scores if unscaled
if start is not None and start < 1e12:
start = start * 1_000_000
if end is not None and end < 1e12:
end = end * 1_000_000
zset_key = self._channel_zset_key(platform, channel_id)
msg_keys: list[str] = await self._redis.zrangebyscore(
zset_key,
min=start,
max=end,
)
if not msg_keys:
return []
return await self._fetch_hashes(zset_key, msg_keys)
[docs]
async def get_messages_after(
self,
platform: str,
channel_id: str,
after_ts_exclusive: float | None,
*,
zset_batch: int = 5000,
) -> list[CachedMessage]:
"""Return channel messages strictly newer than a timestamp floor.
The "everything since X" read used to gather context the model has not
yet seen. Paginates ``ZRANGEBYSCORE`` over the per-channel
``channel_msgs:{platform}:{channel_id}`` sorted set in *zset_batch*-sized
windows (using Redis exclusive lower-bound syntax ``(score`` so the floor
itself is excluded), then hydrates the collected keys via
:meth:`_fetch_hashes` in batches of 400. As with
:meth:`get_by_timerange`, a floor given in plain Unix seconds (below
``1e12``) is scaled up to a monotonic zset score; ``None`` means "from
the beginning". Called by ``openrouter_client.executor`` and exercised in
``tests/core/test_message_cache_ts_alignment.py``.
Args:
platform (str): Platform identifier (e.g. ``"discord"``).
channel_id (str): Channel identifier.
after_ts_exclusive (float | None): Exclusive lower bound in Unix
seconds (or a pre-scaled score); ``None`` returns all messages.
zset_batch (int): Page size for each ``ZRANGEBYSCORE`` scan.
Returns:
list[CachedMessage]: Matching messages ascending by timestamp; empty
when none qualify.
"""
# Scale raw Unix timestamp (seconds) to monotonic zset score if unscaled
if after_ts_exclusive is not None and after_ts_exclusive < 1e12:
after_ts_exclusive = after_ts_exclusive * 1_000_000
zset_key = self._channel_zset_key(platform, channel_id)
min_score = "-inf" if after_ts_exclusive is None else f"({after_ts_exclusive}"
msg_keys: list[str] = []
offset = 0
while True:
batch: list[Any] = await self._redis.zrangebyscore(
zset_key,
min_score,
"+inf",
start=offset,
num=zset_batch,
)
if not batch:
break
msg_keys.extend(
k.decode() if isinstance(k, bytes) else str(k) for k in batch
)
offset += len(batch)
if len(batch) < zset_batch:
break
if not msg_keys:
return []
hm_batch = 400
messages: list[CachedMessage] = []
for i in range(0, len(msg_keys), hm_batch):
part = msg_keys[i : i + hm_batch]
messages.extend(await self._fetch_hashes(zset_key, part))
return messages
[docs]
async def update_text_by_message_id(
self,
platform: str,
channel_id: str,
message_id: str,
new_text: str,
) -> str | None:
"""Overwrite a cached message's text, located by platform message id.
Supports the edited-message path: when a user edits a message on the
platform, the gateway resyncs the new body here. Resolves the message's
Redis key via :meth:`find_key_by_message_id` (idempotency-index lookup,
falling back to a channel-zset scan) and, on a hit, ``HSET``\\ s the
``text`` field on that ``msg:*`` hash. The embedding is intentionally
left untouched. Called by ``message_processor.processor`` on edit
events.
Args:
platform (str): Platform identifier (e.g. ``"discord"``).
channel_id (str): Channel identifier.
message_id (str): Platform-specific id of the edited message.
new_text (str): The replacement text to store.
Returns:
str | None: The Redis key that was updated, or ``None`` if no
matching message was found.
"""
target_key = await self.find_key_by_message_id(
platform,
channel_id,
message_id,
)
if target_key is None:
return None
await self._redis.hset(target_key, "text", new_text)
return target_key
[docs]
async def find_key_by_message_id(
self,
platform: str,
channel_id: str,
message_id: str,
) -> str | None:
"""Resolve the Redis key of a cached message by its platform message id.
Tries the fast path first — ``GET`` the idempotency index
``msgid:{platform}:{channel_id}:{message_id}`` written at log time and
confirm the pointed-at hash still ``EXISTS`` — then falls back to a
bounded ``ZREVRANGE`` scan (windows of 500, up to 5000 keys) of the
per-channel sorted set, pipelining ``HGET ... message_id`` to find the
match. Reads only; index-lookup failures are swallowed and logged at
debug. Called by :meth:`update_text_by_message_id`,
:meth:`find_keys_by_message_ids`, and ``tools.xray_tool``.
Args:
platform (str): Platform identifier (e.g. ``"discord"``).
channel_id (str): Channel identifier.
message_id (str): Platform-specific message id to resolve.
Returns:
str | None: The matching Redis key (e.g. ``"msg:abc-123"``), or
``None`` if not found.
"""
idem_key = f"msgid:{platform}:{channel_id}:{message_id}"
try:
indexed = await self._redis.get(idem_key)
if indexed:
if isinstance(indexed, bytes):
indexed = indexed.decode("utf-8")
exists = await self._redis.exists(indexed)
if exists:
return indexed
except Exception:
logger.debug(
"msgid index lookup failed for %s",
message_id,
exc_info=True,
)
zset_key = self._channel_zset_key(platform, channel_id)
offset = 0
window = 500
while offset < 5000:
msg_keys: list[str] = await self._redis.zrevrange(
zset_key,
offset,
offset + window - 1,
)
if not msg_keys:
break
pipe = self._redis.pipeline()
for key in msg_keys:
pipe.hget(key, "message_id")
results = await pipe.execute()
for key, stored_id in zip(msg_keys, results):
if stored_id is not None and str(stored_id) == message_id:
return key
if len(msg_keys) < window:
break
offset += window
return None
[docs]
async def find_keys_by_message_ids(
self,
platform: str,
channel_id: str,
message_ids: list[str],
) -> dict[str, str]:
"""Batch version of :meth:`find_key_by_message_id`.
Returns a mapping of ``{message_id: redis_key}`` for every
*message_id* that was found in Redis. Missing IDs are omitted.
Fetches the channel zset once and builds the full lookup map in
a single pipeline, instead of repeating the scan per message.
"""
if not message_ids:
return {}
wanted = set(message_ids)
found: dict[str, str] = {}
for mid in message_ids:
key = await self.find_key_by_message_id(platform, channel_id, mid)
if key:
found[mid] = key
return found
[docs]
async def has_real_embedding(self, redis_key: str) -> bool:
"""Report whether a message hash holds a real (non-placeholder) embedding.
Messages logged with ``defer_embedding=True`` (or with empty text) get a
zero-vector placeholder so they can be backfilled later; this lets the
embedding backfill workers tell which ``msg:*`` rows still need a real
vector. ``HGET``\\ s the ``embedding`` field over the raw,
non-``decode_responses`` Redis connection (``self._redis_raw``) so the
binary FLOAT32 bytes are not mangled by UTF-8 decoding, then checks the
blob is full length and not all zeros. Read-only. See
:meth:`has_real_embedding_many` for the pipelined batch form.
Args:
redis_key (str): The ``msg:{uuid}`` key to inspect.
Returns:
bool: ``True`` if the stored embedding is full-length and non-zero;
``False`` if missing, too short, or an all-zero placeholder.
"""
raw: bytes | None = await self._redis_raw.hget(redis_key, "embedding")
if raw is None or len(raw) < VECTOR_DIM * 4:
return False
return raw != b"\x00" * len(raw)
[docs]
async def has_real_embedding_many(
self,
redis_keys: list[str],
) -> list[bool]:
"""Batch-check which message hashes hold a real embedding.
The pipelined form of :meth:`has_real_embedding`, used by the embedding
backfill paths to filter a whole batch of keys in one Redis round trip
instead of one call each. Pipelines ``HGET ... embedding`` for every key
over the raw, non-``decode_responses`` connection (``self._redis_raw``)
so binary FLOAT32 data survives intact, then applies the same
full-length-and-non-zero test per key. Read-only. Called from
``background_tasks`` and ``message_processor.history_backfill``.
Args:
redis_keys (list[str]): The ``msg:{uuid}`` keys to inspect.
Returns:
list[bool]: One flag per input key (positionally aligned), ``True``
where the embedding is real and ``False`` where it is missing, short,
or a zero placeholder.
"""
if not redis_keys:
return []
pipe = self._redis_raw.pipeline()
for key in redis_keys:
pipe.hget(key, "embedding")
results = await pipe.execute()
min_len = VECTOR_DIM * 4
zero_blob = b"\x00" * min_len
out: list[bool] = []
for raw in results:
if raw is None or len(raw) < min_len:
out.append(False)
else:
out.append(raw != zero_blob)
return out
[docs]
async def mark_deleted_by_message_id(
self,
platform: str,
channel_id: str,
message_id: str,
deleted_at_iso: str,
) -> bool:
"""Tombstone a cached message by prefixing a deletion marker to its text.
Supports the platform delete path: when a user deletes a message, the
gateway calls this so cached/recalled history shows it was removed rather
than silently dropping it (the original text is kept after the marker).
``ZREVRANGE``\\ s the most recent 200 keys of the per-channel
``channel_msgs:{platform}:{channel_id}`` sorted set, pipelines
``HGET ... message_id`` to find the match, and — if the marker is not
already present — ``HSET``\\ s the ``text`` field with a leading
``[deleted at ...]`` tag. Called by ``message_processor.processor`` on
delete events.
Args:
platform (str): Platform identifier (e.g. ``"discord"``).
channel_id (str): Channel identifier.
message_id (str): Platform-specific id of the deleted message.
deleted_at_iso (str): ISO-8601 deletion timestamp embedded in the
marker.
Returns:
bool: ``True`` if the message was found (and tagged, unless already
tagged); ``False`` if no match was in the recent window.
"""
zset_key = self._channel_zset_key(platform, channel_id)
msg_keys: list[str] = await self._redis.zrevrange(zset_key, 0, 199)
if not msg_keys:
return False
pipe = self._redis.pipeline()
for key in msg_keys:
pipe.hget(key, "message_id")
results = await pipe.execute()
target_key: str | None = None
for key, stored_id in zip(msg_keys, results):
if stored_id is not None and str(stored_id) == message_id:
target_key = key
break
if target_key is None:
return False
current_text = await self._redis.hget(target_key, "text")
tag = f"[deleted at {deleted_at_iso}] "
if current_text and tag not in str(current_text):
new_text = tag + str(current_text)
await self._redis.hset(target_key, "text", new_text)
return True
[docs]
async def search_messages(
self,
query: str,
limit: int = 10,
channel_id: str | None = None,
platform: str | None = None,
user_id: str | None = None,
*,
channel_ids: list[str] | None = None,
min_timestamp: float | None = None,
query_embedding: list[float] | np.ndarray | None = None,
) -> list[dict[str, Any]]:
"""Semantic search across cached messages using RediSearch KNN.
Generates an embedding for *query* unless *query_embedding* is
provided, then runs vector similarity search filtered by optional
*channel_id*, *platform*, and/or *user_id*.
When *channel_ids* is provided, the filter becomes an OR over the
listed channel IDs (RediSearch ``@channel_id:{c1|c2|c3}`` syntax)
and *channel_id* must be ``None``. Each id is escaped individually.
When *min_timestamp* is set (Unix seconds), only messages with
``timestamp >= min_timestamp`` are considered. In that case
*platform* must be set, plus either *channel_id* or *channel_ids*
(strict per-channel(s) recall).
Returns a list of dicts with ``text``, ``user_name``,
``redis_key``, ``timestamp``, ``similarity``, ``channel_id``, and
``platform``.
"""
if channel_id is not None and channel_ids:
logger.warning(
"search_messages: channel_id and channel_ids are mutually "
"exclusive; ignoring channel_ids",
)
channel_ids = None
if channel_ids:
_seen: set[str] = set()
_normalized: list[str] = []
for _cid in channel_ids:
_c = str(_cid).strip()
if not _c or _c in _seen:
continue
_seen.add(_c)
_normalized.append(_c)
channel_ids = _normalized or None
if query_embedding is None:
if not query or not query.strip():
return []
embedding = await self._openrouter.embed(
query,
self._embedding_model,
)
query_blob = _embed_to_bytes(embedding)
else:
arr = np.asarray(query_embedding, dtype=np.float32).reshape(-1)
if arr.size != VECTOR_DIM:
logger.warning(
"search_messages: query_embedding dim %d != %d",
arr.size,
VECTOR_DIM,
)
return []
query_blob = _embed_to_bytes(arr)
if min_timestamp is not None:
if not platform or not (channel_id or channel_ids):
logger.warning(
"search_messages: min_timestamp requires platform and "
"(channel_id or channel_ids)",
)
return []
pre_filter = "*"
parts: list[str] = []
if platform:
parts.append(f"@platform:{{{_escape_tag(platform)}}}")
if channel_id:
parts.append(f"@channel_id:{{{_escape_tag(channel_id)}}}")
elif channel_ids:
_joined = "|".join(_escape_tag(c) for c in channel_ids)
parts.append(f"@channel_id:{{{_joined}}}")
if user_id:
parts.append(f"@user_id:{{{_escape_tag(user_id)}}}")
if min_timestamp is not None:
parts.append(f"@timestamp:[{min_timestamp} +inf]")
if parts:
pre_filter = " ".join(parts)
knn = (
f"({pre_filter})=>"
f"[KNN {limit} @embedding $query_vec "
f"EF_RUNTIME {DEFAULT_KNN_EF_RUNTIME} AS score]"
)
q = (
Query(knn)
.sort_by("score")
.return_fields(
"user_id",
"user_name",
"platform",
"channel_id",
"text",
"timestamp",
"score",
)
.paging(0, limit)
.dialect(2)
)
result = await self._redis.ft(MSG_INDEX_NAME).search(
q,
query_params={"query_vec": query_blob},
)
out: list[dict[str, Any]] = []
for doc in result.docs:
score_raw = getattr(doc, "score", None)
cosine_dist = float(score_raw) if score_raw is not None else 1.0
similarity = 1.0 - cosine_dist
_id = getattr(doc, "id", None)
redis_key = (
_id.decode("utf-8", errors="replace")
if isinstance(_id, bytes)
else str(_id or "")
)
out.append(
{
"redis_key": redis_key,
"text": self._doc_str(doc, "text"),
"user_name": self._doc_str(doc, "user_name"),
"user_id": self._doc_str(doc, "user_id"),
"channel_id": self._doc_str(doc, "channel_id"),
"platform": self._doc_str(doc, "platform"),
"timestamp": float(self._doc_str(doc, "timestamp") or 0),
"similarity": round(similarity, 4),
}
)
return out
[docs]
async def get_messages_around_key(
self,
platform: str,
channel_id: str,
redis_key: str,
before: int = 5,
after: int = 5,
) -> list[CachedMessage]:
"""Return a chronological window of messages centered on one key.
Powers semantic-recall context expansion: after a vector hit identifies a
single relevant message, this widens it to the surrounding conversation
so the recalled snippet reads coherently. Looks up the key's ``ZRANK``
in the per-channel ``channel_msgs:{platform}:{channel_id}`` sorted set,
``ZRANGE``\\ s the rank window ``[rank - before, rank + after]``, and
hydrates the keys via :meth:`_fetch_hashes`. If the key is absent from
the zset but its hash still ``EXISTS``, the single message is returned;
if it is gone entirely, an empty list is returned. Called by
``message_processor.channel_semantic_recall``.
Args:
platform (str): Platform identifier (e.g. ``"discord"``).
channel_id (str): Channel identifier.
redis_key (str): The ``msg:{uuid}`` key to center the window on.
before (int): Number of preceding messages to include.
after (int): Number of following messages to include.
Returns:
list[CachedMessage]: Up to ``before + 1 + after`` messages ascending
by time; possibly just the one message, or empty if the key is gone.
"""
if not redis_key:
return []
zset_key = self._channel_zset_key(platform, channel_id)
rank = await self._redis.zrank(zset_key, redis_key)
if rank is None:
exists = await self._redis.exists(redis_key)
if exists:
return await self._fetch_hashes(zset_key, [redis_key])
return []
start = max(0, int(rank) - before)
stop = int(rank) + after
msg_keys_raw = await self._redis.zrange(zset_key, start, stop)
msg_keys: list[str] = [
k.decode() if isinstance(k, bytes) else str(k) for k in (msg_keys_raw or [])
]
if not msg_keys:
return []
return await self._fetch_hashes(zset_key, msg_keys)
[docs]
async def get_recent_for_user(
self,
platform: str,
user_id: str,
limit: int = 20,
) -> list[CachedMessage]:
"""Return a user's most recent cached messages across all channels.
Unlike the channel-scoped reads, this finds a user's history without
knowing where they spoke, by running an ``FT.SEARCH`` over the
``idx:messages`` RediSearch index with TAG filters on platform and
user_id (escaped via :func:`_escape_tag`), sorted by timestamp
descending. Results are materialised through
:meth:`_doc_to_cached_message` (embeddings are not returned). Backs
:meth:`get_recent_speaker_channels` and several tools
(``tools.xray_tool``, ``tools.dm_history``, ``tools.chat_analytics``,
``tools.gravimetric_telescope``).
Args:
platform (str): Platform identifier (e.g. ``"discord"``).
user_id (str): The speaker's platform user id.
limit (int): Maximum number of messages to return.
Returns:
list[CachedMessage]: The user's messages newest-first, across every
channel (including DMs).
"""
tag_filter = (
f"@platform:{{{_escape_tag(platform)}}} "
f"@user_id:{{{_escape_tag(user_id)}}}"
)
q = (
Query(tag_filter)
.sort_by("timestamp", asc=False)
.return_fields(
"user_id",
"user_name",
"platform",
"channel_id",
"text",
"timestamp",
"message_id",
"reply_to_id",
"kind",
)
.paging(0, limit)
.dialect(2)
)
result = await self._redis.ft(MSG_INDEX_NAME).search(q)
return [self._doc_to_cached_message(doc) for doc in result.docs]
[docs]
async def get_recent_speaker_channels(
self,
platform: str,
user_id: str,
limit: int = 10,
lookback: int = 300,
) -> list[str]:
"""Return the speaker's most-recently-used channel IDs (MRU order).
Walks the last *lookback* messages this user sent on *platform*
(newest first) and returns up to *limit* unique channel IDs in
most-recent-first order.
Privacy-by-construction: the result only contains channels where
*user_id* has actually sent at least one message — i.e. channels
the speaker demonstrably had access to at write time. Other
users' DMs and channels the speaker only lurked in cannot leak
through this function.
"""
if not platform or not user_id:
return []
if limit <= 0:
return []
try:
messages = await self.get_recent_for_user(
platform=platform,
user_id=user_id,
limit=max(1, int(lookback)),
)
except Exception:
logger.debug(
"get_recent_speaker_channels: get_recent_for_user failed for " "%s:%s",
platform,
user_id,
exc_info=True,
)
return []
seen: set[str] = set()
ordered: list[str] = []
for m in messages:
cid = (m.channel_id or "").strip()
if not cid or cid in seen:
continue
seen.add(cid)
ordered.append(cid)
if len(ordered) >= limit:
break
return ordered
# ------------------------------------------------------------------
# Channel metadata cache
#
# Lightweight ``channel_meta:{platform}:{channel_id}`` HASH carrying
# human-readable channel/guild names so cross-channel recall windows
# can be tagged "this is FROM #design in MyGuild, not the current
# room". Population is fire-and-forget at message-ingest time; reads
# are batched at recall time. TTL matches the message cache (90d).
# ------------------------------------------------------------------
@staticmethod
def _channel_meta_key(platform: str, channel_id: str) -> str:
"""Build the Redis hash key for a channel's metadata record.
Used by :meth:`record_channel_metadata` (writes) and
:meth:`get_channel_metadata_many` (batched reads) to locate the
lightweight ``channel_meta:{platform}:{channel_id}`` hash that carries
human-readable channel/guild names and the ``is_dm`` flag.
Args:
platform (str): Platform identifier (e.g. ``"discord"``).
channel_id (str): Channel identifier.
Returns:
str: The ``channel_meta:{platform}:{channel_id}`` key.
"""
return f"channel_meta:{platform}:{channel_id}"
# ------------------------------------------------------------------
# Thought-summary storage
# ------------------------------------------------------------------
[docs]
async def log_thought_summary(
self,
channel_id: str,
thought_text: str,
timestamp: float | None = None,
) -> None:
"""Store a thought summary in a per-channel Redis sorted set.
Deduplicates by comparing against the most recent entry — if
it has the same content the new entry is silently skipped.
Parameters
----------
channel_id:
Platform-agnostic channel identifier.
thought_text:
The raw text extracted from ``<thought>`` tags.
timestamp:
Unix timestamp; defaults to ``time.time()``.
"""
if not thought_text or not thought_text.strip():
return
if timestamp is None:
timestamp = time.time()
zset_key = f"thought_summaries:{channel_id}"
# Deduplicate against the most recent entry
try:
recent = await self._redis.zrevrange(zset_key, 0, 0)
if recent:
last_data = json.loads(recent[0])
if last_data.get("content") == thought_text:
logger.debug(
"Skipping duplicate thought summary for channel %s",
channel_id,
)
return
except Exception:
logger.debug(
"Thought summary deduplication check failed, storing anyway",
exc_info=True,
)
entry = json.dumps(
{
"content": thought_text,
"timestamp": timestamp,
}
)
pipe = self._redis.pipeline()
pipe.zadd(zset_key, {entry: timestamp})
pipe.expire(zset_key, TTL_SECONDS)
await pipe.execute()
logger.debug(
"Stored thought summary in Redis for channel %s",
channel_id,
)
[docs]
async def get_recent_thought_summaries(
self,
channel_id: str,
limit: int = 30,
) -> list[dict[str, Any]]:
"""Retrieve unique thought summaries for a channel (newest first).
Parameters
----------
channel_id:
Platform-agnostic channel identifier.
limit:
Maximum number of unique summaries to return.
Returns
-------
list[dict]
Each dict has ``content`` (str) and ``timestamp`` (float).
"""
zset_key = f"thought_summaries:{channel_id}"
try:
raw_entries: list[str] = await self._redis.zrevrange(
zset_key,
0,
limit * 3,
)
except Exception:
logger.debug(
"Failed to retrieve thought summaries for channel %s",
channel_id,
exc_info=True,
)
_note_context_read_degraded("thought_summaries")
return []
summaries: list[dict[str, Any]] = []
seen_contents: set[str] = set()
for entry_json in raw_entries:
try:
data = json.loads(entry_json)
except json.JSONDecodeError:
continue
content = data.get("content", "")
if content in seen_contents:
continue
seen_contents.add(content)
summaries.append(data)
if len(summaries) >= limit:
break
return summaries
[docs]
async def backfill_channel_indexes(self, batch_size: int = 500) -> int:
"""Rebuild the per-channel sorted-set indexes from stored message hashes.
A maintenance/migration routine for when the ``channel_msgs:*`` indexes
are missing or incomplete (e.g. after introducing them) but the
underlying ``msg:*`` hashes still exist. ``SCAN``\\ s all ``msg:*`` keys
in *batch_size* pages, pipelines ``HMGET ... platform channel_id
timestamp`` to read their routing fields, then ``ZADD``\\ s each key into
its ``channel_msgs:{platform}:{channel_id}`` sorted set, refreshes the
zset TTL, and registers the channel in ``sg:active_channels`` — logging a
total at the end. Idempotent: ``ZADD`` simply re-scores existing members,
so it is safe to run repeatedly. No in-repo callers (operational entry
point, invoked manually).
Args:
batch_size (int): Number of keys to scan and process per page.
Returns:
int: The number of messages (re)indexed.
"""
cursor: int | str = 0
total = 0
while True:
cursor, keys = await self._redis.scan(
cursor,
match="msg:*",
count=batch_size,
)
if keys:
pipe = self._redis.pipeline()
for key in keys:
pipe.hmget(key, "platform", "channel_id", "timestamp")
results = await pipe.execute()
pipe = self._redis.pipeline()
added = 0
for key, values in zip(keys, results):
plat, chan, ts = values
if not plat or not chan or not ts:
continue
zset_key = self._channel_zset_key(plat, chan)
pipe.zadd(zset_key, {key: float(ts)})
pipe.expire(zset_key, TTL_SECONDS)
pipe.sadd("sg:active_channels", f"{plat}:{chan}")
pipe.expire("sg:active_channels", TTL_SECONDS)
added += 1
if added:
await pipe.execute()
total += added
if cursor == 0:
break
logger.info("backfill_channel_indexes: indexed %d messages", total)
return total
# ------------------------------------------------------------------
# Tool-call logging (Stargazer-System-Log pipeline)
# ------------------------------------------------------------------
# ------------------------------------------------------------------
# Context break (per-channel timestamp floor)
# ------------------------------------------------------------------
@staticmethod
def _ctxbreak_key(platform: str, channel_id: str) -> str:
"""Build the Redis string key holding a channel's context-break floor.
Names the ``ctxbreak:{platform}:{channel_id}`` key whose value is a
Unix-timestamp floor; every context-building pathway must exclude
messages at or before that timestamp. Pure key formatting with no Redis
or other side effects.
Used only by :meth:`set_ctxbreak_ts` (which ``SET``\\ s the floor) and
:meth:`get_ctxbreak_ts` (which ``GET``\\ s it); no external callers.
Args:
platform (str): Platform identifier (e.g. ``"discord"``).
channel_id (str): Channel identifier.
Returns:
str: The ``ctxbreak:{platform}:{channel_id}`` key.
"""
return f"ctxbreak:{platform}:{channel_id}"
[docs]
async def set_ctxbreak_ts(
self,
platform: str,
channel_id: str,
ts: float,
) -> None:
"""Persist a context-break timestamp floor for a channel.
Implements the user-facing "context break" / fresh-start command: once
set, every context-building pathway must drop messages whose timestamp is
``<=`` this value, so the model starts the conversation cleanly from this
point. ``SET``\\ s the ``ctxbreak:{platform}:{channel_id}`` string key
(built by :meth:`_ctxbreak_key`) to the stringified timestamp. The
counterpart reader is :meth:`get_ctxbreak_ts`. Called by
``message_processor.processor``.
Args:
platform (str): Platform identifier (e.g. ``"discord"``).
channel_id (str): Channel identifier.
ts (float): The Unix-timestamp floor; messages at or before it are
excluded from context.
"""
await self._redis.set(self._ctxbreak_key(platform, channel_id), str(ts))
[docs]
async def get_ctxbreak_ts(
self,
platform: str,
channel_id: str,
) -> float | None:
"""Return a channel's context-break timestamp floor, or ``None``.
The reader paired with :meth:`set_ctxbreak_ts`: every context assembly
path consults this to learn the cutoff below which messages must be
excluded. ``GET``\\ s the ``ctxbreak:{platform}:{channel_id}`` key (built
by :meth:`_ctxbreak_key`) and parses it to a float, returning ``None``
when unset or unparseable (i.e. no break in effect). Called by
:mod:`prompt_context`, ``message_processor.proactive_gates``,
``message_processor.generate_and_send``,
``message_processor.history_backfill``, and ``web.obs``.
Args:
platform (str): Platform identifier (e.g. ``"discord"``).
channel_id (str): Channel identifier.
Returns:
float | None: The context-break floor in Unix seconds, or ``None``
if none is set.
"""
raw = await self._redis.get(self._ctxbreak_key(platform, channel_id))
if raw is None:
return None
try:
return float(raw)
except (ValueError, TypeError):
return None
@property
def redis_client(self) -> aioredis.Redis:
"""The underlying async Redis connection.
Exposed so that other subsystems (e.g. :class:`ToolContext`)
can share the same connection without reaching into private
state.
"""
return self._redis
@property
def redis_raw_client(self) -> aioredis.Redis:
"""The underlying async Redis connection with decode_responses=False.
Exposed so that subsystems needing raw binary payloads (e.g., RedisEventBus)
can use it.
"""
return self._redis_raw
# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------
[docs]
async def close(self) -> None:
"""Close both underlying Redis connections held by this cache.
Releases the decoded (``self._redis``) and raw-binary
(``self._redis_raw``) connection pools created in :meth:`__init__`,
calling ``aclose()`` on each. Invoked at shutdown / teardown so sockets
and pooled connections are not leaked. Called by maintenance scripts such
as ``build_kg``.
"""
await self._redis.aclose()
await self._redis_raw.aclose()
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
@staticmethod
def _channel_zset_key(platform: str, channel_id: str) -> str:
"""Build the Redis key for a channel's time-ordered message index.
Names the ``channel_msgs:{platform}:{channel_id}`` sorted set that is the
backbone of every per-channel read in this class — recent history, time
ranges, "since", and window expansion all operate on it. Pure key
formatting with no Redis or other side effects. Used throughout
:class:`MessageCache` (e.g. by :meth:`log_message`, :meth:`get_recent`,
:meth:`get_messages_after`, :meth:`get_messages_around_key`,
:meth:`find_key_by_message_id`).
Args:
platform (str): Platform identifier (e.g. ``"discord"``).
channel_id (str): Channel identifier.
Returns:
str: The ``channel_msgs:{platform}:{channel_id}`` key.
"""
return f"channel_msgs:{platform}:{channel_id}"
@staticmethod
def _channel_seq_key(platform: str, channel_id: str) -> str:
"""Build the Redis key for a channel's monotonic sequence counter.
Names the ``INCR`` counter used by :meth:`_monotonic_channel_score`
(its only caller) to break ties between messages that share the same
millisecond timestamp, guaranteeing strictly increasing zset scores.
Args:
platform (str): Platform identifier (e.g. ``"discord"``).
channel_id (str): Channel identifier.
Returns:
str: The ``sg:channel_seq:{platform}:{channel_id}`` key.
"""
return f"sg:channel_seq:{platform}:{channel_id}"
_HASH_FIELDS = (
"user_id",
"user_name",
"platform",
"channel_id",
"text",
"timestamp",
"message_id",
"reply_to_id",
"kind",
)
async def _fetch_hashes(
self,
zset_key: str,
msg_keys: list[str],
) -> list[CachedMessage]:
"""Hydrate message keys into :class:`CachedMessage` objects, self-healing.
The shared materialiser behind the per-channel reads
(:meth:`get_recent`, :meth:`get_by_timerange`,
:meth:`get_messages_after`, :meth:`get_messages_around_key`). Pipelines
``HMGET`` of the scalar ``_HASH_FIELDS`` for each key — deliberately not
``HGETALL``, so the binary FLOAT32 embedding blob (incompatible with this
``decode_responses=True`` connection) is skipped and the returned
messages carry an empty embedding. Keys whose hash has expired or been
evicted (all-``None`` result) are collected and ``ZREM``\\ ed from
*zset_key* in a second pipeline, so reads keep the sorted-set index
self-healing rather than accumulating dangling members. Internal;
not called outside :class:`MessageCache`.
Args:
zset_key (str): The per-channel sorted-set key the keys came from,
used to prune stale members.
msg_keys (list[str]): The ``msg:{uuid}`` keys to hydrate.
Returns:
list[CachedMessage]: One message per live key (embedding empty),
positionally following *msg_keys* minus any pruned entries.
"""
if not msg_keys:
return []
pipe = self._redis.pipeline()
for key in msg_keys:
pipe.hmget(key, *self._HASH_FIELDS)
results = await pipe.execute()
messages: list[CachedMessage] = []
stale_keys: list[str] = []
for key, values in zip(msg_keys, results):
if not values or all(v is None for v in values):
stale_keys.append(key)
continue
mapping = dict(zip(self._HASH_FIELDS, values))
messages.append(
CachedMessage(
user_id=str(mapping.get("user_id") or ""),
user_name=str(mapping.get("user_name") or ""),
platform=str(mapping.get("platform") or ""),
channel_id=str(mapping.get("channel_id") or ""),
text=str(mapping.get("text") or ""),
timestamp=float(mapping.get("timestamp") or 0),
embedding=[],
message_key=key,
message_id=str(mapping.get("message_id") or ""),
reply_to_id=str(mapping.get("reply_to_id") or ""),
kind=str(mapping.get("kind") or "user_in"),
)
)
if stale_keys:
pipe = self._redis.pipeline()
for key in stale_keys:
pipe.zrem(zset_key, key)
await pipe.execute()
return messages
@staticmethod
def _doc_str(doc: Any, field_name: str) -> str:
"""Safely read one field off a RediSearch result document as a string.
RediSearch ``Document`` objects only carry the fields a query asked to
return and raise on absent attributes, so this getattr-with-default
helper normalises a possibly-missing field to a plain string (``""`` when
absent or ``None``). Pure attribute access with no I/O. Used by
:meth:`search_messages` and :meth:`_doc_to_cached_message` to read fields
off ``FT.SEARCH`` hits.
Args:
doc (Any): A RediSearch result document.
field_name (str): The field to read.
Returns:
str: The field value as a string, or ``""`` if missing.
"""
val = getattr(doc, field_name, None)
if val is None:
return ""
return str(val)
@classmethod
def _doc_to_cached_message(cls, doc: Any) -> CachedMessage:
"""Build a :class:`CachedMessage` from a RediSearch result document.
Bridges the ``FT.SEARCH`` result shape to the cache's domain object,
reading each field through :meth:`_doc_str` (so missing fields degrade to
sensible defaults) and carrying the document id across as
``message_key``. The embedding is left empty because search documents do
not return the vector blob. Pure transform with no I/O. Called by
:meth:`get_recent_for_user`.
Args:
doc (Any): A RediSearch result document for one ``msg:*`` row.
Returns:
CachedMessage: The materialised message (with an empty embedding).
"""
return CachedMessage(
user_id=cls._doc_str(doc, "user_id"),
user_name=cls._doc_str(doc, "user_name"),
platform=cls._doc_str(doc, "platform"),
channel_id=cls._doc_str(doc, "channel_id"),
text=cls._doc_str(doc, "text"),
timestamp=float(cls._doc_str(doc, "timestamp") or 0),
embedding=[],
message_key=getattr(doc, "id", ""),
message_id=cls._doc_str(doc, "message_id"),
reply_to_id=cls._doc_str(doc, "reply_to_id"),
kind=cls._doc_str(doc, "kind") or "user_in",
)
# ------------------------------------------------------------------
# Shared utility — active channel discovery
# ------------------------------------------------------------------
[docs]
async def get_active_channels(
redis: Any,
limit: int = 10,
) -> list[tuple[str, str]]:
"""Discover the most recently active channels across the whole cache.
A module-level utility (not a :class:`MessageCache` method, so it can run
against any raw Redis handle) that background workers use to decide which
channels to summarise, extract knowledge from, or run heartbeats on.
``SCAN``\\ s all ``channel_msgs:*`` keys, pipelines ``ZREVRANGE ... 0 0
WITHSCORES`` to read each channel's newest message score, parses the
``platform`` and ``channel_id`` back out of the key name, and returns the top
*limit* by recency. Read-only; on any failure it logs, records a
degraded-read metric via :func:`_note_context_read_degraded`, and fails open
with an empty list. Called by ``background_tasks``, ``kg_extraction``,
``background_agents.channel_summarizer``, and
``background_agents.channel_heartbeat``.
Args:
redis (Any): An async Redis client (decoded or raw) to scan.
limit (int): Maximum number of channels to return.
Returns:
list[tuple[str, str]]: Up to *limit* ``(platform, channel_id)`` tuples,
most-recently-active first; empty on error or when no channels exist.
"""
try:
all_keys: list[str] = []
async for key in redis.scan_iter(
match="channel_msgs:*",
count=500,
):
k = key.decode() if isinstance(key, bytes) else key
all_keys.append(k)
if not all_keys:
return []
pipe = redis.pipeline()
for key in all_keys:
pipe.zrevrange(key, 0, 0, withscores=True)
results = await pipe.execute()
scored: list[tuple[float, str, str]] = []
for key, result in zip(all_keys, results):
if not result:
continue
_member, score = result[0]
parts = key.split(":", 2)
if len(parts) < 3:
continue
scored.append((float(score), parts[1], parts[2]))
scored.sort(reverse=True)
return [(platform, cid) for _, platform, cid in scored[:limit]]
except Exception:
logger.exception("Failed to discover active channels")
_note_context_read_degraded("active_channels")
return []