Source code for cascade_engine

"""NCM Cascade Engine — Multi-turn neurochemical event sequences.

╔═══════════════════════════════════════════════════════════════════════════════╗
║  🌀 CASCADE ENGINE                                                            ║
╠═══════════════════════════════════════════════════════════════════════════════╣
║  Loads cascade definitions from ncm_cascades.yaml                            ║
║  Checks triggers against current NCM vector + active emotions                ║
║  Advances active cascades one stage-tick per turn                            ║
║  Handles interrupts (abort / pause / skip_to_stage / trigger_cascade)        ║
║  Applies synergy bonuses when cascades co-activate                           ║
║  Persists state in Redis: ncm:cascades:{channel_id}                          ║
╠═══════════════════════════════════════════════════════════════════════════════╣
║  Called during exhale() after metabolic decay, stimulus delta        ║
║  stacking, and antagonist suppression.                               ║
╚═══════════════════════════════════════════════════════════════════════════════╝
"""

from __future__ import annotations

import asyncio
import jsonutil as json
import logging
import os
import random
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING

import yaml

if TYPE_CHECKING:
    from ncm_variant_cache import CueVariantCache

logger = logging.getLogger(__name__)

# ─────────────────────────────────────────────────────────────────────
# Constants
# ─────────────────────────────────────────────────────────────────────
MAX_CONCURRENT_CASCADES = 3
REDIS_CASCADE_KEY = "ncm:cascades:{channel_id}"
REDIS_CASCADE_TTL = 86400  # 24h

# Delta parsing (reuse from ncm_delta_parser if available)
try:
    from ncm_delta_parser import parse_delta_string
except ImportError:
    import re

    _NUM_RE = re.compile(r"^([A-Za-z0-9_]+)([+-])(\d+\.?\d*)$")
    _ARR_RE = re.compile(r"^([A-Za-z0-9_]+)(↑|↓)$")
    _REV_RE = re.compile(r"^([A-Za-z0-9_]+)\.reversed$")

    def parse_delta_string(ds: str) -> Dict[str, float]:
        """Parse an NCM delta string into a node-to-magnitude mapping (fallback).

        Pure-Python stand-in used only when ``ncm_delta_parser.parse_delta_string``
        cannot be imported, so cascade definitions still load with degraded
        precision. Tokenizes whitespace-separated entries and recognizes three
        shapes via the module-level regexes: signed numeric deltas like
        ``DA+0.4`` (``_NUM_RE``), arrow nudges like ``ENT1`` up/down worth
        +/-0.15 (``_ARR_RE``), and reversal tokens like ``SERT.reversed`` worth
        +0.2 (``_REV_RE``); unrecognized tokens are ignored. Touches no external
        state. Called at module load by :func:`_load_cascade_defs` to pre-parse
        every stage and synergy ``delta`` into a vector.

        Args:
            ds: Raw delta string such as ``"DA+0.4 5HT-0.2 ENT1"``.

        Returns:
            Mapping of node name to summed float magnitude (empty when ``ds`` is
            falsy or contains no recognizable tokens).
        """
        if not ds:
            return {}
        result: Dict[str, float] = {}
        for tok in ds.strip().split():
            m = _NUM_RE.match(tok)
            if m:
                n, s, v = m.groups()
                result[n] = result.get(n, 0.0) + (float(v) if s == "+" else -float(v))
                continue
            m = _ARR_RE.match(tok)
            if m:
                n, a = m.groups()
                result[n] = result.get(n, 0.0) + (0.15 if a == "↑" else -0.15)
                continue
            m = _REV_RE.match(tok)
            if m:
                result[m.group(1)] = result.get(m.group(1), 0.0) + 0.2
        return result


# ─────────────────────────────────────────────────────────────────────
# Cascade Definition Loader
# ─────────────────────────────────────────────────────────────────────
_cascade_defs: Optional[Dict[str, Dict[str, Any]]] = None


