"""Redis-backed message cache with per-channel sorted-set indexes and
RediSearch vector search.
Each message is stored as a Redis hash keyed by ``msg:{uuid}`` with
the embedding as a binary FLOAT32 blob. A per-channel sorted set
(``channel_msgs:{platform}:{channel_id}``) provides O(log N) time-
ordered lookups by channel, while a RediSearch HNSW index
(``idx:messages``) enables semantic KNN similarity search.
All cached messages are represented as :class:`CachedMessage` objects
and are given a 45-day TTL in Redis.
"""
from __future__ import annotations
import json
import logging
import time
import uuid as _uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any
import numpy as np
import redis.asyncio as aioredis
from redis.commands.search.query import Query
from init_redis_indexes import MSG_INDEX_NAME, VECTOR_DIM
from openrouter_client import OpenRouterClient
logger = logging.getLogger(__name__)
TTL_SECONDS = 45 * 86400 # 45 days
def _embed_to_bytes(embedding: list[float] | np.ndarray) -> bytes:
"""Convert an embedding vector to a FLOAT32 byte blob for Redis."""
arr = np.array(embedding, dtype=np.float32)
return arr.tobytes()
def _bytes_to_embed(raw: bytes) -> list[float]:
"""Convert a FLOAT32 byte blob back to a Python list."""
return np.frombuffer(raw, dtype=np.float32).tolist()
def _escape_tag(value: str) -> str:
"""Escape special characters for RediSearch TAG queries."""
out: list[str] = []
for ch in value:
if ch in r"\.+*?[^]$(){}=!<>|:-@&~\"'/":
out.append("\\")
out.append(ch)
return "".join(out)
[docs]
@dataclass
class CachedMessage:
"""A single cached message stored in Redis.
Construct from raw JSON via :meth:`from_json` / :meth:`from_dict`,
or let :class:`MessageCache` build one for you.
"""
user_id: str
user_name: str
platform: str
channel_id: str
text: str
timestamp: float
embedding: list[float] = field(default_factory=list, repr=False)
message_key: str = ""
message_id: str = ""
"""Platform-specific message identifier (Discord message ID, Matrix event ID, etc.)."""
reply_to_id: str = ""
"""Platform-specific ID of the message this one replies to, if any."""
# -- Serialisation -----------------------------------------------------
[docs]
def to_dict(self) -> dict[str, Any]:
"""Convert to dict representation.
Returns:
dict[str, Any]: The result.
"""
return {
"user_id": self.user_id,
"user_name": self.user_name,
"platform": self.platform,
"channel_id": self.channel_id,
"text": self.text,
"timestamp": self.timestamp,
"embedding": self.embedding,
"message_id": self.message_id,
"reply_to_id": self.reply_to_id,
}
[docs]
def to_json(self) -> str:
"""Convert to json representation.
Returns:
str: Result string.
"""
return json.dumps(self.to_dict())
[docs]
def to_redis_hash(self) -> dict[str, str | bytes | float]:
"""Produce the mapping written to a Redis hash key."""
return {
"user_id": self.user_id,
"user_name": self.user_name,
"platform": self.platform,
"channel_id": self.channel_id,
"text": self.text,
"timestamp": str(self.timestamp),
"embedding": _embed_to_bytes(self.embedding),
"message_id": self.message_id,
"reply_to_id": self.reply_to_id,
}
[docs]
@classmethod
def from_dict(cls, data: dict[str, Any]) -> CachedMessage:
"""Construct from dict data.
Args:
data (dict[str, Any]): Input data payload.
Returns:
CachedMessage: The result.
"""
return cls(
user_id=str(data.get("user_id", "")),
user_name=str(data.get("user_name", "")),
platform=str(data.get("platform", "")),
channel_id=str(data.get("channel_id", "")),
text=str(data.get("text", "")),
timestamp=float(data.get("timestamp", 0.0)),
embedding=data.get("embedding") or [],
message_id=str(data.get("message_id", "")),
reply_to_id=str(data.get("reply_to_id", "")),
)
[docs]
@classmethod
def from_json(cls, raw: str) -> CachedMessage:
"""Construct from json data.
Args:
raw (str): The raw value.
Returns:
CachedMessage: The result.
"""
return cls.from_dict(json.loads(raw))
[docs]
@classmethod
def from_redis_hash(
cls,
mapping: dict[str, Any],
key: str = "",
) -> CachedMessage:
"""Construct from a Redis HGETALL result."""
emb_raw = mapping.get("embedding")
if isinstance(emb_raw, bytes) and len(emb_raw) > 100:
embedding = _bytes_to_embed(emb_raw)
else:
embedding = []
return cls(
user_id=str(mapping.get("user_id", "")),
user_name=str(mapping.get("user_name", "")),
platform=str(mapping.get("platform", "")),
channel_id=str(mapping.get("channel_id", "")),
text=str(mapping.get("text", "")),
timestamp=float(mapping.get("timestamp", 0)),
embedding=embedding,
message_key=key,
message_id=str(mapping.get("message_id", "")),
reply_to_id=str(mapping.get("reply_to_id", "")),
)
# -- Representation ----------------------------------------------------
@property
def repr(self) -> str:
"""Human-readable one-liner suitable for prompt injection or logs."""
dt = datetime.fromtimestamp(self.timestamp, tz=timezone.utc)
ts_str = dt.strftime("%Y-%m-%d %H:%M:%S")
preview = self.text if len(self.text) <= 120 else self.text[:117] + "..."
return f"[{ts_str}] {self.user_name}: {preview}"
[docs]
def __repr__(self) -> str:
"""Internal helper: repr .
Returns:
str: Result string.
"""
return (
f"CachedMessage(user={self.user_name!r}, platform={self.platform!r}, "
f"channel={self.channel_id!r}, ts={self.timestamp}, "
f"embedding_len={len(self.embedding)})"
)
[docs]
class MessageCache:
"""Async Redis message cache with automatic embedding generation
and RediSearch-backed vector search.
Parameters
----------
redis_url:
Redis connection URL (e.g. ``"redis://localhost:6379/0"``).
openrouter_client:
Shared :class:`OpenRouterClient` used to call the embeddings API.
embedding_model:
Model identifier for the embeddings endpoint
(e.g. ``"google/gemini-embedding-001"``).
"""
[docs]
def __init__(
self,
redis_url: str,
openrouter_client: OpenRouterClient,
embedding_model: str = "google/gemini-embedding-001",
ssl_kwargs: dict | None = None,
) -> None:
"""Initialize the instance.
Args:
redis_url (str): The redis url value.
openrouter_client (OpenRouterClient): The openrouter client value.
embedding_model (str): The embedding model value.
ssl_kwargs (dict | None): Optional SSL/mTLS keyword arguments
forwarded to ``redis.asyncio.from_url``.
"""
_ssl = ssl_kwargs or {}
self._redis: aioredis.Redis = aioredis.from_url(
redis_url, decode_responses=True, **_ssl,
)
# Separate connection WITHOUT decode_responses for reading
# binary fields (embeddings are raw float32 blobs).
self._redis_raw: aioredis.Redis = aioredis.from_url(
redis_url, decode_responses=False, **_ssl,
)
self._openrouter = openrouter_client
self._embedding_model = embedding_model
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
[docs]
async def log_message(
self,
platform: str,
channel_id: str,
user_id: str,
user_name: str,
text: str,
timestamp: float | None = None,
embedding: list[float] | None = None,
defer_embedding: bool = False,
message_id: str = "",
reply_to_id: str = "",
) -> CachedMessage:
"""Log a message to Redis with an embedding and a 45-day TTL.
The message is stored as a Redis hash (``msg:{uuid}``) with the
embedding as a binary FLOAT32 blob, automatically indexed by the
``idx:messages`` RediSearch index.
Parameters
----------
embedding:
Pre-computed embedding vector. When provided the internal
embedding API call is skipped, saving a round-trip.
defer_embedding:
When ``True`` **and** no pre-computed *embedding* is given,
store a zero-vector placeholder instead of calling the
embeddings API. The caller is responsible for enqueuing the
returned ``message_key`` in an :class:`EmbeddingBatchQueue`
so the real embedding is written later.
Returns
-------
CachedMessage
The message object that was written to Redis (includes the
embedding vector, or an empty list when deferred).
"""
if timestamp is None:
timestamp = time.time()
if embedding is None:
if defer_embedding or not text or not text.strip():
embedding = [0.0] * VECTOR_DIM
else:
embedding = await self._openrouter.embed(
text, self._embedding_model,
)
message = CachedMessage(
user_id=user_id,
user_name=user_name,
platform=platform,
channel_id=channel_id,
text=text,
timestamp=timestamp,
embedding=embedding,
message_id=message_id,
reply_to_id=reply_to_id,
)
key = f"msg:{_uuid.uuid4()}"
message.message_key = key
mapping = message.to_redis_hash()
zset_key = self._channel_zset_key(platform, channel_id)
pipe = self._redis.pipeline()
pipe.hset(key, mapping=mapping)
pipe.expire(key, TTL_SECONDS)
pipe.zadd(zset_key, {key: timestamp})
pipe.expire(zset_key, TTL_SECONDS)
await pipe.execute()
return message
[docs]
async def get_recent(
self,
platform: str,
channel_id: str,
count: int = 50,
) -> list[CachedMessage]:
"""Return the *count* most recent cached messages (newest first).
Uses the per-channel sorted set for O(log N) lookups without
depending on the RediSearch module.
"""
zset_key = self._channel_zset_key(platform, channel_id)
msg_keys: list[str] = await self._redis.zrevrange(zset_key, 0, count - 1)
if not msg_keys:
return []
return await self._fetch_hashes(zset_key, msg_keys)
[docs]
async def get_by_timerange(
self,
platform: str,
channel_id: str,
start: float,
end: float,
) -> list[CachedMessage]:
"""Return cached messages within a Unix-timestamp range (inclusive).
Uses the per-channel sorted set with ZRANGEBYSCORE for O(log N)
range lookups. Results are ordered ascending by timestamp.
"""
zset_key = self._channel_zset_key(platform, channel_id)
msg_keys: list[str] = await self._redis.zrangebyscore(
zset_key, min=start, max=end,
)
if not msg_keys:
return []
return await self._fetch_hashes(zset_key, msg_keys)
[docs]
async def get_messages_after(
self,
platform: str,
channel_id: str,
after_ts_exclusive: float | None,
*,
zset_batch: int = 5000,
) -> list[CachedMessage]:
"""Messages with zset score strictly greater than *after_ts_exclusive*.
When *after_ts_exclusive* is ``None``, returns all messages in the
channel zset (ascending by time). Uses Redis exclusive-interval
syntax ``(score`` on the lower bound.
"""
zset_key = self._channel_zset_key(platform, channel_id)
min_score = "-inf" if after_ts_exclusive is None else f"({after_ts_exclusive}"
msg_keys: list[str] = []
offset = 0
while True:
batch: list[Any] = await self._redis.zrangebyscore(
zset_key,
min_score,
"+inf",
start=offset,
num=zset_batch,
)
if not batch:
break
msg_keys.extend(
k.decode() if isinstance(k, bytes) else str(k)
for k in batch
)
offset += len(batch)
if len(batch) < zset_batch:
break
if not msg_keys:
return []
hm_batch = 400
messages: list[CachedMessage] = []
for i in range(0, len(msg_keys), hm_batch):
part = msg_keys[i: i + hm_batch]
messages.extend(await self._fetch_hashes(zset_key, part))
return messages
[docs]
async def update_text_by_message_id(
self,
platform: str,
channel_id: str,
message_id: str,
new_text: str,
) -> str | None:
"""Update the ``text`` field of a cached message identified by its
platform message ID.
Scans the most recent entries in the channel ZSET to find the
matching hash. Returns the Redis key if the message was found
and updated, ``None`` otherwise.
"""
target_key = await self.find_key_by_message_id(
platform, channel_id, message_id,
)
if target_key is None:
return None
await self._redis.hset(target_key, "text", new_text)
return target_key
[docs]
async def find_key_by_message_id(
self,
platform: str,
channel_id: str,
message_id: str,
) -> str | None:
"""Find the Redis key for a cached message by its platform message ID.
Returns the key string (e.g. ``"msg:abc-123"``) or ``None``.
"""
zset_key = self._channel_zset_key(platform, channel_id)
msg_keys: list[str] = await self._redis.zrevrange(zset_key, 0, 199)
if not msg_keys:
return None
pipe = self._redis.pipeline()
for key in msg_keys:
pipe.hget(key, "message_id")
results = await pipe.execute()
for key, stored_id in zip(msg_keys, results):
if stored_id is not None and str(stored_id) == message_id:
return key
return None
[docs]
async def find_keys_by_message_ids(
self,
platform: str,
channel_id: str,
message_ids: list[str],
) -> dict[str, str]:
"""Batch version of :meth:`find_key_by_message_id`.
Returns a mapping of ``{message_id: redis_key}`` for every
*message_id* that was found in Redis. Missing IDs are omitted.
Fetches the channel zset once and builds the full lookup map in
a single pipeline, instead of repeating the scan per message.
"""
if not message_ids:
return {}
zset_key = self._channel_zset_key(platform, channel_id)
msg_keys: list[str] = await self._redis.zrevrange(zset_key, 0, 199)
if not msg_keys:
return {}
pipe = self._redis.pipeline()
for key in msg_keys:
pipe.hget(key, "message_id")
results = await pipe.execute()
# Build reverse map: stored_message_id -> redis_key
id_to_key: dict[str, str] = {}
for key, stored_id in zip(msg_keys, results):
if stored_id is not None:
id_to_key[str(stored_id)] = key
# Filter to just the requested IDs
wanted = set(message_ids)
return {mid: id_to_key[mid] for mid in wanted if mid in id_to_key}
[docs]
async def has_real_embedding(self, redis_key: str) -> bool:
"""Return ``True`` if *redis_key* has a non-zero embedding blob.
A zero-vector placeholder (all zeros) is treated as missing.
Uses the raw (non-decoded) Redis connection to avoid UTF-8
corruption of binary float32 data.
"""
raw: bytes | None = await self._redis_raw.hget(redis_key, "embedding")
if raw is None or len(raw) < VECTOR_DIM * 4:
return False
return raw != b"\x00" * len(raw)
[docs]
async def has_real_embedding_many(
self, redis_keys: list[str],
) -> list[bool]:
"""Pipelined version of :meth:`has_real_embedding`.
Returns a list of booleans, one per key, indicating
whether each key has a non-zero embedding. Uses the raw
(non-decoded) Redis connection to correctly handle binary data.
"""
if not redis_keys:
return []
pipe = self._redis_raw.pipeline()
for key in redis_keys:
pipe.hget(key, "embedding")
results = await pipe.execute()
min_len = VECTOR_DIM * 4
zero_blob = b"\x00" * min_len
out: list[bool] = []
for raw in results:
if raw is None or len(raw) < min_len:
out.append(False)
else:
out.append(raw != zero_blob)
return out
[docs]
async def mark_deleted_by_message_id(
self,
platform: str,
channel_id: str,
message_id: str,
deleted_at_iso: str,
) -> bool:
"""Prepend a ``[deleted at TIMESTAMP]`` marker to a cached message's text.
The original text is preserved. Returns ``True`` if the message
was found and updated.
"""
zset_key = self._channel_zset_key(platform, channel_id)
msg_keys: list[str] = await self._redis.zrevrange(zset_key, 0, 199)
if not msg_keys:
return False
pipe = self._redis.pipeline()
for key in msg_keys:
pipe.hget(key, "message_id")
results = await pipe.execute()
target_key: str | None = None
for key, stored_id in zip(msg_keys, results):
if stored_id is not None and str(stored_id) == message_id:
target_key = key
break
if target_key is None:
return False
current_text = await self._redis.hget(target_key, "text")
tag = f"[deleted at {deleted_at_iso}] "
if current_text and tag not in str(current_text):
new_text = tag + str(current_text)
await self._redis.hset(target_key, "text", new_text)
return True
[docs]
async def search_messages(
self,
query: str,
limit: int = 10,
channel_id: str | None = None,
platform: str | None = None,
user_id: str | None = None,
) -> list[dict[str, Any]]:
"""Semantic search across cached messages using RediSearch KNN.
Generates an embedding for *query*, then runs a vector
similarity search filtered by optional *channel_id*,
*platform*, and/or *user_id*.
Returns a list of dicts with ``text``, ``user_name``,
``timestamp``, ``score``, ``channel_id``, and ``platform``.
"""
if not query or not query.strip():
return []
embedding = await self._openrouter.embed(
query, self._embedding_model,
)
query_blob = _embed_to_bytes(embedding)
pre_filter = "*"
parts: list[str] = []
if platform:
parts.append(f"@platform:{{{_escape_tag(platform)}}}")
if channel_id:
parts.append(f"@channel_id:{{{_escape_tag(channel_id)}}}")
if user_id:
parts.append(f"@user_id:{{{_escape_tag(user_id)}}}")
if parts:
pre_filter = " ".join(parts)
knn = (
f"({pre_filter})=>"
f"[KNN {limit} @embedding $query_vec AS score]"
)
q = (
Query(knn)
.sort_by("score")
.return_fields("user_id", "user_name", "platform",
"channel_id", "text", "timestamp", "score")
.paging(0, limit)
.dialect(2)
)
result = await self._redis.ft(MSG_INDEX_NAME).search(
q, query_params={"query_vec": query_blob},
)
out: list[dict[str, Any]] = []
for doc in result.docs:
score_raw = getattr(doc, "score", None)
cosine_dist = float(score_raw) if score_raw is not None else 1.0
similarity = 1.0 - cosine_dist
out.append({
"text": self._doc_str(doc, "text"),
"user_name": self._doc_str(doc, "user_name"),
"user_id": self._doc_str(doc, "user_id"),
"channel_id": self._doc_str(doc, "channel_id"),
"platform": self._doc_str(doc, "platform"),
"timestamp": float(self._doc_str(doc, "timestamp") or 0),
"similarity": round(similarity, 4),
})
return out
[docs]
async def get_recent_for_user(
self,
platform: str,
user_id: str,
limit: int = 20,
) -> list[CachedMessage]:
"""Return the most recent cached messages sent by *user_id*.
Uses a RediSearch ``FT.SEARCH`` query filtered by platform and
user_id, sorted by timestamp descending. This works across all
channels (including DMs) without needing to know the channel ID.
"""
tag_filter = (
f"@platform:{{{_escape_tag(platform)}}} "
f"@user_id:{{{_escape_tag(user_id)}}}"
)
q = (
Query(tag_filter)
.sort_by("timestamp", asc=False)
.return_fields(
"user_id", "user_name", "platform",
"channel_id", "text", "timestamp",
"message_id", "reply_to_id",
)
.paging(0, limit)
.dialect(2)
)
result = await self._redis.ft(MSG_INDEX_NAME).search(q)
return [self._doc_to_cached_message(doc) for doc in result.docs]
# ------------------------------------------------------------------
# Thought-summary storage
# ------------------------------------------------------------------
[docs]
async def log_thought_summary(
self,
channel_id: str,
thought_text: str,
timestamp: float | None = None,
) -> None:
"""Store a thought summary in a per-channel Redis sorted set.
Deduplicates by comparing against the most recent entry — if
it has the same content the new entry is silently skipped.
Parameters
----------
channel_id:
Platform-agnostic channel identifier.
thought_text:
The raw text extracted from ``<thought>`` tags.
timestamp:
Unix timestamp; defaults to ``time.time()``.
"""
if not thought_text or not thought_text.strip():
return
if timestamp is None:
timestamp = time.time()
zset_key = f"thought_summaries:{channel_id}"
# Deduplicate against the most recent entry
try:
recent = await self._redis.zrevrange(zset_key, 0, 0)
if recent:
last_data = json.loads(recent[0])
if last_data.get("content") == thought_text:
logger.debug(
"Skipping duplicate thought summary for channel %s",
channel_id,
)
return
except Exception:
logger.debug(
"Thought summary deduplication check failed, storing anyway",
exc_info=True,
)
entry = json.dumps({
"content": thought_text,
"timestamp": timestamp,
})
pipe = self._redis.pipeline()
pipe.zadd(zset_key, {entry: timestamp})
pipe.expire(zset_key, TTL_SECONDS)
await pipe.execute()
logger.debug(
"Stored thought summary in Redis for channel %s", channel_id,
)
[docs]
async def get_recent_thought_summaries(
self,
channel_id: str,
limit: int = 30,
) -> list[dict[str, Any]]:
"""Retrieve unique thought summaries for a channel (newest first).
Parameters
----------
channel_id:
Platform-agnostic channel identifier.
limit:
Maximum number of unique summaries to return.
Returns
-------
list[dict]
Each dict has ``content`` (str) and ``timestamp`` (float).
"""
zset_key = f"thought_summaries:{channel_id}"
try:
raw_entries: list[str] = await self._redis.zrevrange(
zset_key, 0, limit * 3,
)
except Exception:
logger.debug(
"Failed to retrieve thought summaries for channel %s",
channel_id, exc_info=True,
)
return []
summaries: list[dict[str, Any]] = []
seen_contents: set[str] = set()
for entry_json in raw_entries:
try:
data = json.loads(entry_json)
except json.JSONDecodeError:
continue
content = data.get("content", "")
if content in seen_contents:
continue
seen_contents.add(content)
summaries.append(data)
if len(summaries) >= limit:
break
return summaries
[docs]
async def backfill_channel_indexes(self, batch_size: int = 500) -> int:
"""Populate per-channel ZSETs from existing ``msg:*`` hashes.
Scans all ``msg:*`` keys and inserts them into the appropriate
``channel_msgs:{platform}:{channel_id}`` sorted set. Safe to
run multiple times -- ZADD is idempotent for existing members.
Returns the number of messages indexed.
"""
cursor: int | str = 0
total = 0
while True:
cursor, keys = await self._redis.scan(
cursor, match="msg:*", count=batch_size,
)
if keys:
pipe = self._redis.pipeline()
for key in keys:
pipe.hmget(key, "platform", "channel_id", "timestamp")
results = await pipe.execute()
pipe = self._redis.pipeline()
added = 0
for key, values in zip(keys, results):
plat, chan, ts = values
if not plat or not chan or not ts:
continue
zset_key = self._channel_zset_key(plat, chan)
pipe.zadd(zset_key, {key: float(ts)})
pipe.expire(zset_key, TTL_SECONDS)
added += 1
if added:
await pipe.execute()
total += added
if cursor == 0:
break
logger.info("backfill_channel_indexes: indexed %d messages", total)
return total
@property
def redis_client(self) -> aioredis.Redis:
"""The underlying async Redis connection.
Exposed so that other subsystems (e.g. :class:`ToolContext`)
can share the same connection without reaching into private
state.
"""
return self._redis
# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------
[docs]
async def close(self) -> None:
"""Gracefully close the Redis connections."""
await self._redis.aclose()
await self._redis_raw.aclose()
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
@staticmethod
def _channel_zset_key(platform: str, channel_id: str) -> str:
"""Internal helper: channel zset key.
Args:
platform (str): Platform adapter instance.
channel_id (str): Discord/Matrix channel identifier.
Returns:
str: Result string.
"""
return f"channel_msgs:{platform}:{channel_id}"
_HASH_FIELDS = ("user_id", "user_name", "platform", "channel_id", "text", "timestamp", "message_id", "reply_to_id")
async def _fetch_hashes(
self,
zset_key: str,
msg_keys: list[str],
) -> list[CachedMessage]:
"""Pipeline HMGET for *msg_keys*, prune any that have expired.
Uses HMGET (not HGETALL) to skip the binary embedding blob,
which is incompatible with ``decode_responses=True``.
"""
if not msg_keys:
return []
pipe = self._redis.pipeline()
for key in msg_keys:
pipe.hmget(key, *self._HASH_FIELDS)
results = await pipe.execute()
messages: list[CachedMessage] = []
stale_keys: list[str] = []
for key, values in zip(msg_keys, results):
if not values or all(v is None for v in values):
stale_keys.append(key)
continue
mapping = dict(zip(self._HASH_FIELDS, values))
messages.append(CachedMessage(
user_id=str(mapping.get("user_id") or ""),
user_name=str(mapping.get("user_name") or ""),
platform=str(mapping.get("platform") or ""),
channel_id=str(mapping.get("channel_id") or ""),
text=str(mapping.get("text") or ""),
timestamp=float(mapping.get("timestamp") or 0),
embedding=[],
message_key=key,
message_id=str(mapping.get("message_id") or ""),
reply_to_id=str(mapping.get("reply_to_id") or ""),
))
if stale_keys:
pipe = self._redis.pipeline()
for key in stale_keys:
pipe.zrem(zset_key, key)
await pipe.execute()
return messages
@staticmethod
def _doc_str(doc: Any, field_name: str) -> str:
"""Internal helper: doc str.
Args:
doc (Any): The doc value.
field_name (str): The field name value.
Returns:
str: Result string.
"""
val = getattr(doc, field_name, None)
if val is None:
return ""
return str(val)
@classmethod
def _doc_to_cached_message(cls, doc: Any) -> CachedMessage:
"""Internal helper: doc to cached message.
Args:
doc (Any): The doc value.
Returns:
CachedMessage: The result.
"""
return CachedMessage(
user_id=cls._doc_str(doc, "user_id"),
user_name=cls._doc_str(doc, "user_name"),
platform=cls._doc_str(doc, "platform"),
channel_id=cls._doc_str(doc, "channel_id"),
text=cls._doc_str(doc, "text"),
timestamp=float(cls._doc_str(doc, "timestamp") or 0),
embedding=[],
message_key=getattr(doc, "id", ""),
message_id=cls._doc_str(doc, "message_id"),
reply_to_id=cls._doc_str(doc, "reply_to_id"),
)
# ------------------------------------------------------------------
# Shared utility — active channel discovery
# ------------------------------------------------------------------
[docs]
async def get_active_channels(
redis: Any,
limit: int = 10,
) -> list[tuple[str, str]]:
"""Find the most recently active channels by checking ZSET scores.
Scans ``channel_msgs:*`` keys and returns the *limit* channels with
the highest max score (most recent message timestamp).
Returns a list of ``(platform, channel_id)`` tuples.
"""
try:
all_keys: list[str] = []
async for key in redis.scan_iter(
match="channel_msgs:*", count=500,
):
k = key.decode() if isinstance(key, bytes) else key
all_keys.append(k)
if not all_keys:
return []
pipe = redis.pipeline()
for key in all_keys:
pipe.zrevrange(key, 0, 0, withscores=True)
results = await pipe.execute()
scored: list[tuple[float, str, str]] = []
for key, result in zip(all_keys, results):
if not result:
continue
_member, score = result[0]
parts = key.split(":", 2)
if len(parts) < 3:
continue
scored.append((float(score), parts[1], parts[2]))
scored.sort(reverse=True)
return [(platform, cid) for _, platform, cid in scored[:limit]]
except Exception:
logger.exception("Failed to discover active channels")
return []