Source code for ncm_variant_cache

"""NCM Cue Variant Cache — LLM-generated rephrasing variants for cascade cues.

On first use of any cue/reason string a background task fires a cheap
OpenRouter LLM call to generate alternative phrasings, then immediately
batch-embeds them using google/gemini-embedding-001. The result is stored in
Redis with a 24h TTL and held in-memory so subsequent turns pay no I/O cost.
Variant selection uses cosine similarity (dot product on unit vectors) against
the current limbic emotional context. Falls back to random.choice when the
context embedding is not yet available.

Redis key format:  ncm:cue_variant:{sha256_hex_of_original_string}
Redis value (v2):  JSON object — see _CacheEntry below
Redis value (v1):  JSON array of strings (old format — loaded as text-only,
                   re-embedded lazily on next ensure_cached call)
"""

from __future__ import annotations

import asyncio
import hashlib
import json
import logging
import random
import re
from collections import OrderedDict
from typing import Optional

import httpx

from gemini_embed_pool import embed_batch_via_gemini
from openrouter_client import OpenRouterClient

logger = logging.getLogger(__name__)

OPENROUTER_CHAT_URL = "https://openrouter.ai/api/v1/chat/completions"
EMBED_MODEL = "google/gemini-embedding-001"
REDIS_KEY_PREFIX = "ncm:cue_variant:"
NUM_VARIANTS = 6
_TTL_MIN = 7 * 86400   # 7 days
_TTL_MAX = 14 * 86400  # 14 days

def _random_ttl() -> int:
    return random.randint(_TTL_MIN, _TTL_MAX)

VARIANT_MODELS = [
    "google/gemini-3.1-flash-lite-preview",
    "stepfun/step-3.5-flash:free",
    "nvidia/nemotron-3-nano-30b-a3b:free",
    "openai/gpt-oss-120b:free",
    "z-ai/glm-4.5-air:free",
]

def _system_prompt() -> str:
    """Internal helper: system prompt.

        Returns:
            str: Result string.
        """
    n = NUM_VARIANTS
    examples = ", ".join(f'"variant{i}"' for i in range(1, n + 1))
    return (
        "You are a neurochemical prose writer for an AI character. "
        f"Given an internal emotional/somatic cue phrase, write exactly {n} alternative "
        "phrasings. Preserve the core meaning; vary vocabulary, rhythm, and metaphor. "
        "Keep each variant terse (5–18 words). "
        f"Output ONLY a JSON array with exactly {n} strings: "
        f"[{examples}]\n"
        "No other text, no markdown, no explanation."
    )


# ─────────────────────────────────────────────────────────────────────
# Helpers
# ─────────────────────────────────────────────────────────────────────

def _cache_key(original: str) -> str:
    """Internal helper: cache key.

        Args:
            original (str): The original value.

        Returns:
            str: Result string.
        """
    h = hashlib.sha256(original.encode("utf-8")).hexdigest()
    return f"{REDIS_KEY_PREFIX}{h}"


def _dot(a: list[float], b: list[float]) -> float:
    """Cosine similarity for unit-normalized vectors = plain dot product."""
    return sum(x * y for x, y in zip(a, b))


def _normalize(vec: list[float]) -> list[float]:
    """L2-normalize a vector. Returns zero-vector unchanged."""
    norm = sum(x * x for x in vec) ** 0.5
    if norm < 1e-10:
        return vec
    return [x / norm for x in vec]


def _mean_vec(vecs: list[list[float]]) -> list[float]:
    """Compute the L2-normalized centroid of a list of vectors."""
    if not vecs:
        return []
    dim = len(vecs[0])
    centroid = [sum(v[i] for v in vecs) / len(vecs) for i in range(dim)]
    return _normalize(centroid)


