Source code for core.control_ops

"""Fleet-wide control-ops: restart / git-pull across every Stargazer service.

The monolith is gone, so an admin ``!restart_*`` / ``!bot_pull`` command can no
longer just act on one process. Instead the **gateway** authorizes the command
and broadcasts a small JSON op over the Redis pub/sub channel
:data:`CONTROL_OPS_CHANNEL`. Every service runs a :class:`ControlOpsDaemon`
(wired into :class:`core.service_base.StargazerService`) that decides whether the
op applies to it and self-restarts / self-pulls accordingly.

Why pub/sub (not a stream): a restart op must fan out to **all live** instances
at once and must NOT be replayed by a service that was down (that would cause a
restart storm on recovery). Pub/sub's "live subscribers only, no persistence" is
exactly the right delivery guarantee — and it matches the existing
``sg:channel:cancel`` control channel.

Ordering guarantee: the gateway holds the live Discord connection, so it always
restarts **last** (a longer grace) — its own daemon self-restarts only after the
publisher has flushed the ACK and aggregated replies.
"""

from __future__ import annotations

import asyncio
import json
import logging
import socket
import uuid
from typing import Optional

from core.ops_exec import (
    _deferred_systemctl_restart,
    run_git_pull,
    systemctl_restart,
)
from core.service_registry import SERVICE_TTL, heartbeat, register_service

logger = logging.getLogger("control_ops")

# Pub/sub channel every service subscribes to.
CONTROL_OPS_CHANNEL = "sg:control:ops"
# Per-instance idempotency marker prefix (SETNX, short TTL).
_SEEN_PREFIX = "sg:control:seen:"
# Reply-list prefix for ack aggregation (RPUSH by daemons, BLPOP by publisher).
_REPLY_PREFIX = "sg:control:reply:"
# Registry key prefix (mirrors core.service_registry).
_REGISTRY_PREFIX = "sg:registry:service:"

# The service tiers we can target (proxy is external; "all" = these tiers).
SERVICE_TIERS = ("gateway", "inference", "agents", "consolidation", "web", "tools")

# Authoritative command -> (op, target) map for cluster control commands. Kept in
# sync with the cluster-locus entries in message_processor.command_registry (which
# drive help/detection). ``!bot_restart`` is a back-compat alias for inference.
CONTROL_OPS_COMMANDS: dict[str, tuple[str, str]] = {
    "!restart_all": ("restart", "all"),
    "!restart_gateway": ("restart", "gateway"),
    "!restart_inference": ("restart", "inference"),
    "!restart_agents": ("restart", "agents"),
    "!restart_consolidation": ("restart", "consolidation"),
    "!restart_web": ("restart", "web"),
    "!restart_tools": ("restart", "tools"),
    "!bot_restart": ("restart", "inference"),
    "!proxy_restart": ("restart", "proxy"),
    "!bot_pull": ("pull", "all"),
}


