"""NCM Cue Variant Cache — LLM-generated rephrasing variants for cascade cues.
On first use of any cue/reason string a background task fires a cheap
OpenRouter LLM call to generate alternative phrasings, then immediately
batch-embeds them using google/gemini-embedding-001. The result is stored in
Redis with a 24h TTL and held in-memory so subsequent turns pay no I/O cost.
Variant selection uses cosine similarity (dot product on unit vectors) against
the current limbic emotional context. Falls back to random.choice when the
context embedding is not yet available.
Redis key format: ncm:cue_variant:{sha256_hex_of_original_string}
Redis value (v2): JSON object — see _CacheEntry below
Redis value (v1): JSON array of strings (old format — loaded as text-only,
re-embedded lazily on next ensure_cached call)
"""
from __future__ import annotations
import asyncio
import hashlib
import json
import logging
import random
import re
from collections import OrderedDict
from typing import Optional
import httpx
from gemini_embed_pool import embed_batch_via_gemini
from openrouter_client import OpenRouterClient
logger = logging.getLogger(__name__)
OPENROUTER_CHAT_URL = "https://openrouter.ai/api/v1/chat/completions"
EMBED_MODEL = "google/gemini-embedding-001"
REDIS_KEY_PREFIX = "ncm:cue_variant:"
NUM_VARIANTS = 6
_TTL_MIN = 7 * 86400 # 7 days
_TTL_MAX = 14 * 86400 # 14 days
def _random_ttl() -> int:
return random.randint(_TTL_MIN, _TTL_MAX)
VARIANT_MODELS = [
"google/gemini-3.1-flash-lite-preview",
"stepfun/step-3.5-flash:free",
"nvidia/nemotron-3-nano-30b-a3b:free",
"openai/gpt-oss-120b:free",
"z-ai/glm-4.5-air:free",
]
def _system_prompt() -> str:
"""Internal helper: system prompt.
Returns:
str: Result string.
"""
n = NUM_VARIANTS
examples = ", ".join(f'"variant{i}"' for i in range(1, n + 1))
return (
"You are a neurochemical prose writer for an AI character. "
f"Given an internal emotional/somatic cue phrase, write exactly {n} alternative "
"phrasings. Preserve the core meaning; vary vocabulary, rhythm, and metaphor. "
"Keep each variant terse (5–18 words). "
f"Output ONLY a JSON array with exactly {n} strings: "
f"[{examples}]\n"
"No other text, no markdown, no explanation."
)
# ─────────────────────────────────────────────────────────────────────
# Helpers
# ─────────────────────────────────────────────────────────────────────
def _cache_key(original: str) -> str:
"""Internal helper: cache key.
Args:
original (str): The original value.
Returns:
str: Result string.
"""
h = hashlib.sha256(original.encode("utf-8")).hexdigest()
return f"{REDIS_KEY_PREFIX}{h}"
def _dot(a: list[float], b: list[float]) -> float:
"""Cosine similarity for unit-normalized vectors = plain dot product."""
return sum(x * y for x, y in zip(a, b))
def _normalize(vec: list[float]) -> list[float]:
"""L2-normalize a vector. Returns zero-vector unchanged."""
norm = sum(x * x for x in vec) ** 0.5
if norm < 1e-10:
return vec
return [x / norm for x in vec]
def _mean_vec(vecs: list[list[float]]) -> list[float]:
"""Compute the L2-normalized centroid of a list of vectors."""
if not vecs:
return []
dim = len(vecs[0])
centroid = [sum(v[i] for v in vecs) / len(vecs) for i in range(dim)]
return _normalize(centroid)
def _parse_variants(text: str) -> list[str]:
"""Extract a list of variant strings from a possibly messy LLM response."""
if not text:
return []
text = text.strip()
if "```" in text:
m = re.search(r"```(?:json)?\s*(.*?)```", text, re.DOTALL)
if m:
text = m.group(1).strip()
start = text.find("[")
end = text.rfind("]")
if start == -1 or end <= start:
return []
try:
arr = json.loads(text[start : end + 1])
if isinstance(arr, list):
valid = [s.strip() for s in arr if isinstance(s, str) and s.strip()]
if valid:
return valid[:NUM_VARIANTS]
except (json.JSONDecodeError, ValueError):
pass
return []
def _entry_to_redis(original: str, texts: list[str], vecs: list[list[float]]) -> str:
"""Serialize a cache entry to its Redis JSON representation."""
mean = _mean_vec(vecs) if vecs else []
variants = [{"text": t, "vec": v} for t, v in zip(texts, vecs)]
return json.dumps({"original": original, "variants": variants, "mean_vec": mean})
def _entry_from_redis(raw: str) -> dict | None:
"""Deserialize a Redis value. Returns None on failure.
Handles both v2 (object with 'variants' key) and v1 (plain list).
"""
try:
data = json.loads(raw)
except Exception:
return None
# v1 format — plain list of strings
if isinstance(data, list):
texts = [s for s in data if isinstance(s, str)]
if not texts:
return None
return {"original": None, "texts": texts, "vecs": [], "mean_vec": []}
# v2 format
if not isinstance(data, dict) or "variants" not in data:
return None
texts = []
vecs = []
for item in data.get("variants", []):
if isinstance(item, dict) and "text" in item:
texts.append(item["text"])
vecs.append(item.get("vec", []))
if not texts:
return None
return {
"original": data.get("original"),
"texts": texts,
"vecs": vecs,
"mean_vec": data.get("mean_vec", []),
}
# ─────────────────────────────────────────────────────────────────────
# Main class
# ─────────────────────────────────────────────────────────────────────
[docs]
class CueVariantCache:
"""Lazy LLM-backed variant cache for cascade cue and reason strings.
Variant selection is context-aware: before the cascade engine runs,
call ``set_context(emotion_text)`` to register the current dominant-
emotion string. ``get_variant()`` then picks the variant whose
embedding has the highest cosine similarity to the context embedding.
Falls back to ``random.choice`` until the context embedding is ready.
Parameters
----------
redis_client:
An ``redis.asyncio.Redis`` instance. May be None.
api_key:
OpenRouter API key. When None the cache is a no-op.
"""
[docs]
def __init__(
self,
redis_client=None,
api_key: Optional[str] = None,
openrouter_client: OpenRouterClient | None = None,
variant_models: Optional[list[str]] = None,
) -> None:
"""Initialize the instance.
Args:
redis_client: Redis connection client.
api_key (Optional[str]): The api key value.
openrouter_client: Shared OpenRouterClient for connection
pooling and batch embedding. Falls back to direct HTTP
when None.
variant_models: Override models for LLM generation. When None,
uses VARIANT_MODELS. Use e.g. ["google/gemini-3.1-flash-lite-preview"]
for paid-only, high-throughput pregeneration.
"""
self._redis = redis_client
self._api_key = api_key
self._openrouter = openrouter_client
self._variant_models = variant_models if variant_models is not None else VARIANT_MODELS
# original → {"texts": [...], "vecs": [...], "mean_vec": [...]}
self._mem: dict[str, dict] = {}
# Strings currently being generated / embedded
self._pending: set[str] = set()
# context_text → embedding vector (LRU-evicted, max 128)
self._query_cache: OrderedDict[str, list[float]] = OrderedDict()
self._query_cache_max = 128
# Strings whose context embedding is currently being fetched
self._embed_pending: set[str] = set()
# The query vector for the current turn (set by set_context)
self._current_query_vec: list[float] | None = None
# L2: Track background tasks for clean shutdown
self._background_tasks: set[asyncio.Task] = set()
# ------------------------------------------------------------------
# Public hot-path (sync)
# ------------------------------------------------------------------
[docs]
def set_context(self, context_text: str) -> None:
"""Register the current turn's emotional context for variant selection.
Sync — returns immediately. If the context embedding is already in
``_query_cache``, the query vector is updated synchronously. Otherwise
the query vector is cleared (falling back to random this turn) and a
background embed task is scheduled.
"""
if not context_text or not self._api_key:
self._current_query_vec = None
return
cached = self._query_cache.get(context_text)
if cached is not None:
self._current_query_vec = cached
return
# Not yet embedded — clear for this turn and schedule background fetch
self._current_query_vec = None
if context_text not in self._embed_pending:
self._embed_pending.add(context_text)
try:
task = asyncio.get_running_loop().create_task(
self._embed_context(context_text)
)
self._background_tasks.add(task)
task.add_done_callback(self._background_tasks.discard)
except RuntimeError:
self._embed_pending.discard(context_text)
[docs]
def get_variant(self, s: str) -> str:
"""Return the most contextually resonant variant, or the original.
Uses cosine similarity (dot product on unit vectors) against
``_current_query_vec`` when both the query vector and per-variant
embeddings are available. Falls back to ``random.choice`` otherwise.
"""
entry = self._mem.get(s)
if not entry:
return s
texts = entry.get("texts", [])
if not texts:
return s
vecs = entry.get("vecs", [])
qvec = self._current_query_vec
if qvec and vecs and len(vecs) == len(texts):
try:
scores = [_dot(qvec, v) for v in vecs if v]
if len(scores) == len(texts):
best_idx = scores.index(max(scores))
return texts[best_idx]
except Exception:
pass
return random.choice(texts)
# ------------------------------------------------------------------
# Public warm-up paths (async)
# ------------------------------------------------------------------
[docs]
async def ensure_cached(self, s: str) -> None:
"""Generate, embed, and cache variants for *s* if not already present.
Fire-and-forget safe — call via asyncio.create_task().
"""
if not s or not self._api_key:
return
# Already in memory and has embeddings — nothing to do
entry = self._mem.get(s)
if entry and entry.get("vecs"):
return
if s in self._pending:
return
# Check Redis first
if self._redis:
try:
raw = await self._redis.get(_cache_key(s))
if raw:
parsed = _entry_from_redis(raw)
if parsed:
if parsed["vecs"]:
# Full v2 entry — load directly
self._mem[s] = parsed
return
else:
# v1 or embedding-less entry — use texts but
# fall through to (re-)embed below
self._mem[s] = parsed
except Exception as e:
logger.debug("Variant cache Redis read error: %s", e)
self._pending.add(s)
try:
# Do we already have texts (from v1 Redis hit) or need to generate?
existing = self._mem.get(s)
if existing and existing.get("texts") and not existing.get("vecs"):
texts = existing["texts"]
# Need to embed only
vecs = await self._embed_texts(texts)
if vecs:
entry = {"original": s, "texts": texts, "vecs": vecs,
"mean_vec": _mean_vec(vecs)}
self._mem[s] = entry
await self._write_redis(s, texts, vecs)
else:
# Full generation pipeline
texts = await self._generate(s)
if texts:
vecs = await self._embed_texts(texts)
entry = {"original": s, "texts": texts,
"vecs": vecs if vecs else [],
"mean_vec": _mean_vec(vecs) if vecs else []}
self._mem[s] = entry
await self._write_redis(s, texts, vecs or [])
except Exception as e:
logger.warning("ensure_cached failed for %r: %s", s[:60], e)
finally:
self._pending.discard(s)
[docs]
async def load_all_from_redis(self) -> None:
"""Warm the in-memory layer by scanning all cached entries in Redis.
Now that the v2 format stores the original string inside the value,
this can populate ``_mem`` for every previously generated cue.
"""
if not self._redis:
return
try:
pattern = f"{REDIS_KEY_PREFIX}*"
cursor = 0
loaded = 0
while True:
cursor, keys = await self._redis.scan(
cursor, match=pattern, count=200
)
for key in keys:
try:
raw = await self._redis.get(key)
if not raw:
continue
parsed = _entry_from_redis(raw)
if parsed and parsed.get("original") and parsed.get("texts"):
self._mem[parsed["original"]] = parsed
loaded += 1
except Exception:
pass
if cursor == 0:
break
logger.info(
"Variant cache warm-up: loaded %d entries from Redis", loaded
)
except Exception as e:
logger.debug("Variant cache warm-up scan failed: %s", e)
# ------------------------------------------------------------------
# Internal embedding helpers
# ------------------------------------------------------------------
async def _embed_context(self, context_text: str) -> None:
"""Embed *context_text* and store in ``_query_cache``."""
try:
vecs = await self._embed_texts([context_text])
if vecs:
# LRU eviction
if len(self._query_cache) >= self._query_cache_max:
self._query_cache.popitem(last=False)
self._query_cache[context_text] = vecs[0]
# If this is still the intended context, update live
if self._current_query_vec is None:
self._current_query_vec = vecs[0]
logger.debug(
"Context embedding ready for: %.60s", context_text
)
except Exception as e:
logger.debug("Context embed failed for %r: %s", context_text[:60], e)
finally:
self._embed_pending.discard(context_text)
[docs]
async def drain(self) -> None:
"""Wait for all background embedding/generation tasks to complete."""
if self._background_tasks:
await asyncio.gather(*self._background_tasks, return_exceptions=True)
self._background_tasks.clear()
async def _embed_texts(self, texts: list[str]) -> list[list[float]]:
"""Batch-embed *texts* using ``google/gemini-embedding-001``.
When an ``OpenRouterClient`` was provided at init time, delegates
to ``client.embed_batch()`` for shared connection pooling,
dual-provider failover, and the batch embedding API. Otherwise
falls back to direct HTTP (original implementation).
"""
if not texts:
return []
# ── Preferred path: shared OpenRouterClient ──────────────
if self._openrouter is not None:
try:
raw = await self._openrouter.embed_batch(
texts, EMBED_MODEL,
)
return [_normalize(v) for v in raw if v]
except Exception as e:
logger.warning(
"Shared client embed_batch failed, "
"falling back to direct HTTP: %s", e,
)
# Fall through to direct HTTP below
# ── Fallback: Gemini API via shared key pool ─────────────
try:
raw = await embed_batch_via_gemini(texts, EMBED_MODEL)
return [_normalize(v) for v in raw]
except Exception as e:
logger.debug("Embedding request failed: %s", e)
return []
async def _write_redis(
self, original: str, texts: list[str], vecs: list[list[float]]
) -> None:
"""Persist a v2 entry to Redis with 24h TTL."""
if not self._redis:
return
try:
payload = _entry_to_redis(original, texts, vecs)
await self._redis.set(
_cache_key(original), payload, ex=_random_ttl()
)
except Exception as e:
logger.debug("Variant cache Redis write error: %s", e)
# ------------------------------------------------------------------
# Internal LLM generation
# ------------------------------------------------------------------
async def _generate(self, original: str) -> list[str]:
"""Call OpenRouter with model rotation to produce variant strings."""
if not self._api_key:
return []
headers = {
"Authorization": f"Bearer {self._api_key}",
"Content-Type": "application/json",
"HTTP-Referer": "https://github.com/matrix-llm-bot",
"X-Title": "Matrix LLM Bot",
}
user_msg = f'Original: "{original}"'
last_error: str = ""
async with httpx.AsyncClient(
timeout=httpx.Timeout(30.0, connect=10.0)
) as client:
for model in self._variant_models:
payload = {
"model": model,
"messages": [
{"role": "system", "content": _system_prompt()},
{"role": "user", "content": user_msg},
],
"temperature": 0.9,
"max_tokens": 1024,
}
try:
resp = await client.post(
OPENROUTER_CHAT_URL, json=payload, headers=headers
)
if resp.status_code == 429:
logger.debug(
"Variant gen rate-limited on %s, waiting 5s then next",
model,
)
last_error = "rate limited"
await asyncio.sleep(5)
continue
resp.raise_for_status()
data = resp.json()
if "error" in data:
last_error = str(data["error"])
logger.debug(
"Variant gen error from %s: %s", model, last_error
)
continue
text = (
data.get("choices", [{}])[0]
.get("message", {})
.get("content", "")
)
variants = _parse_variants(text)
if variants:
logger.info(
"Generated %d variants for %r via %s",
len(variants),
original[:50],
model,
)
return variants
logger.debug(
"Empty variant parse from %s (raw: %.200s)", model, text
)
last_error = "empty parse"
except Exception as e:
last_error = str(e)
logger.debug("Variant gen exception on %s: %s", model, e)
continue
logger.warning(
"All variant models failed for %r — last error: %s",
original[:60],
last_error,
)
return []