def _parse_variants(text: str) -> list[str]:
    """Extract a list of variant strings from a possibly messy LLM response."""
    if not text:
        return []
    text = text.strip()

    if "```" in text:
        m = re.search(r"```(?:json)?\s*(.*?)```", text, re.DOTALL)
        if m:
            text = m.group(1).strip()

    start = text.find("[")
    end = text.rfind("]")
    if start == -1 or end <= start:
        return []

    try:
        arr = json.loads(text[start : end + 1])
        if isinstance(arr, list):
            valid = [s.strip() for s in arr if isinstance(s, str) and s.strip()]
            if valid:
                return valid[:NUM_VARIANTS]
    except (json.JSONDecodeError, ValueError):
        pass

    return []


def _entry_to_redis(original: str, texts: list[str], vecs: list[list[float]]) -> str:
    """Serialize a cache entry to its Redis JSON representation."""
    mean = _mean_vec(vecs) if vecs else []
    variants = [{"text": t, "vec": v} for t, v in zip(texts, vecs)]
    return json.dumps({"original": original, "variants": variants, "mean_vec": mean})


def _entry_from_redis(raw: str) -> dict | None:
    """Deserialize a Redis value.  Returns None on failure.

    Handles both v2 (object with 'variants' key) and v1 (plain list).
    """
    try:
        data = json.loads(raw)
    except Exception:
        return None

    # v1 format — plain list of strings
    if isinstance(data, list):
        texts = [s for s in data if isinstance(s, str)]
        if not texts:
            return None
        return {"original": None, "texts": texts, "vecs": [], "mean_vec": []}

    # v2 format
    if not isinstance(data, dict) or "variants" not in data:
        return None

    texts = []
    vecs = []
    for item in data.get("variants", []):
        if isinstance(item, dict) and "text" in item:
            texts.append(item["text"])
            vecs.append(item.get("vec", []))

    if not texts:
        return None

    return {
        "original": data.get("original"),
        "texts": texts,
        "vecs": vecs,
        "mean_vec": data.get("mean_vec", []),
    }


# ─────────────────────────────────────────────────────────────────────
# Main class
# ─────────────────────────────────────────────────────────────────────