def _load_cascade_defs() -> Dict[str, Dict[str, Any]]:
    """Load, pre-parse, and process-cache cascade definitions from YAML.

    Reads ``ncm_cascades.yaml`` (sibling to this module) once and memoizes the
    result in the module-global ``_cascade_defs`` so subsequent calls are free.
    Skips ``_meta_``-prefixed system configs (those belong to
    :func:`_load_meta_configs`) and any entry lacking a ``stages`` list, then
    walks each cascade pre-parsing every stage and synergy ``delta`` string into
    a numeric vector via :func:`parse_delta_string`, including the optional
    sativa/indica pole deltas used for strain-gradient interpolation. Reads the
    filesystem and logs a warning if the YAML is missing and an error if it fails
    to parse, returning an empty mapping in either case. Called by
    :func:`_eval_trigger`, by :meth:`CascadeEngine.__init__`, and indirectly
    wherever an engine resolves stage names.

    Returns:
        Mapping of cascade id to its definition dict, with ``delta_vector`` (and
        any pole-specific vectors) injected into each stage and synergy.
    """
    global _cascade_defs
    if _cascade_defs is not None:
        return _cascade_defs

    yaml_path = os.path.join(
        os.path.dirname(os.path.abspath(__file__)), "ncm_cascades.yaml"
    )
    if not os.path.exists(yaml_path):
        logger.warning("ncm_cascades.yaml not found at %s", yaml_path)
        _cascade_defs = {}
        return _cascade_defs

    try:
        with open(yaml_path, "r", encoding="utf-8") as f:
            raw = yaml.safe_load(f)
    except Exception as e:
        logger.error("Failed to load ncm_cascades.yaml: %s", e)
        _cascade_defs = {}
        return _cascade_defs

    # Pre-parse all stage deltas into vectors
    defs: Dict[str, Dict[str, Any]] = {}
    meta: Dict[str, Dict[str, Any]] = {}
    for cascade_id, cdef in raw.items():
        if not isinstance(cdef, dict):
            continue
        # Meta-system configs start with _meta_
        if cascade_id.startswith("_meta_"):
            meta[cascade_id] = cdef
            continue
        if "stages" not in cdef:
            continue
        for stage in cdef.get("stages", []):
            if "delta" in stage:
                stage["delta_vector"] = parse_delta_string(stage["delta"])
            # 🌿 Pre-parse sativa/indica pole deltas for gradient lerp
            if "delta_sativa" in stage:
                stage["delta_vector_sativa"] = parse_delta_string(stage["delta_sativa"])
            if "delta_indica" in stage:
                stage["delta_vector_indica"] = parse_delta_string(stage["delta_indica"])
        # Also parse synergy deltas
        for syn in cdef.get("synergy", []):
            if "delta" in syn:
                syn["delta_vector"] = parse_delta_string(syn["delta"])
        defs[cascade_id] = cdef

    _cascade_defs = defs
    logger.info("Loaded %d cascade definitions, %d meta configs", len(defs), len(meta))
    return _cascade_defs


def _load_meta_configs() -> Dict[str, Dict[str, Any]]:
    """Load the ``_meta_``-prefixed meta-system configs from the cascade YAML.

    Re-reads ``ncm_cascades.yaml`` and returns only the special entries whose
    keys start with ``_meta_`` (e.g. ``_meta_habituation``,
    ``_meta_gaba_polarity``), which tune cross-cascade behaviors like tolerance,
    sensitization, and GABA chloride-reversal polarity rather than describing a
    cascade themselves. Reads the filesystem and silently returns an empty
    mapping when the file is absent or unparseable. Called by
    :meth:`CascadeEngine.__init__` to populate ``self._meta``.

    Returns:
        Mapping of meta-config id to its config dict (empty if none found).
    """
    yaml_path = os.path.join(
        os.path.dirname(os.path.abspath(__file__)), "ncm_cascades.yaml"
    )
    if not os.path.exists(yaml_path):
        return {}
    try:
        with open(yaml_path, "r", encoding="utf-8") as f:
            raw = yaml.safe_load(f)
    except Exception:
        return {}
    return {
        k: v for k, v in raw.items() if isinstance(v, dict) and k.startswith("_meta_")
    }


