Source code for core.state_machine

"""Per-operation state machine with crash-safe checkpoints.

:class:`OperationStateMachine` tracks each trace through a validated
lifecycle — ``RECEIVED`` -> ``PRE_INFERENCE_GATHER`` -> ``INFERRING``
-> ``TOOL_EXECUTING`` / ``POSTPROCESSING`` -> ``DELIVERING`` ->
``COMPLETED`` (plus ``ERRORED`` / ``RETRYING`` / ``RECLAIMED``) — and
persists compressed, schema-versioned checkpoints so a reclaimed
message can resume without redoing expensive gather/inference work.
"""

import json
import zlib
import time
import logging
from typing import ClassVar

logger = logging.getLogger(__name__)

CHECKPOINT_SCHEMA_VERSION = 1
ABSOLUTE_FAILSAFE_TTL = 86400

[docs] class OperationStateMachine: _VALID_PREDECESSORS: ClassVar[dict[str, list[str]]] = { "RECEIVED": [], "PREPROCESSING": ["RECEIVED"], "QUEUED": ["RECEIVED", "PREPROCESSING", "RETRYING"], "CLAIMED": ["QUEUED", "RECLAIMED"], "PRE_INFERENCE_GATHER": ["CLAIMED"], "INFERRING": ["PRE_INFERENCE_GATHER", "TOOL_EXECUTING", "POSTPROCESSING"], "TOOL_EXECUTING": ["INFERRING"], "POSTPROCESSING": ["INFERRING"], "DELIVERING": ["POSTPROCESSING"], "DELIVERED": ["DELIVERING"], "COMPLETED": ["DELIVERED", "ERRORED"], "ERRORED": ["RECEIVED", "PREPROCESSING", "QUEUED", "CLAIMED", "PRE_INFERENCE_GATHER", "INFERRING", "TOOL_EXECUTING", "POSTPROCESSING", "DELIVERING", "DELIVERED"], "RETRYING": ["ERRORED"], "RECLAIMED": ["CLAIMED", "PRE_INFERENCE_GATHER", "INFERRING", "TOOL_EXECUTING", "POSTPROCESSING"], }
[docs] @classmethod def is_valid_transition(cls, current: str, target: str) -> bool: """Report whether a move from ``current`` to ``target`` is legal. Consults the class-level :attr:`_VALID_PREDECESSORS` adjacency map, which encodes the operation lifecycle, and returns ``True`` only when ``current`` appears in ``target``\\ 's allowed-predecessor list. Used to guard against out-of-order or skipped lifecycle steps. This is a pure lookup with no side effects. Within the repo it is exercised only by ``tests/core/migration/test_state_machine.py``; no production caller invokes it directly. Note that neither :meth:`transition` nor :func:`write_checkpoint_and_transition` consults this guard, and the Lua script in ``scripts/state_transition.lua`` deliberately skips re-validation and just ``HSET``\\ s the new state for atomicity, so callers needing the lifecycle check must call this method themselves. Args: current: The state the trace is currently in (e.g. ``"CLAIMED"``). target: The state being transitioned to (e.g. ``"INFERRING"``). Returns: ``True`` if ``current`` is a valid predecessor of ``target``, otherwise ``False`` (including when ``target`` is unknown). """ return current in cls._VALID_PREDECESSORS.get(target, [])
[docs] def __init__(self, redis): """Bind the state machine to a Redis client for trace updates. Stores ``redis`` on ``self._redis`` for later use by :meth:`transition`, which mutates the ``sg:trace:{trace_id}`` hash. Instances are wired into :class:`core.stream_consumer.InboundStreamConsumer`, which receives the state machine via its ``state_machine`` constructor argument (stored as ``self._state_machine``) and calls :meth:`transition` as messages are claimed (``CLAIMED``), completed (``COMPLETED``), or errored (``ERRORED``); no internal call site constructing ``OperationStateMachine`` itself was found outside the tests. Args: redis: An async Redis client (``redis.asyncio``-style) supporting ``exists`` and ``hset``. """ self._redis = redis
[docs] async def transition(self, trace_id: str, target_state: str, metadata: dict | None = None) -> bool: """Move a trace to ``target_state``, recording it for observability. Advances the lifecycle stamp on a trace by writing ``state`` (plus any extra ``metadata`` fields) into the trace hash, so operators and the web dashboard can see where each operation is in the ``RECEIVED`` -> ... -> ``COMPLETED`` pipeline. It first checks the trace still exists and skips the write if not, which guards against resurrecting a trace whose hash has already expired or been reaped. Touches Redis: an ``exists`` probe followed by an ``hset`` on ``sg:trace:{trace_id}`` (on the client passed at construction). Unlike the Lua-script path in :func:`write_checkpoint_and_transition`, this does not consult :meth:`is_valid_transition` -- it writes whatever ``target_state`` it is given. Called by :class:`core.stream_consumer.InboundStreamConsumer` (which holds the state machine as ``self._state_machine``) as messages are claimed (``CLAIMED``), completed (``COMPLETED``), and errored (``ERRORED``). Args: trace_id: Identifier selecting the ``sg:trace:{trace_id}`` hash. target_state: New lifecycle state to record (e.g. ``"CLAIMED"``). metadata: Optional extra fields to merge into the trace hash alongside ``state`` (e.g. ``{"msg_id": ...}``). Returns: bool: ``True`` if the trace existed and was updated, ``False`` if no trace hash was found (the write is skipped). """ mapping = {"state": target_state} if metadata: mapping.update(metadata) exists = await self._redis.exists(f"sg:trace:{trace_id}") if not exists: return False await self._redis.hset(f"sg:trace:{trace_id}", mapping=mapping) return True
[docs] async def write_checkpoint_and_transition( redis, trace_id: str, gather_output: dict, target_state: str, script_sha: str ): """Atomically write checkpoint AND transition state in one network hop. Persist the gather-phase output as a compressed, schema-versioned checkpoint and advance the trace to ``target_state`` in a single pipelined round trip, so a later reclaim can resume from the checkpoint instead of redoing expensive pre-inference gather work. The payload is JSON-encoded then zlib-compressed (level 6) before storage. Builds a Redis pipeline that ``hset``\\ s the checkpoint hash ``sg:checkpoint:{trace_id}`` (schema version, timestamp, compressed payload, and original/compressed sizes), sets that key's TTL to :data:`ABSOLUTE_FAILSAFE_TTL` seconds, and runs the caller-supplied Lua CAS script via ``evalsha(script_sha, ...)`` against ``sg:trace:{trace_id}`` to perform the validated state transition. Logs a ``checkpoint_saved`` line with the achieved compression ratio. Within the repo it is exercised only by ``tests/core/migration/test_state_machine.py``; no production caller was found, so ``script_sha`` (the SHA of a pre-loaded transition script) is expected to be supplied by the inference worker's checkpoint flow. Args: redis: An async Redis client whose ``pipeline()`` supports ``hset``, ``expire``, and ``evalsha``. trace_id: Identifier of the operation; selects both the ``sg:checkpoint:{trace_id}`` and ``sg:trace:{trace_id}`` keys. gather_output: JSON-serializable dict of pre-inference gather results to checkpoint. target_state: Lifecycle state to transition the trace into (passed as the sole argument to the Lua CAS script). script_sha: SHA1 of a server-side Lua script (loaded via ``SCRIPT LOAD``) that performs the validated state transition on ``sg:trace:{trace_id}``. Returns: list: The pipeline's results list, one entry per queued command (``hset``, ``expire``, ``evalsha``). """ # Compress payload raw_payload = json.dumps(gather_output).encode('utf-8') compressed = zlib.compress(raw_payload, level=6) checkpoint = { "schema_version": str(CHECKPOINT_SCHEMA_VERSION), "checkpoint_ts": str(time.time()), "payload_compressed": compressed, "payload_size_original": str(len(raw_payload)), "payload_size_compressed": str(len(compressed)), } pipe = redis.pipeline() pipe.hset(f"sg:checkpoint:{trace_id}", mapping=checkpoint) pipe.expire(f"sg:checkpoint:{trace_id}", ABSOLUTE_FAILSAFE_TTL) pipe.evalsha( script_sha, 1, # numkeys f"sg:trace:{trace_id}", target_state, ) results = await pipe.execute() logger.info( "checkpoint_saved trace_id=%s original_bytes=%d compressed_bytes=%d compression_ratio=%.1f%%", trace_id, len(raw_payload), len(compressed), (len(compressed)/len(raw_payload))*100 if len(raw_payload) > 0 else 0, ) return results
[docs] async def recover_checkpoint(redis, trace_id: str) -> dict | None: """Load and decompress a previously saved gather checkpoint, if usable. Read the checkpoint written by :func:`write_checkpoint_and_transition`, validate its schema version, and return the original gather output so a reclaimed operation can resume without redoing the gather phase. Returns ``None`` when no checkpoint exists or when the stored checkpoint predates the current :data:`CHECKPOINT_SCHEMA_VERSION` (forcing a fresh regather). Calls ``redis.hgetall`` on ``sg:checkpoint:{trace_id}``, normalizes any byte-keyed fields to ``str``, and on a current-schema hit zlib-decompresses the ``payload_compressed`` field and ``json.loads`` it back into a dict. A stale schema logs ``checkpoint_schema_stale`` and yields ``None``; a successful recovery logs ``checkpoint_recovered`` with the checkpoint age. Within the repo it is exercised only by ``tests/core/migration/test_state_machine.py``; no production caller was found, so it is expected to be driven by the inference worker's reclaim path. Args: redis: An async Redis client supporting ``hgetall``. trace_id: Identifier of the operation whose checkpoint to load (selects ``sg:checkpoint:{trace_id}``). Returns: dict | None: The recovered gather output dict, or ``None`` if the checkpoint is missing or its schema is older than the current version. """ checkpoint = await redis.hgetall(f"sg:checkpoint:{trace_id}") if not checkpoint: return None # Convert byte keys to strings if necessary processed = {} for k, v in checkpoint.items(): key_str = k.decode('utf-8') if isinstance(k, bytes) else str(k) processed[key_str] = v schema_ver = int(processed.get("schema_version", b"0").decode() if isinstance(processed.get("schema_version", b"0"), bytes) else processed.get("schema_version", 0)) if schema_ver < CHECKPOINT_SCHEMA_VERSION: logger.warning( "checkpoint_schema_stale trace_id=%s checkpoint_version=%d current_version=%d", trace_id, schema_ver, CHECKPOINT_SCHEMA_VERSION ) return None # Require regather # Decompress payload compressed = processed["payload_compressed"] raw = zlib.decompress(compressed) gather_output = json.loads(raw) logger.info( "checkpoint_recovered trace_id=%s checkpoint_age_s=%.2f", trace_id, time.time() - float(processed["checkpoint_ts"]), ) return gather_output