[docs] def control_op_for(text: str) -> Optional[tuple[str, str]]: """Return ``(op, target)`` for a cluster control command, or None.""" if not text or not text.strip(): return None return CONTROL_OPS_COMMANDS.get(text.strip().lower().split()[0])
[docs] def is_control_ops_command(text: str) -> bool: """True if *text* is a fleet-wide control-ops command.""" return control_op_for(text) is not None
def _decode(raw) -> str: return raw.decode() if isinstance(raw, (bytes, bytearray)) else str(raw)
[docs] def unit_name_for(config, service_name: str) -> str: """Resolve the systemd unit for *service_name* (explicit map → prefix+name).""" names = getattr(config, "control_unit_names", None) or {} if service_name in names: return names[service_name] prefix = getattr(config, "control_unit_prefix", "stargazer-") return f"{prefix}{service_name}"
[docs] def fleet_units(config=None) -> list[str]: """Resolve every :data:`SERVICE_TIERS` tier to its systemd unit, in order. The single source of truth for "all the Stargazer service units" — the live microservices that replaced the retired ``stargazer`` / ``stargazer-swarm`` monolith. Each tier is resolved through :func:`unit_name_for`, so the result honours a deployment's ``control_unit_prefix`` / ``control_unit_names`` overrides; with the defaults it yields ``stargazer-gateway``, ``stargazer-inference``, ``stargazer-agents``, ``stargazer-consolidation``, ``stargazer-web``. *config* may be ``None`` (callers without a Config in hand), in which case the default prefix applies. Pure; no I/O. Used by the admin journal tail (``Config.resolved_journal_units``), the ``read_service_logs`` tool, the log→RAG ingest task, and the ``bot_control`` restart tools so they all target the same fleet. Args: config: Optional Config-like object carrying ``control_unit_prefix`` / ``control_unit_names``; ``None`` falls back to the default prefix. Returns: list[str]: Resolved systemd unit names, one per tier, in :data:`SERVICE_TIERS` order. """ return [unit_name_for(config, tier) for tier in SERVICE_TIERS]
# ────────────────────────────────────────────── # Per-service daemon # ──────────────────────────────────────────────
[docs] class ControlOpsDaemon: """Subscribes to :data:`CONTROL_OPS_CHANNEL` and acts on matching ops. One runs in every service (started by ``StargazerService.boot``). It mirrors the worker cancellation daemon's pub/sub loop, but lives at the service-base level so all five tiers participate in fleet-wide restart / pull. """ def __init__(self, svc) -> None: self.svc = svc @property def _cfg(self): return getattr(self.svc, "cfg", None) def _cfg_val(self, name: str, default): cfg = self._cfg if cfg is None: return default return getattr(cfg, name, default) def _op_applies(self, op: str, target: str) -> bool: name = self.svc.service_name if target == "proxy": # Restart the external proxy from exactly one tier (default: gateway). return name == self._cfg_val("control_proxy_handler_service", "gateway") if target == "all": return name in SERVICE_TIERS return target == name
[docs] async def run(self) -> None: redis = self.svc.redis if redis is None: return pubsub = redis.pubsub() try: await pubsub.subscribe(CONTROL_OPS_CHANNEL) logger.info( "ControlOpsDaemon subscribed name=%s id=%s", self.svc.service_name, self.svc.instance_id, ) async for message in pubsub.listen(): if message.get("type") != "message": continue try: payload = json.loads(_decode(message["data"])) await self._handle(payload) except Exception: logger.exception("ControlOpsDaemon failed to process op") except asyncio.CancelledError: logger.info("ControlOpsDaemon cancelled name=%s", self.svc.service_name) except Exception: logger.exception("ControlOpsDaemon listener crashed name=%s", self.svc.service_name) finally: try: await pubsub.unsubscribe(CONTROL_OPS_CHANNEL) await pubsub.aclose() except Exception: pass
async def _reply(self, reply_channel: Optional[str], status: str) -> None: if not reply_channel or self.svc.redis is None: return host = socket.gethostname() line = f"{self.svc.service_name}:{self.svc.instance_id[:8]}@{host}: {status}" try: await self.svc.redis.rpush(reply_channel, line) await self.svc.redis.expire(reply_channel, 30) except Exception: logger.debug("ControlOpsDaemon reply push failed", exc_info=True) async def _handle(self, payload: dict) -> None: op = payload.get("op") target = payload.get("target") if op not in ("restart", "pull") or not target: return if not self._op_applies(op, target): return redis = self.svc.redis nonce = payload.get("nonce") reply_channel = payload.get("reply_channel") # Idempotency: act on a given nonce at most once per instance. if nonce and redis is not None: seen_key = f"{_SEEN_PREFIX}{self.svc.instance_id}:{nonce}" try: fresh = await redis.set(seen_key, "1", nx=True, ex=30) except Exception: fresh = True if not fresh: return if op == "pull": result = await run_git_pull( self._cfg_val("bot_repo_path", "") or None, redis=redis, dedupe=True, lock_ttl=int(self._cfg_val("control_pull_lock_ttl", 90)), instance_id=self.svc.instance_id, ) await self._reply(reply_channel, result.summary) return # op == "restart" if target == "proxy": proxy_unit = self._cfg_val("proxy_service_name", "gemini-cli-proxy") ok, summary = await systemctl_restart(proxy_unit) await self._reply(reply_channel, summary) return # Self-restart this service. Gateway goes last (longer grace) so it does # not sever its own Discord connection before the ACK is delivered. unit = unit_name_for(self._cfg, self.svc.service_name) if self.svc.service_name == "gateway": grace = float(self._cfg_val("control_gateway_restart_grace", 8.0)) else: grace = float(self._cfg_val("control_service_restart_grace", 2.0)) await self._reply(reply_channel, f"restarting `{unit}` in {grace:.0f}s") asyncio.create_task( _deferred_systemctl_restart(unit, delay=grace), name=f"control-ops-restart-{self.svc.service_name}", )
# ────────────────────────────────────────────── # Gateway-side publisher # ────────────────────────────────────────────── async def _count_live_instances(redis, target: str) -> int: """Best-effort count of live registry instances for *target* ("all" → fleet).""" if redis is None: return 0 try: if target == "all": keys = await redis.keys(f"{_REGISTRY_PREFIX}*") else: keys = await redis.keys(f"{_REGISTRY_PREFIX}{target}:*") except Exception: return 0 # For "all" only count our service tiers (exclude any stray keys). if target == "all": n = 0 for k in keys or []: ks = _decode(k) tier = ks[len(_REGISTRY_PREFIX):].split(":", 1)[0] if tier in SERVICE_TIERS: n += 1 return n return len(keys or []) def _target_label(op: str, target: str, expected: int) -> str: if op == "pull": return "every host" if target == "all": return f"all services (~{expected})" if target == "proxy": return "the proxy" plural = "s" if expected != 1 else "" return f"{target}{'' if expected <= 1 else f' ×{expected}'} replica{plural}"
[docs] async def dispatch_control_op( redis, config, *, op: str, target: str, requested_by: str, send, trace_id: str = "", ) -> str: """Publish a control op, aggregate per-service acks, and report via *send*. Sends an immediate ACK, broadcasts the op on :data:`CONTROL_OPS_CHANNEL`, collects ack lines on a per-request reply list for ``control_reply_timeout`` seconds, then sends and returns a roster string. """ nonce = uuid.uuid4().hex reply_channel = f"{_REPLY_PREFIX}{uuid.uuid4().hex}" verb = "Pulling latest code on" if op == "pull" else "Restarting" expected = await _count_live_instances(redis, target) await send(f"🔄 {verb} {_target_label(op, target, expected)}…") payload = { "op": op, "target": target, "requested_by": requested_by, "trace_id": trace_id or uuid.uuid4().hex, "nonce": nonce, "reply_channel": reply_channel, } try: await redis.publish(CONTROL_OPS_CHANNEL, json.dumps(payload)) except Exception as e: msg = f"❌ Failed to broadcast `{op}`/`{target}`: `{e}`" await send(msg) return msg # Aggregate acks until the timeout budget is exhausted. timeout = float(getattr(config, "control_reply_timeout", 3.0)) loop = asyncio.get_event_loop() deadline = loop.time() + timeout replies: list[str] = [] while loop.time() < deadline: try: item = await redis.blpop(reply_channel, timeout=1) except Exception: break if not item: continue # item is (key, value); value may be bytes replies.append(_decode(item[1])) try: await redis.delete(reply_channel) except Exception: pass if replies: body = "\n".join(f"• {r}" for r in replies) roster = f"✅ `{op}`/`{target}` acknowledged by {len(replies)} instance(s):\n{body}" else: roster = ( f"⚠️ `{op}`/`{target}` broadcast, but no instances acknowledged within " f"{timeout:.0f}s (they may restart anyway, or the registry is empty)." ) await send(roster) return roster
[docs] async def format_service_roster(redis) -> str: """Render the live service registry for the ``!services`` command.""" if redis is None: return "⚠️ Redis unavailable — cannot list services." try: keys = await redis.keys(f"{_REGISTRY_PREFIX}*") except Exception as e: return f"❌ Failed to read service registry: `{e}`" if not keys: return "No live Stargazer services found in the registry." by_tier: dict[str, list[str]] = {} for k in keys: ks = _decode(k) rest = ks[len(_REGISTRY_PREFIX):] tier, _, inst = rest.partition(":") status = "" try: raw = await redis.get(k) if raw: meta = json.loads(_decode(raw)) status = meta.get("status", "") if meta.get("host"): status = f"{status}@{meta['host']}" if status else meta["host"] except Exception: pass label = inst[:8] + (f" ({status})" if status else "") by_tier.setdefault(tier, []).append(label) lines = ["**🛰️ Live Stargazer services**", ""] for tier in sorted(by_tier): insts = ", ".join(sorted(by_tier[tier])) lines.append(f"**{tier}** ({len(by_tier[tier])}): {insts}") return "\n".join(lines)