[docs] class CueVariantCache: """Lazy LLM-backed variant cache for cascade cue and reason strings. Variant selection is context-aware: before the cascade engine runs, call ``set_context(emotion_text)`` to register the current dominant- emotion string. ``get_variant()`` then picks the variant whose embedding has the highest cosine similarity to the context embedding. Falls back to ``random.choice`` until the context embedding is ready. Parameters ---------- redis_client: An ``redis.asyncio.Redis`` instance. May be None. api_key: OpenRouter API key. When None the cache is a no-op. """
[docs] def __init__( self, redis_client=None, api_key: Optional[str] = None, openrouter_client: OpenRouterClient | None = None, variant_models: Optional[list[str]] = None, ) -> None: """Initialize the instance. Args: redis_client: Redis connection client. api_key (Optional[str]): The api key value. openrouter_client: Shared OpenRouterClient for connection pooling and batch embedding. Falls back to direct HTTP when None. variant_models: Override models for LLM generation. When None, uses VARIANT_MODELS. Use e.g. ["google/gemini-3.1-flash-lite-preview"] for paid-only, high-throughput pregeneration. """ self._redis = redis_client self._api_key = api_key self._openrouter = openrouter_client self._variant_models = variant_models if variant_models is not None else VARIANT_MODELS # original → {"texts": [...], "vecs": [...], "mean_vec": [...]} self._mem: dict[str, dict] = {} # Strings currently being generated / embedded self._pending: set[str] = set() # context_text → embedding vector (LRU-evicted, max 128) self._query_cache: OrderedDict[str, list[float]] = OrderedDict() self._query_cache_max = 128 # Strings whose context embedding is currently being fetched self._embed_pending: set[str] = set() # The query vector for the current turn (set by set_context) self._current_query_vec: list[float] | None = None # L2: Track background tasks for clean shutdown self._background_tasks: set[asyncio.Task] = set()
# ------------------------------------------------------------------ # Public hot-path (sync) # ------------------------------------------------------------------
[docs] def set_context(self, context_text: str) -> None: """Register the current turn's emotional context for variant selection. Sync — returns immediately. If the context embedding is already in ``_query_cache``, the query vector is updated synchronously. Otherwise the query vector is cleared (falling back to random this turn) and a background embed task is scheduled. """ if not context_text or not self._api_key: self._current_query_vec = None return cached = self._query_cache.get(context_text) if cached is not None: self._current_query_vec = cached return # Not yet embedded — clear for this turn and schedule background fetch self._current_query_vec = None if context_text not in self._embed_pending: self._embed_pending.add(context_text) try: task = asyncio.get_running_loop().create_task( self._embed_context(context_text) ) self._background_tasks.add(task) task.add_done_callback(self._background_tasks.discard) except RuntimeError: self._embed_pending.discard(context_text)
[docs] def get_variant(self, s: str) -> str: """Return the most contextually resonant variant, or the original. Uses cosine similarity (dot product on unit vectors) against ``_current_query_vec`` when both the query vector and per-variant embeddings are available. Falls back to ``random.choice`` otherwise. """ entry = self._mem.get(s) if not entry: return s texts = entry.get("texts", []) if not texts: return s vecs = entry.get("vecs", []) qvec = self._current_query_vec if qvec and vecs and len(vecs) == len(texts): try: scores = [_dot(qvec, v) for v in vecs if v] if len(scores) == len(texts): best_idx = scores.index(max(scores)) return texts[best_idx] except Exception: pass return random.choice(texts)
# ------------------------------------------------------------------ # Public warm-up paths (async) # ------------------------------------------------------------------
[docs] async def ensure_cached(self, s: str) -> None: """Generate, embed, and cache variants for *s* if not already present. Fire-and-forget safe — call via asyncio.create_task(). """ if not s or not self._api_key: return # Already in memory and has embeddings — nothing to do entry = self._mem.get(s) if entry and entry.get("vecs"): return if s in self._pending: return # Check Redis first if self._redis: try: raw = await self._redis.get(_cache_key(s)) if raw: parsed = _entry_from_redis(raw) if parsed: if parsed["vecs"]: # Full v2 entry — load directly self._mem[s] = parsed return else: # v1 or embedding-less entry — use texts but # fall through to (re-)embed below self._mem[s] = parsed except Exception as e: logger.debug("Variant cache Redis read error: %s", e) self._pending.add(s) try: # Do we already have texts (from v1 Redis hit) or need to generate? existing = self._mem.get(s) if existing and existing.get("texts") and not existing.get("vecs"): texts = existing["texts"] # Need to embed only vecs = await self._embed_texts(texts) if vecs: entry = {"original": s, "texts": texts, "vecs": vecs, "mean_vec": _mean_vec(vecs)} self._mem[s] = entry await self._write_redis(s, texts, vecs) else: # Full generation pipeline texts = await self._generate(s) if texts: vecs = await self._embed_texts(texts) entry = {"original": s, "texts": texts, "vecs": vecs if vecs else [], "mean_vec": _mean_vec(vecs) if vecs else []} self._mem[s] = entry await self._write_redis(s, texts, vecs or []) except Exception as e: logger.warning("ensure_cached failed for %r: %s", s[:60], e) finally: self._pending.discard(s)
[docs] async def load_all_from_redis(self) -> None: """Warm the in-memory layer by scanning all cached entries in Redis. Now that the v2 format stores the original string inside the value, this can populate ``_mem`` for every previously generated cue. """ if not self._redis: return try: pattern = f"{REDIS_KEY_PREFIX}*" cursor = 0 loaded = 0 while True: cursor, keys = await self._redis.scan( cursor, match=pattern, count=200 ) for key in keys: try: raw = await self._redis.get(key) if not raw: continue parsed = _entry_from_redis(raw) if parsed and parsed.get("original") and parsed.get("texts"): self._mem[parsed["original"]] = parsed loaded += 1 except Exception: pass if cursor == 0: break logger.info( "Variant cache warm-up: loaded %d entries from Redis", loaded ) except Exception as e: logger.debug("Variant cache warm-up scan failed: %s", e)
# ------------------------------------------------------------------ # Internal embedding helpers # ------------------------------------------------------------------ async def _embed_context(self, context_text: str) -> None: """Embed *context_text* and store in ``_query_cache``.""" try: vecs = await self._embed_texts([context_text]) if vecs: # LRU eviction if len(self._query_cache) >= self._query_cache_max: self._query_cache.popitem(last=False) self._query_cache[context_text] = vecs[0] # If this is still the intended context, update live if self._current_query_vec is None: self._current_query_vec = vecs[0] logger.debug( "Context embedding ready for: %.60s", context_text ) except Exception as e: logger.debug("Context embed failed for %r: %s", context_text[:60], e) finally: self._embed_pending.discard(context_text)
[docs] async def drain(self) -> None: """Wait for all background embedding/generation tasks to complete.""" if self._background_tasks: await asyncio.gather(*self._background_tasks, return_exceptions=True) self._background_tasks.clear()
async def _embed_texts(self, texts: list[str]) -> list[list[float]]: """Batch-embed *texts* using ``google/gemini-embedding-001``. When an ``OpenRouterClient`` was provided at init time, delegates to ``client.embed_batch()`` for shared connection pooling, dual-provider failover, and the batch embedding API. Otherwise falls back to direct HTTP (original implementation). """ if not texts: return [] # ── Preferred path: shared OpenRouterClient ────────────── if self._openrouter is not None: try: raw = await self._openrouter.embed_batch( texts, EMBED_MODEL, ) return [_normalize(v) for v in raw if v] except Exception as e: logger.warning( "Shared client embed_batch failed, " "falling back to direct HTTP: %s", e, ) # Fall through to direct HTTP below # ── Fallback: Gemini API via shared key pool ───────────── try: raw = await embed_batch_via_gemini(texts, EMBED_MODEL) return [_normalize(v) for v in raw] except Exception as e: logger.debug("Embedding request failed: %s", e) return [] async def _write_redis( self, original: str, texts: list[str], vecs: list[list[float]] ) -> None: """Persist a v2 entry to Redis with 24h TTL.""" if not self._redis: return try: payload = _entry_to_redis(original, texts, vecs) await self._redis.set( _cache_key(original), payload, ex=_random_ttl() ) except Exception as e: logger.debug("Variant cache Redis write error: %s", e) # ------------------------------------------------------------------ # Internal LLM generation # ------------------------------------------------------------------ async def _generate(self, original: str) -> list[str]: """Call OpenRouter with model rotation to produce variant strings.""" if not self._api_key: return [] headers = { "Authorization": f"Bearer {self._api_key}", "Content-Type": "application/json", "HTTP-Referer": "https://github.com/matrix-llm-bot", "X-Title": "Matrix LLM Bot", } user_msg = f'Original: "{original}"' last_error: str = "" async with httpx.AsyncClient( timeout=httpx.Timeout(30.0, connect=10.0) ) as client: for model in self._variant_models: payload = { "model": model, "messages": [ {"role": "system", "content": _system_prompt()}, {"role": "user", "content": user_msg}, ], "temperature": 0.9, "max_tokens": 1024, } try: resp = await client.post( OPENROUTER_CHAT_URL, json=payload, headers=headers ) if resp.status_code == 429: logger.debug( "Variant gen rate-limited on %s, waiting 5s then next", model, ) last_error = "rate limited" await asyncio.sleep(5) continue resp.raise_for_status() data = resp.json() if "error" in data: last_error = str(data["error"]) logger.debug( "Variant gen error from %s: %s", model, last_error ) continue text = ( data.get("choices", [{}])[0] .get("message", {}) .get("content", "") ) variants = _parse_variants(text) if variants: logger.info( "Generated %d variants for %r via %s", len(variants), original[:50], model, ) return variants logger.debug( "Empty variant parse from %s (raw: %.200s)", model, text ) last_error = "empty parse" except Exception as e: last_error = str(e) logger.debug("Variant gen exception on %s: %s", model, e) continue logger.warning( "All variant models failed for %r — last error: %s", original[:60], last_error, ) return []