# ─────────────────────────────────────────────────────────────────────
# Cascade State (per-channel, persisted in Redis)
# ─────────────────────────────────────────────────────────────────────
[docs] class CascadeState: """Per-channel cascade bookkeeping, serializable to and from Redis. Holds everything the engine needs to advance a channel's cascades across turns: the currently ``active`` cascades and their stage progress, per-cascade ``cooldowns``, a rolling ``history`` of trigger/abort/complete events, rolling windows of recent node values, delta counts and emotion sets used for sustained-condition checks, and the habituation ``fire_counts`` / ``last_fired_turn`` used for tolerance and recovery. Instances are loaded and persisted by :meth:`CascadeEngine._load_state` and :meth:`CascadeEngine._save_state` under the ``ncm:cascades:{channel_id}`` Redis key, round-tripping through :meth:`to_dict` and :meth:`from_dict`. """
[docs] def __init__(self): """Initialize an empty cascade state with no active cascades. Sets up all tracking containers to their empty defaults so a fresh channel starts with nothing active, no cooldowns, and zeroed rolling windows. Touches no external state; called by :meth:`from_dict` and by :meth:`CascadeEngine._load_state` when Redis holds no prior state. """ self.active: Dict[str, Dict[str, Any]] = {} self.cooldowns: Dict[str, int] = {} self.history: List[str] = [] self.turn_count: int = 0 # Rolling tracker for sustained-condition checks self.node_history: Dict[str, List[float]] = {} self.delta_counts: List[int] = [] self.last_emotions: List[Set[str]] = [] # Habituation tracking self.fire_counts: Dict[str, int] = {} # Last turn each cascade completed (for tolerance reversal) self.last_fired_turn: Dict[str, int] = {}
[docs] def to_dict(self) -> dict: """Convert to dict representation. Returns: dict: Result dictionary. """ return { "active": self.active, "cooldowns": self.cooldowns, "history": self.history[-50:], # keep last 50 "turn_count": self.turn_count, "delta_counts": self.delta_counts[-10:], "last_emotions": [list(s) for s in self.last_emotions[-5:]], "fire_counts": self.fire_counts, "last_fired_turn": self.last_fired_turn, }
[docs] @classmethod def from_dict(cls, d: dict) -> "CascadeState": """Construct from dict data. Args: d (dict): The d value. Returns: 'CascadeState': The result. """ s = cls() s.active = d.get("active", {}) s.cooldowns = d.get("cooldowns", {}) s.history = d.get("history", []) s.turn_count = d.get("turn_count", 0) s.delta_counts = d.get("delta_counts", []) s.last_emotions = [set(e) for e in d.get("last_emotions", [])] s.fire_counts = d.get("fire_counts", {}) s.last_fired_turn = d.get("last_fired_turn", {}) return s
# ───────────────────────────────────────────────────────────────────── # Condition Evaluator # ───────────────────────────────────────────────────────────────────── def _eval_node_condition( cond: dict, vector: Dict[str, float], state: CascadeState, ) -> bool: """Evaluate one node-threshold condition against the current NCM vector. Compares the named node's current value (defaulting to the 0.5 baseline when absent) to ``cond["val"]`` using the requested operator (``>``, ``<``, ``>=``, ``<=``; unknown operators fall back to ``>``). When the condition carries a ``for_turns`` count, it additionally requires the threshold to have held across that many of the most recent samples in ``state.node_history`` (which :meth:`CascadeEngine.tick` populates each turn), so transient spikes do not satisfy a sustained gate. Reads only the passed ``vector`` and ``state`` and mutates nothing. Called by :func:`_eval_trigger` and by the node-based interrupt branch of :meth:`CascadeEngine.tick`. Args: cond: Condition dict with ``node``, ``op``, ``val`` and optional ``for_turns``. vector: Current NCM state vector (node name to float). state: Channel cascade state supplying ``node_history`` for sustained checks. Returns: ``True`` when the comparison (and any sustained requirement) holds. """ node = cond.get("node", "") op = cond.get("op", ">") val = cond.get("val", 0.0) current = vector.get(node, 0.5) # baseline 0.5 if op == ">": result = current > val elif op == "<": result = current < val elif op == ">=": result = current >= val elif op == "<=": result = current <= val else: result = current > val # Check sustained condition (for_turns) for_turns = cond.get("for_turns") if for_turns and result: # This needs historical checking — simplified: assume true if # the node has been tracked above threshold recently # (The node_history is populated by the engine each turn) hist = state.node_history.get(node, []) if len(hist) < for_turns: return False if not all( (h > val if op in (">", ">=") else h < val) for h in hist[-for_turns:] ): return False return result def _eval_trigger( trigger: dict, vector: Dict[str, float], active_emotions: Set[str], state: CascadeState, active_cascade_ids: Set[str], ) -> bool: """Decide whether a cascade's full trigger block fires this turn. Evaluates the rich, all-or-nothing trigger grammar used in ``ncm_cascades.yaml``: it short-circuits ``manual_only`` cascades, then checks (as applicable) that a cascade just completed, that prerequisite cascades are active at a minimum stage (resolving stage names via :func:`_load_cascade_defs`), the per-turn delta count, a peak-emotion gate, every ``all`` sub-condition (delegating node thresholds to :func:`_eval_node_condition` and computing emotion-variance over the rolling window), at least one ``any`` sub-condition, and any required active emotion. Pure predicate over the passed arguments; it reads cascade definitions and ``state`` history but mutates nothing. Called by :meth:`CascadeEngine.tick` during the new-cascade trigger phase, once per candidate cascade. Args: trigger: The cascade's ``trigger`` block from YAML. vector: Current NCM state vector, possibly sensitization-adjusted by the caller. active_emotions: Emotions that fired this turn. state: Channel cascade state supplying history and rolling windows. active_cascade_ids: Ids of cascades already active this turn. Returns: ``True`` only when every applicable condition in the trigger passes. """ if trigger.get("manual_only"): return False # Check cascade_just_completed if trigger.get("cascade_just_completed"): # Check if any cascade completed this turn last_history = state.history[-1] if state.history else "" if ":completed:" not in last_history: return False # Check cascade_requires for req in trigger.get("cascade_requires", []): req_cascade = req.get("cascade", "") min_stage = req.get("min_stage", "") if req_cascade not in state.active: return False if min_stage: defs = _load_cascade_defs() cdef = defs.get(req_cascade, {}) stages = cdef.get("stages", []) stage_names = [s["name"] for s in stages] current_stage_idx = state.active[req_cascade].get("stage", 0) if min_stage in stage_names: required_idx = stage_names.index(min_stage) if current_stage_idx < required_idx: return False # Check delta_count_this_turn delta_count_cond = trigger.get("condition", {}) if isinstance(delta_count_cond, dict): if "delta_count_this_turn" in delta_count_cond: dc_op = delta_count_cond.get("op", ">") dc_val = delta_count_cond.get("val", 4) last_dc = state.delta_counts[-1] if state.delta_counts else 0 if dc_op == ">" and not (last_dc > dc_val): return False elif dc_op == "<" and not (last_dc < dc_val): return False # Check peak_last_turn peak = trigger.get("peak_last_turn") if peak: # Simplified: check if any high-intensity emotions fired last turn last_emos = state.last_emotions[-1] if state.last_emotions else set() peak_emos = set(trigger.get("any_emotion_last_turn", [])) if not (last_emos & peak_emos): return False # Check 'all' conditions (all must pass) all_conds = trigger.get("all", []) for cond in all_conds: if isinstance(cond, dict): if "node" in cond: if not _eval_node_condition(cond, vector, state): return False if "emotion_variance_window" in cond: # Check low emotional variance over N turns window = cond.get("emotion_variance_window", 3) threshold = cond.get("val", 0.1) if len(state.last_emotions) < window: return False recent = state.last_emotions[-window:] # Variance = difference in emotion sets across turns if len(recent) >= 2: all_emos = set() for s in recent: all_emos |= s if len(all_emos) > 3: # more than 3 unique emotions = too varied return False # Check 'any' conditions (at least one must pass) any_conds = trigger.get("any", []) if any_conds: any_passed = False for cond in any_conds: if isinstance(cond, dict): if "node" in cond: if _eval_node_condition(cond, vector, state): any_passed = True break if "emotion_and" in cond or "emotion" in cond: emo = cond.get("emotion") or cond.get("emotion_and", {}).get( "emotion", "" ) if emo in active_emotions: sub_cond = cond.get("condition") or cond.get( "emotion_and", {} ).get("condition") if sub_cond: if _eval_node_condition(sub_cond, vector, state): any_passed = True break else: any_passed = True break if not any_passed: return False # Check 'any_emotion' (at least one must be active) any_emo = trigger.get("any_emotion", []) if any_emo: if not (set(any_emo) & active_emotions): return False return True # ───────────────────────────────────────────────────────────────────── # CASCADE ENGINE # ─────────────────────────────────────────────────────────────────────
[docs] class CascadeEngine: """Drives multi-turn neurochemical cascades for the limbic system. Owns the per-channel cascade lifecycle: each turn it loads :class:`CascadeState` from Redis, checks interrupts on active cascades, advances their stages and accumulates stage deltas, applies synergy bonuses and the meta-systems (habituation/tolerance, sensitization, GABA polarity inversion), evaluates triggers for new cascades, and persists state back to Redis under ``ncm:cascades:{channel_id}``. Cascade definitions and meta configs are loaded once via :func:`_load_cascade_defs` and :func:`_load_meta_configs`; stage cue text is humanized through an optional ``CueVariantCache``, and the returned delta vector feeds back into the NCM state. Instantiated once by ``limbic_system.coordinator.LimbicCoordinator`` (its ``cascade_engine`` attribute), which calls :meth:`tick` and :meth:`get_active_cascades` from ``exhale()`` each turn. """
[docs] def __init__( self, redis_client=None, variant_cache: "CueVariantCache | None" = None ): """Initialize the instance. Args: redis_client: Redis connection client. variant_cache ('CueVariantCache | None'): The variant cache value. """ self._redis = redis_client self._variant_cache = variant_cache self._defs = _load_cascade_defs() self._meta = _load_meta_configs() # Pre-compute meta-system settings self._hab = self._meta.get("_meta_habituation", {}) self._gaba = self._meta.get("_meta_gaba_polarity", {})
def _pick(self, raw) -> str: """Return a variant of *raw*, scheduling generation if not yet cached. Handles both plain strings and YAML list values transparently. On first use the original text is returned immediately while LLM generation runs in the background. """ s: str = random.choice(raw) if isinstance(raw, list) else (raw or "") if not s: return "" if self._variant_cache: asyncio.create_task(self._variant_cache.ensure_cached(s)) return self._variant_cache.get_variant(s) return s def _habituation_multiplier(self, cascade_id: str, state: CascadeState) -> float: """Compute the delta-magnitude scaling for a habituated cascade. Models tolerance: once a cascade has fired at least ``habituation_onset`` times (tracked in ``state.fire_counts``), each further fire shrinks its emitted deltas by an exponential ``habituation_decay`` curve, floored at ``habituation_floor``, so repeated cascades hit progressively softer. Reads only the ``_meta_habituation`` config (``self._hab``) and ``state``; returns ``1.0`` unchanged when habituation is disabled or the cascade is still below onset. Called by :meth:`tick` while accumulating each active cascade's stage delta. Args: cascade_id: Cascade whose fire history is consulted. state: Channel cascade state supplying ``fire_counts``. Returns: A multiplier in ``[habituation_floor, 1.0]`` applied to stage deltas. """ if not self._hab.get("enabled"): return 1.0 fires = state.fire_counts.get(cascade_id, 0) onset = self._hab.get("habituation_onset", 5) if fires < onset: return 1.0 decay = self._hab.get("habituation_decay", 0.85) floor = self._hab.get("habituation_floor", 0.4) excess = fires - onset return max(floor, decay**excess) def _sensitization_adjustment(self, cascade_id: str, state: CascadeState) -> float: """Compute the trigger-threshold easing from sensitization. Models the complementary effect to tolerance: the more a cascade has fired (``state.fire_counts``), the easier it becomes to re-trigger, scaled linearly by ``sensitization_rate`` and clamped to ``sensitization_cap``. :meth:`tick` adds the returned amount to the NCM vector values before evaluating that cascade's trigger, effectively lowering its thresholds. Reads only ``self._hab`` and ``state`` and returns ``0.0`` when habituation is disabled. Args: cascade_id: Cascade whose fire history is consulted. state: Channel cascade state supplying ``fire_counts``. Returns: A non-negative offset in ``[0.0, sensitization_cap]`` added to vector values for trigger evaluation. """ if not self._hab.get("enabled"): return 0.0 fires = state.fire_counts.get(cascade_id, 0) rate = self._hab.get("sensitization_rate", 0.02) cap = self._hab.get("sensitization_cap", 0.30) return min(cap, fires * rate) def _habituation_should_skip(self, cascade_id: str, state: CascadeState) -> bool: """Probabilistic tolerance gate — habituated cascades may not fire. After habituation_onset fires, each subsequent fire has a decreasing probability of actually triggering. Uses the same decay curve as delta habituation so the two stay in sync. """ if not self._hab.get("enabled"): return False fires = state.fire_counts.get(cascade_id, 0) onset = self._hab.get("habituation_onset", 5) if fires < onset: return False decay = self._hab.get("habituation_decay", 0.85) floor = self._hab.get("habituation_floor", 0.4) excess = fires - onset fire_prob = max(floor, decay**excess) if random.random() > fire_prob: logger.debug( "Habituation skip: %s (fires=%d, prob=%.2f)", cascade_id, fires, fire_prob, ) return True return False def _check_gaba_inversion(self, vector: Dict[str, float]) -> bool: """Test whether the chloride gradient flips GABA from inhibitory to excitatory. Models developmental/stress GABA polarity reversal: when the NKCC1 chloride-importer node exceeds the KCC2 exporter node by more than ``inversion_threshold``, GABAergic signaling is treated as depolarizing for this turn. Reads node names and the threshold from the ``_meta_gaba_polarity`` config (``self._gaba``) and the passed ``vector``; returns ``False`` when GABA polarity meta is disabled. Called once per turn by :meth:`tick`, whose result gates :meth:`_apply_gaba_inversion`. Args: vector: Current NCM state vector (queried for the NKCC1/KCC2 nodes). Returns: ``True`` when the chloride differential exceeds the inversion threshold. """ if not self._gaba.get("enabled"): return False nkcc1 = vector.get(self._gaba.get("nkcc1_node", "NKCC1_CHLORIDE"), 0.5) kcc2 = vector.get(self._gaba.get("kcc2_node", "KCC2_CHLORIDE"), 0.5) threshold = self._gaba.get("inversion_threshold", 0.15) return (nkcc1 - kcc2) > threshold def _apply_gaba_inversion( self, delta: Dict[str, float], cascade_id: str ) -> Dict[str, float]: """Sign-flip GABA-related tokens in a stage delta when polarity is inverted. Applied only after :meth:`_check_gaba_inversion` reports a flip and only to cascades listed in the meta config's ``affected_cascades``: for those, any token in ``inverted_tokens`` has its delta negated (excitatory becomes inhibitory and vice versa) while all other tokens pass through unchanged. Reads ``self._gaba`` and returns the input ``delta`` untouched when the cascade is not affected. Called by :meth:`tick` while building each active cascade's stage delta. Args: delta: The stage's node-to-magnitude delta mapping. cascade_id: Cascade the delta belongs to, checked against ``affected_cascades``. Returns: A new delta mapping with affected tokens sign-flipped, or the original mapping when the cascade is unaffected. """ affected = self._gaba.get("affected_cascades", []) if cascade_id not in affected: return delta tokens = set(self._gaba.get("inverted_tokens", [])) inverted = {} for k, v in delta.items(): if k in tokens: inverted[k] = -v # Sign flip else: inverted[k] = v return inverted async def _load_state(self, channel_id: str) -> CascadeState: """Fetch and deserialize a channel's cascade state from Redis. Reads the JSON blob at ``ncm:cascades:{channel_id}`` via the injected async Redis client and rebuilds a :class:`CascadeState` through :meth:`CascadeState.from_dict`. Returns a fresh empty state when no Redis client is configured, no value exists, or deserialization fails (the error is logged at debug level and swallowed so a turn never crashes on bad state). Called by :meth:`tick`, :meth:`get_active_cascades`, and :meth:`force_trigger`. Args: channel_id: Channel whose state key is read. Returns: The persisted :class:`CascadeState`, or a new empty one on miss or error. """ if self._redis: try: key = REDIS_CASCADE_KEY.format(channel_id=channel_id) raw = await self._redis.get(key) if raw: return CascadeState.from_dict(json.loads(raw)) except Exception as e: logger.debug("Cascade state load error: %s", e) return CascadeState() async def _save_state(self, channel_id: str, state: CascadeState): """Serialize and persist a channel's cascade state to Redis. Writes ``state`` (via :meth:`CascadeState.to_dict`, which already trims its rolling histories) as JSON to ``ncm:cascades:{channel_id}`` with a ``REDIS_CASCADE_TTL`` (24h) expiry. No-ops when no Redis client is configured and logs-and-swallows any write error at debug level so a turn is never lost to a transient Redis failure. Called by :meth:`tick` and :meth:`force_trigger` after they mutate state. Args: channel_id: Channel whose state key is written. state: The cascade state to persist. """ if self._redis: try: key = REDIS_CASCADE_KEY.format(channel_id=channel_id) await self._redis.set( key, json.dumps(state.to_dict()), ex=REDIS_CASCADE_TTL ) except Exception as e: logger.debug("Cascade state save error: %s", e)
[docs] async def tick( self, channel_id: str, vector: Dict[str, float], active_emotions: Set[str], delta_count: int = 0, ) -> Dict[str, float]: """Execute one turn of cascade processing. Called during exhale(). Returns combined delta vector from all active cascade stages this turn. Parameters ---------- channel_id : str The channel being processed. vector : Dict[str, float] Current NCM state vector (post-emotion-deltas). active_emotions : Set[str] Emotions that fired this turn. delta_count : int Number of emotion deltas applied this turn. Returns ------- Dict[str, float] Combined delta vector from cascade processing this turn. """ state = await self._load_state(channel_id) state.turn_count += 1 state.delta_counts.append(delta_count) state.last_emotions.append(active_emotions.copy()) # Trim rolling history state.delta_counts = state.delta_counts[-10:] state.last_emotions = state.last_emotions[-5:] # Update node history for sustained-condition checks for node, val in vector.items(): if node not in state.node_history: state.node_history[node] = [] state.node_history[node].append(val) state.node_history[node] = state.node_history[node][-10:] combined_delta: Dict[str, float] = {} cascades_to_trigger: List[str] = [] completed_this_turn: Set[str] = set() # 💀 grace period tracking # ── Pre-compute GABA polarity for this turn (local: no cross-tick bleed) ── gaba_inverted = self._check_gaba_inversion(vector) # ── Phase 1: Check interrupts on active cascades ────────── for cid in list(state.active.keys()): info = state.active[cid] cdef = self._defs.get(cid) if not cdef: del state.active[cid] continue # Skip paused cascades if info.get("paused"): pause_remaining = info.get("pause_remaining", 0) if pause_remaining > 0: info["pause_remaining"] = pause_remaining - 1 continue else: info["paused"] = False interrupted = False for intr in cdef.get("interrupt", []): icond = intr.get("condition", {}) triggered = False # Emotion-based interrupt if "emotion" in icond: if icond["emotion"] in active_emotions: triggered = True # Node-based interrupt elif "node" in icond: triggered = _eval_node_condition(icond, vector, state) # Delta count sustained elif "delta_count_sustained" in icond: dc_op = icond.get("op", "<") dc_val = icond.get("val", 2) dc_for = icond.get("for_turns", 3) recent = state.delta_counts[-dc_for:] if len(recent) >= dc_for: if dc_op == "<" and all(d < dc_val for d in recent): triggered = True # Cascade triggered elif icond.get("cascade_triggered"): # Will be checked after new trigger phase pass if triggered: action = intr.get("action", "abort") if action == "abort": logger.info( "Cascade %s ABORTED: %s", cid, self._pick(intr.get("reason", "")), ) state.history.append(f"{cid}:aborted:turn_{state.turn_count}") del state.active[cid] interrupted = True break elif action == "pause": pt = intr.get("pause_turns", 2) info["paused"] = True info["pause_remaining"] = pt logger.info( "Cascade %s PAUSED for %d turns: %s", cid, pt, self._pick(intr.get("reason", "")), ) interrupted = True break elif action == "skip_to_stage": target = intr.get("target", "") stages = cdef.get("stages", []) stage_names = [s["name"] for s in stages] if target in stage_names: info["stage"] = stage_names.index(target) info["turns_in_stage"] = 0 logger.info( "Cascade %s SKIP to %s: %s", cid, target, self._pick(intr.get("reason", "")), ) interrupted = True break elif action == "trigger_cascade": target = intr.get("target", "") cascades_to_trigger.append(target) logger.info( "Cascade %s triggers %s: %s", cid, target, self._pick(intr.get("reason", "")), ) # Also abort or complete the current cascade state.history.append(f"{cid}:completed:turn_{state.turn_count}") del state.active[cid] interrupted = True break if interrupted: continue # ── Phase 2: Advance active cascades & collect deltas ───── for cid in list(state.active.keys()): info = state.active[cid] if info.get("paused"): continue cdef = self._defs.get(cid) if not cdef: continue stages = cdef.get("stages", []) stage_idx = info.get("stage", 0) if stage_idx >= len(stages): # Cascade complete logger.info("Cascade %s COMPLETED at turn %d", cid, state.turn_count) state.history.append(f"{cid}:completed:turn_{state.turn_count}") state.cooldowns[cid] = cdef.get("cooldown", 0) # Track fire count for habituation + tolerance reversal state.fire_counts[cid] = state.fire_counts.get(cid, 0) + 1 state.last_fired_turn[cid] = state.turn_count completed_this_turn.add(cid) # 💀 grace period del state.active[cid] continue stage = stages[stage_idx] hold = stage.get("hold_turns", 1) turns_in = info.get("turns_in_stage", 0) # Skip conditional stages unless explicitly jumped to if stage.get("conditional") and turns_in == 0: info["stage"] = stage_idx + 1 info["turns_in_stage"] = 0 continue # Apply this stage's delta # 🌿 Sativa/Indica gradient interpolation # If the stage provides delta_vector_sativa and # delta_vector_indica, lerp between them using the # cascade's strain_gradient (0.0=indica, 1.0=sativa). # Falls back to base delta_vector when poles are absent. sativa_d = stage.get("delta_vector_sativa") indica_d = stage.get("delta_vector_indica") strain_g = info.get( "strain_gradient", cdef.get("strain_gradient"), ) if sativa_d and indica_d and strain_g is not None: g = max(0.0, min(1.0, float(strain_g))) all_nodes = set(indica_d.keys()) | set(sativa_d.keys()) stage_delta = {} for node in all_nodes: iv = indica_d.get(node, 0.0) sv = sativa_d.get(node, 0.0) stage_delta[node] = iv + g * (sv - iv) else: stage_delta = dict(stage.get("delta_vector", {})) # ── Meta: GABA polarity inversion ────────────── if gaba_inverted: stage_delta = self._apply_gaba_inversion(stage_delta, cid) # ── Meta: Habituation scaling ────────────────── hab_mult = self._habituation_multiplier(cid, state) if hab_mult < 1.0: stage_delta = {k: v * hab_mult for k, v in stage_delta.items()} for k, v in stage_delta.items(): combined_delta[k] = combined_delta.get(k, 0.0) + v # Log the cue -- select pole-appropriate variant if strain_g is not None and float(strain_g) > 0.65: raw_cue = stage.get("cue_sativa", stage.get("cue", "")) elif strain_g is not None and float(strain_g) < 0.35: raw_cue = stage.get("cue_indica", stage.get("cue", "")) else: raw_cue = stage.get("cue", "") if turns_in == 0 and raw_cue: logger.info( "Cascade %s [%s]: %s", cid, stage["name"], self._pick(raw_cue), ) # Advance info["turns_in_stage"] = turns_in + 1 info["total_turns"] = info.get("total_turns", 0) + 1 # Check if stage is done if info["turns_in_stage"] >= hold: info["stage"] = stage_idx + 1 info["turns_in_stage"] = 0 # Check max duration max_dur = cdef.get("max_duration", 30) if info["total_turns"] >= max_dur: logger.info( "Cascade %s hit max_duration (%d), completing", cid, max_dur ) state.history.append(f"{cid}:completed:turn_{state.turn_count}") state.cooldowns[cid] = cdef.get("cooldown", 0) # Track fire count and last-fired turn for habituation state.fire_counts[cid] = state.fire_counts.get(cid, 0) + 1 state.last_fired_turn[cid] = state.turn_count completed_this_turn.add(cid) # 💀 grace period del state.active[cid] # ── Phase 3: Check synergies ────────────────────────────── active_ids = set(state.active.keys()) for cid in list(active_ids): cdef = self._defs.get(cid, {}) for syn in cdef.get("synergy", []): partner = syn.get("with", "") if partner in active_ids: syn_delta = syn.get("delta_vector", {}) for k, v in syn_delta.items(): combined_delta[k] = combined_delta.get(k, 0.0) + v logger.debug( "Synergy: %s + %s%s", cid, partner, self._pick(syn.get("reason", "")), ) # ── Phase 4: Decay cooldowns + tolerance reversal ────────── for cid in list(state.cooldowns.keys()): state.cooldowns[cid] -= 1 if state.cooldowns[cid] <= 0: del state.cooldowns[cid] # Tolerance reversal: decay fire_counts for cascades that # haven't fired in a while. recovery_turns_per_count turns # of abstinence = 1 fire_count recovered. if self._hab.get("enabled"): recovery_rate = self._hab.get("recovery_turns_per_count", 15) for cid in list(state.fire_counts.keys()): if state.fire_counts[cid] <= 0: del state.fire_counts[cid] state.last_fired_turn.pop(cid, None) continue # Skip if cascade is currently active if cid in state.active: continue last_fired = state.last_fired_turn.get(cid, 0) turns_idle = state.turn_count - last_fired if turns_idle > 0 and recovery_rate > 0: recovered = turns_idle // recovery_rate if recovered > 0: old = state.fire_counts[cid] state.fire_counts[cid] = max(0, old - recovered) # Reset the clock so we don't double-count state.last_fired_turn[cid] = ( last_fired + recovered * recovery_rate ) if state.fire_counts[cid] < old: logger.debug( "Tolerance reversal: %s fire_count %d%d" " (%d turns idle)", cid, old, state.fire_counts[cid], turns_idle, ) # ── Phase 5: Check triggers for new cascades ────────────── if len(state.active) < MAX_CONCURRENT_CASCADES: for cid, cdef in self._defs.items(): if cid in state.active: continue if cid in state.cooldowns: continue # 💀 Grace period: don't re-trigger a cascade that # just completed THIS turn — let metabolic decay and # other systems process first. if cid in completed_this_turn: continue # 🌀 Habituation skip: heavily-fired cascades may # probabilistically fail to trigger (true tolerance). if self._habituation_should_skip(cid, state): continue trigger = cdef.get("trigger", {}) # ── Meta: Sensitization threshold adjustment ── sens = self._sensitization_adjustment(cid, state) adjusted_vector = vector if sens > 0.0: # Effectively lower all trigger thresholds by reducing # the comparison point (make current values appear higher) adjusted_vector = {k: min(3.0, v + sens) for k, v in vector.items()} if _eval_trigger( trigger, adjusted_vector, active_emotions, state, active_ids ): state.active[cid] = { "stage": 0, "turns_in_stage": 0, "started_at_turn": state.turn_count, "total_turns": 0, "paused": False, } logger.info( "Cascade %s TRIGGERED at turn %d", cid, state.turn_count ) state.history.append(f"{cid}:triggered:turn_{state.turn_count}") if len(state.active) >= MAX_CONCURRENT_CASCADES: break # Handle manually triggered cascades (from interrupt actions) for cid in cascades_to_trigger: if cid in state.active or cid in state.cooldowns: continue if len(state.active) >= MAX_CONCURRENT_CASCADES: break if cid in self._defs: state.active[cid] = { "stage": 0, "turns_in_stage": 0, "started_at_turn": state.turn_count, "total_turns": 0, "paused": False, } logger.info( "Cascade %s CHAIN-TRIGGERED at turn %d", cid, state.turn_count ) state.history.append(f"{cid}:chain_triggered:turn_{state.turn_count}") # ── Save state ──────────────────────────────────────────── await self._save_state(channel_id, state) # Clamp combined deltas for k in combined_delta: combined_delta[k] = max(-3.0, min(3.0, combined_delta[k])) return combined_delta
[docs] async def get_active_cascades(self, channel_id: str) -> Dict[str, Dict[str, Any]]: """Summarize a channel's in-flight cascades for prompt-context injection. Loads the channel's :class:`CascadeState` (read-only; it does not advance anything) and, for each active cascade, resolves the current stage name and a humanized cue line through :meth:`_pick` (the variant cache), plus how many turns it has run and whether it is paused. The coordinator's ``exhale()`` calls this right after :meth:`tick` and forwards the non-paused cues into the limbic context so active cascades can color the bot's reply. Args: channel_id: Channel whose active cascades are summarized. Returns: Mapping of cascade id to a dict with ``stage``, ``turn``, ``cue``, and ``paused`` (empty when nothing is active). """ state = await self._load_state(channel_id) result = {} for cid, info in state.active.items(): cdef = self._defs.get(cid, {}) stages = cdef.get("stages", []) stage_idx = info.get("stage", 0) stage_name = ( stages[stage_idx]["name"] if stage_idx < len(stages) else "done" ) raw_cue = ( stages[stage_idx].get("cue", "") if stage_idx < len(stages) else "" ) result[cid] = { "stage": stage_name, "turn": info.get("total_turns", 0), "cue": self._pick(raw_cue), "paused": info.get("paused", False), } return result
[docs] async def force_trigger( self, channel_id: str, cascade_id: str, strain_gradient: float | None = None, ) -> bool: """Manually trigger a cascade (e.g. from a tool call). Args: channel_id: Target channel. cascade_id: Cascade to trigger. strain_gradient: Optional sativa/indica gradient (0.0-1.0) for ENDOCANNABINOID_DRIFT bipolar interpolation. """ if cascade_id not in self._defs: return False state = await self._load_state(channel_id) if cascade_id in state.active: return False if len(state.active) >= MAX_CONCURRENT_CASCADES: return False active_info = { "stage": 0, "turns_in_stage": 0, "started_at_turn": state.turn_count, "total_turns": 0, "paused": False, } # 🌿 Attach strain gradient if provided if strain_gradient is not None: active_info["strain_gradient"] = max(0.0, min(1.0, strain_gradient)) state.active[cascade_id] = active_info state.history.append(f"{cascade_id}:force_triggered:turn_{state.turn_count}") await self._save_state(channel_id, state) logger.info( "Cascade %s FORCE-TRIGGERED (gradient=%s)", cascade_id, strain_gradient ) return True