"""Per-operation state machine with crash-safe checkpoints.
:class:`OperationStateMachine` tracks each trace through a validated
lifecycle — ``RECEIVED`` -> ``PRE_INFERENCE_GATHER`` -> ``INFERRING``
-> ``TOOL_EXECUTING`` / ``POSTPROCESSING`` -> ``DELIVERING`` ->
``COMPLETED`` (plus ``ERRORED`` / ``RETRYING`` / ``RECLAIMED``) — and
persists compressed, schema-versioned checkpoints so a reclaimed
message can resume without redoing expensive gather/inference work.
"""
import json
import zlib
import time
import logging
from typing import ClassVar
logger = logging.getLogger(__name__)
CHECKPOINT_SCHEMA_VERSION = 1
ABSOLUTE_FAILSAFE_TTL = 86400
[docs]
class OperationStateMachine:
_VALID_PREDECESSORS: ClassVar[dict[str, list[str]]] = {
"RECEIVED": [],
"PREPROCESSING": ["RECEIVED"],
"QUEUED": ["RECEIVED", "PREPROCESSING", "RETRYING"],
"CLAIMED": ["QUEUED", "RECLAIMED"],
"PRE_INFERENCE_GATHER": ["CLAIMED"],
"INFERRING": ["PRE_INFERENCE_GATHER",
"TOOL_EXECUTING", "POSTPROCESSING"],
"TOOL_EXECUTING": ["INFERRING"],
"POSTPROCESSING": ["INFERRING"],
"DELIVERING": ["POSTPROCESSING"],
"DELIVERED": ["DELIVERING"],
"COMPLETED": ["DELIVERED", "ERRORED"],
"ERRORED": ["RECEIVED", "PREPROCESSING", "QUEUED", "CLAIMED",
"PRE_INFERENCE_GATHER",
"INFERRING", "TOOL_EXECUTING",
"POSTPROCESSING", "DELIVERING", "DELIVERED"],
"RETRYING": ["ERRORED"],
"RECLAIMED": ["CLAIMED", "PRE_INFERENCE_GATHER",
"INFERRING", "TOOL_EXECUTING",
"POSTPROCESSING"],
}
[docs]
@classmethod
def is_valid_transition(cls, current: str, target: str) -> bool:
"""Report whether a move from ``current`` to ``target`` is legal.
Consults the class-level :attr:`_VALID_PREDECESSORS` adjacency map,
which encodes the operation lifecycle, and returns ``True`` only when
``current`` appears in ``target``\\ 's allowed-predecessor list. Used to
guard against out-of-order or skipped lifecycle steps.
This is a pure lookup with no side effects. Within the repo it is
exercised only by ``tests/core/migration/test_state_machine.py``; no
production caller invokes it directly. Note that neither
:meth:`transition` nor :func:`write_checkpoint_and_transition` consults
this guard, and the Lua script in ``scripts/state_transition.lua``
deliberately skips re-validation and just ``HSET``\\ s the new state for
atomicity, so callers needing the lifecycle check must call this method
themselves.
Args:
current: The state the trace is currently in (e.g. ``"CLAIMED"``).
target: The state being transitioned to (e.g. ``"INFERRING"``).
Returns:
``True`` if ``current`` is a valid predecessor of ``target``,
otherwise ``False`` (including when ``target`` is unknown).
"""
return current in cls._VALID_PREDECESSORS.get(target, [])
[docs]
def __init__(self, redis):
"""Bind the state machine to a Redis client for trace updates.
Stores ``redis`` on ``self._redis`` for later use by
:meth:`transition`, which mutates the ``sg:trace:{trace_id}`` hash.
Instances are wired into
:class:`core.stream_consumer.InboundStreamConsumer`, which receives the
state machine via its ``state_machine`` constructor argument (stored as
``self._state_machine``) and calls :meth:`transition` as messages are
claimed (``CLAIMED``), completed (``COMPLETED``), or errored
(``ERRORED``); no internal call site constructing
``OperationStateMachine`` itself was found outside the tests.
Args:
redis: An async Redis client (``redis.asyncio``-style) supporting
``exists`` and ``hset``.
"""
self._redis = redis
[docs]
async def transition(self, trace_id: str, target_state: str, metadata: dict | None = None) -> bool:
"""Move a trace to ``target_state``, recording it for observability.
Advances the lifecycle stamp on a trace by writing ``state`` (plus any
extra ``metadata`` fields) into the trace hash, so operators and the web
dashboard can see where each operation is in the
``RECEIVED`` -> ... -> ``COMPLETED`` pipeline. It first checks the trace
still exists and skips the write if not, which guards against resurrecting
a trace whose hash has already expired or been reaped.
Touches Redis: an ``exists`` probe followed by an ``hset`` on
``sg:trace:{trace_id}`` (on the client passed at construction). Unlike
the Lua-script path in :func:`write_checkpoint_and_transition`, this does
not consult :meth:`is_valid_transition` -- it writes whatever
``target_state`` it is given. Called by
:class:`core.stream_consumer.InboundStreamConsumer` (which holds the
state machine as ``self._state_machine``) as messages are claimed
(``CLAIMED``), completed (``COMPLETED``), and errored (``ERRORED``).
Args:
trace_id: Identifier selecting the ``sg:trace:{trace_id}`` hash.
target_state: New lifecycle state to record (e.g. ``"CLAIMED"``).
metadata: Optional extra fields to merge into the trace hash
alongside ``state`` (e.g. ``{"msg_id": ...}``).
Returns:
bool: ``True`` if the trace existed and was updated, ``False`` if no
trace hash was found (the write is skipped).
"""
mapping = {"state": target_state}
if metadata:
mapping.update(metadata)
exists = await self._redis.exists(f"sg:trace:{trace_id}")
if not exists:
return False
await self._redis.hset(f"sg:trace:{trace_id}", mapping=mapping)
return True
[docs]
async def write_checkpoint_and_transition(
redis, trace_id: str, gather_output: dict, target_state: str, script_sha: str
):
"""Atomically write checkpoint AND transition state in one network hop.
Persist the gather-phase output as a compressed, schema-versioned
checkpoint and advance the trace to ``target_state`` in a single pipelined
round trip, so a later reclaim can resume from the checkpoint instead of
redoing expensive pre-inference gather work. The payload is JSON-encoded
then zlib-compressed (level 6) before storage.
Builds a Redis pipeline that ``hset``\\ s the checkpoint hash
``sg:checkpoint:{trace_id}`` (schema version, timestamp, compressed payload,
and original/compressed sizes), sets that key's TTL to
:data:`ABSOLUTE_FAILSAFE_TTL` seconds, and runs the caller-supplied Lua CAS
script via ``evalsha(script_sha, ...)`` against ``sg:trace:{trace_id}`` to
perform the validated state transition. Logs a ``checkpoint_saved`` line
with the achieved compression ratio. Within the repo it is exercised only
by ``tests/core/migration/test_state_machine.py``; no production caller was
found, so ``script_sha`` (the SHA of a pre-loaded transition script) is
expected to be supplied by the inference worker's checkpoint flow.
Args:
redis: An async Redis client whose ``pipeline()`` supports ``hset``,
``expire``, and ``evalsha``.
trace_id: Identifier of the operation; selects both the
``sg:checkpoint:{trace_id}`` and ``sg:trace:{trace_id}`` keys.
gather_output: JSON-serializable dict of pre-inference gather results
to checkpoint.
target_state: Lifecycle state to transition the trace into (passed as
the sole argument to the Lua CAS script).
script_sha: SHA1 of a server-side Lua script (loaded via
``SCRIPT LOAD``) that performs the validated state transition on
``sg:trace:{trace_id}``.
Returns:
list: The pipeline's results list, one entry per queued command
(``hset``, ``expire``, ``evalsha``).
"""
# Compress payload
raw_payload = json.dumps(gather_output).encode('utf-8')
compressed = zlib.compress(raw_payload, level=6)
checkpoint = {
"schema_version": str(CHECKPOINT_SCHEMA_VERSION),
"checkpoint_ts": str(time.time()),
"payload_compressed": compressed,
"payload_size_original": str(len(raw_payload)),
"payload_size_compressed": str(len(compressed)),
}
pipe = redis.pipeline()
pipe.hset(f"sg:checkpoint:{trace_id}", mapping=checkpoint)
pipe.expire(f"sg:checkpoint:{trace_id}", ABSOLUTE_FAILSAFE_TTL)
pipe.evalsha(
script_sha,
1, # numkeys
f"sg:trace:{trace_id}",
target_state,
)
results = await pipe.execute()
logger.info(
"checkpoint_saved trace_id=%s original_bytes=%d compressed_bytes=%d compression_ratio=%.1f%%",
trace_id,
len(raw_payload),
len(compressed),
(len(compressed)/len(raw_payload))*100 if len(raw_payload) > 0 else 0,
)
return results
[docs]
async def recover_checkpoint(redis, trace_id: str) -> dict | None:
"""Load and decompress a previously saved gather checkpoint, if usable.
Read the checkpoint written by :func:`write_checkpoint_and_transition`,
validate its schema version, and return the original gather output so a
reclaimed operation can resume without redoing the gather phase. Returns
``None`` when no checkpoint exists or when the stored checkpoint predates
the current :data:`CHECKPOINT_SCHEMA_VERSION` (forcing a fresh regather).
Calls ``redis.hgetall`` on ``sg:checkpoint:{trace_id}``, normalizes any
byte-keyed fields to ``str``, and on a current-schema hit zlib-decompresses
the ``payload_compressed`` field and ``json.loads`` it back into a dict. A
stale schema logs ``checkpoint_schema_stale`` and yields ``None``; a
successful recovery logs ``checkpoint_recovered`` with the checkpoint age.
Within the repo it is exercised only by
``tests/core/migration/test_state_machine.py``; no production caller was
found, so it is expected to be driven by the inference worker's reclaim
path.
Args:
redis: An async Redis client supporting ``hgetall``.
trace_id: Identifier of the operation whose checkpoint to load
(selects ``sg:checkpoint:{trace_id}``).
Returns:
dict | None: The recovered gather output dict, or ``None`` if the
checkpoint is missing or its schema is older than the current version.
"""
checkpoint = await redis.hgetall(f"sg:checkpoint:{trace_id}")
if not checkpoint:
return None
# Convert byte keys to strings if necessary
processed = {}
for k, v in checkpoint.items():
key_str = k.decode('utf-8') if isinstance(k, bytes) else str(k)
processed[key_str] = v
schema_ver = int(processed.get("schema_version", b"0").decode() if isinstance(processed.get("schema_version", b"0"), bytes) else processed.get("schema_version", 0))
if schema_ver < CHECKPOINT_SCHEMA_VERSION:
logger.warning(
"checkpoint_schema_stale trace_id=%s checkpoint_version=%d current_version=%d",
trace_id, schema_ver, CHECKPOINT_SCHEMA_VERSION
)
return None # Require regather
# Decompress payload
compressed = processed["payload_compressed"]
raw = zlib.decompress(compressed)
gather_output = json.loads(raw)
logger.info(
"checkpoint_recovered trace_id=%s checkpoint_age_s=%.2f",
trace_id,
time.time() - float(processed["checkpoint_ts"]),
)
return gather_output