"""Fleet-wide control-ops: restart / git-pull across every Stargazer service.
The monolith is gone, so an admin ``!restart_*`` / ``!bot_pull`` command can no
longer just act on one process. Instead the **gateway** authorizes the command
and broadcasts a small JSON op over the Redis pub/sub channel
:data:`CONTROL_OPS_CHANNEL`. Every service runs a :class:`ControlOpsDaemon`
(wired into :class:`core.service_base.StargazerService`) that decides whether the
op applies to it and self-restarts / self-pulls accordingly.
Why pub/sub (not a stream): a restart op must fan out to **all live** instances
at once and must NOT be replayed by a service that was down (that would cause a
restart storm on recovery). Pub/sub's "live subscribers only, no persistence" is
exactly the right delivery guarantee — and it matches the existing
``sg:channel:cancel`` control channel.
Ordering guarantee: the gateway holds the live Discord connection, so it always
restarts **last** (a longer grace) — its own daemon self-restarts only after the
publisher has flushed the ACK and aggregated replies.
"""
from __future__ import annotations
import asyncio
import json
import logging
import socket
import uuid
from typing import Optional
from core.ops_exec import (
_deferred_systemctl_restart,
run_git_pull,
systemctl_restart,
)
from core.service_registry import SERVICE_TTL, heartbeat, register_service
logger = logging.getLogger("control_ops")
# Pub/sub channel every service subscribes to.
CONTROL_OPS_CHANNEL = "sg:control:ops"
# Per-instance idempotency marker prefix (SETNX, short TTL).
_SEEN_PREFIX = "sg:control:seen:"
# Reply-list prefix for ack aggregation (RPUSH by daemons, BLPOP by publisher).
_REPLY_PREFIX = "sg:control:reply:"
# Registry key prefix (mirrors core.service_registry).
_REGISTRY_PREFIX = "sg:registry:service:"
# The service tiers we can target (proxy is external; "all" = these tiers).
SERVICE_TIERS = ("gateway", "inference", "agents", "consolidation", "web", "tools")
# Authoritative command -> (op, target) map for cluster control commands. Kept in
# sync with the cluster-locus entries in message_processor.command_registry (which
# drive help/detection). ``!bot_restart`` is a back-compat alias for inference.
CONTROL_OPS_COMMANDS: dict[str, tuple[str, str]] = {
"!restart_all": ("restart", "all"),
"!restart_gateway": ("restart", "gateway"),
"!restart_inference": ("restart", "inference"),
"!restart_agents": ("restart", "agents"),
"!restart_consolidation": ("restart", "consolidation"),
"!restart_web": ("restart", "web"),
"!restart_tools": ("restart", "tools"),
"!bot_restart": ("restart", "inference"),
"!proxy_restart": ("restart", "proxy"),
"!bot_pull": ("pull", "all"),
}
[docs]
def control_op_for(text: str) -> Optional[tuple[str, str]]:
"""Return ``(op, target)`` for a cluster control command, or None."""
if not text or not text.strip():
return None
return CONTROL_OPS_COMMANDS.get(text.strip().lower().split()[0])
[docs]
def is_control_ops_command(text: str) -> bool:
"""True if *text* is a fleet-wide control-ops command."""
return control_op_for(text) is not None
def _decode(raw) -> str:
return raw.decode() if isinstance(raw, (bytes, bytearray)) else str(raw)
[docs]
def unit_name_for(config, service_name: str) -> str:
"""Resolve the systemd unit for *service_name* (explicit map → prefix+name)."""
names = getattr(config, "control_unit_names", None) or {}
if service_name in names:
return names[service_name]
prefix = getattr(config, "control_unit_prefix", "stargazer-")
return f"{prefix}{service_name}"
[docs]
def fleet_units(config=None) -> list[str]:
"""Resolve every :data:`SERVICE_TIERS` tier to its systemd unit, in order.
The single source of truth for "all the Stargazer service units" — the live
microservices that replaced the retired ``stargazer`` / ``stargazer-swarm``
monolith. Each tier is resolved through :func:`unit_name_for`, so the result
honours a deployment's ``control_unit_prefix`` / ``control_unit_names``
overrides; with the defaults it yields ``stargazer-gateway``,
``stargazer-inference``, ``stargazer-agents``, ``stargazer-consolidation``,
``stargazer-web``. *config* may be ``None`` (callers without a Config in
hand), in which case the default prefix applies. Pure; no I/O.
Used by the admin journal tail (``Config.resolved_journal_units``), the
``read_service_logs`` tool, the log→RAG ingest task, and the ``bot_control``
restart tools so they all target the same fleet.
Args:
config: Optional Config-like object carrying ``control_unit_prefix`` /
``control_unit_names``; ``None`` falls back to the default prefix.
Returns:
list[str]: Resolved systemd unit names, one per tier, in
:data:`SERVICE_TIERS` order.
"""
return [unit_name_for(config, tier) for tier in SERVICE_TIERS]
# ──────────────────────────────────────────────
# Per-service daemon
# ──────────────────────────────────────────────
[docs]
class ControlOpsDaemon:
"""Subscribes to :data:`CONTROL_OPS_CHANNEL` and acts on matching ops.
One runs in every service (started by ``StargazerService.boot``). It mirrors
the worker cancellation daemon's pub/sub loop, but lives at the service-base
level so all five tiers participate in fleet-wide restart / pull.
"""
def __init__(self, svc) -> None:
self.svc = svc
@property
def _cfg(self):
return getattr(self.svc, "cfg", None)
def _cfg_val(self, name: str, default):
cfg = self._cfg
if cfg is None:
return default
return getattr(cfg, name, default)
def _op_applies(self, op: str, target: str) -> bool:
name = self.svc.service_name
if target == "proxy":
# Restart the external proxy from exactly one tier (default: gateway).
return name == self._cfg_val("control_proxy_handler_service", "gateway")
if target == "all":
return name in SERVICE_TIERS
return target == name
[docs]
async def run(self) -> None:
redis = self.svc.redis
if redis is None:
return
pubsub = redis.pubsub()
try:
await pubsub.subscribe(CONTROL_OPS_CHANNEL)
logger.info(
"ControlOpsDaemon subscribed name=%s id=%s",
self.svc.service_name,
self.svc.instance_id,
)
async for message in pubsub.listen():
if message.get("type") != "message":
continue
try:
payload = json.loads(_decode(message["data"]))
await self._handle(payload)
except Exception:
logger.exception("ControlOpsDaemon failed to process op")
except asyncio.CancelledError:
logger.info("ControlOpsDaemon cancelled name=%s", self.svc.service_name)
except Exception:
logger.exception("ControlOpsDaemon listener crashed name=%s", self.svc.service_name)
finally:
try:
await pubsub.unsubscribe(CONTROL_OPS_CHANNEL)
await pubsub.aclose()
except Exception:
pass
async def _reply(self, reply_channel: Optional[str], status: str) -> None:
if not reply_channel or self.svc.redis is None:
return
host = socket.gethostname()
line = f"{self.svc.service_name}:{self.svc.instance_id[:8]}@{host}: {status}"
try:
await self.svc.redis.rpush(reply_channel, line)
await self.svc.redis.expire(reply_channel, 30)
except Exception:
logger.debug("ControlOpsDaemon reply push failed", exc_info=True)
async def _handle(self, payload: dict) -> None:
op = payload.get("op")
target = payload.get("target")
if op not in ("restart", "pull") or not target:
return
if not self._op_applies(op, target):
return
redis = self.svc.redis
nonce = payload.get("nonce")
reply_channel = payload.get("reply_channel")
# Idempotency: act on a given nonce at most once per instance.
if nonce and redis is not None:
seen_key = f"{_SEEN_PREFIX}{self.svc.instance_id}:{nonce}"
try:
fresh = await redis.set(seen_key, "1", nx=True, ex=30)
except Exception:
fresh = True
if not fresh:
return
if op == "pull":
result = await run_git_pull(
self._cfg_val("bot_repo_path", "") or None,
redis=redis,
dedupe=True,
lock_ttl=int(self._cfg_val("control_pull_lock_ttl", 90)),
instance_id=self.svc.instance_id,
)
await self._reply(reply_channel, result.summary)
return
# op == "restart"
if target == "proxy":
proxy_unit = self._cfg_val("proxy_service_name", "gemini-cli-proxy")
ok, summary = await systemctl_restart(proxy_unit)
await self._reply(reply_channel, summary)
return
# Self-restart this service. Gateway goes last (longer grace) so it does
# not sever its own Discord connection before the ACK is delivered.
unit = unit_name_for(self._cfg, self.svc.service_name)
if self.svc.service_name == "gateway":
grace = float(self._cfg_val("control_gateway_restart_grace", 8.0))
else:
grace = float(self._cfg_val("control_service_restart_grace", 2.0))
await self._reply(reply_channel, f"restarting `{unit}` in {grace:.0f}s")
asyncio.create_task(
_deferred_systemctl_restart(unit, delay=grace),
name=f"control-ops-restart-{self.svc.service_name}",
)
# ──────────────────────────────────────────────
# Gateway-side publisher
# ──────────────────────────────────────────────
async def _count_live_instances(redis, target: str) -> int:
"""Best-effort count of live registry instances for *target* ("all" → fleet)."""
if redis is None:
return 0
try:
if target == "all":
keys = await redis.keys(f"{_REGISTRY_PREFIX}*")
else:
keys = await redis.keys(f"{_REGISTRY_PREFIX}{target}:*")
except Exception:
return 0
# For "all" only count our service tiers (exclude any stray keys).
if target == "all":
n = 0
for k in keys or []:
ks = _decode(k)
tier = ks[len(_REGISTRY_PREFIX):].split(":", 1)[0]
if tier in SERVICE_TIERS:
n += 1
return n
return len(keys or [])
def _target_label(op: str, target: str, expected: int) -> str:
if op == "pull":
return "every host"
if target == "all":
return f"all services (~{expected})"
if target == "proxy":
return "the proxy"
plural = "s" if expected != 1 else ""
return f"{target}{'' if expected <= 1 else f' ×{expected}'} replica{plural}"
[docs]
async def dispatch_control_op(
redis,
config,
*,
op: str,
target: str,
requested_by: str,
send,
trace_id: str = "",
) -> str:
"""Publish a control op, aggregate per-service acks, and report via *send*.
Sends an immediate ACK, broadcasts the op on :data:`CONTROL_OPS_CHANNEL`,
collects ack lines on a per-request reply list for ``control_reply_timeout``
seconds, then sends and returns a roster string.
"""
nonce = uuid.uuid4().hex
reply_channel = f"{_REPLY_PREFIX}{uuid.uuid4().hex}"
verb = "Pulling latest code on" if op == "pull" else "Restarting"
expected = await _count_live_instances(redis, target)
await send(f"🔄 {verb} {_target_label(op, target, expected)}…")
payload = {
"op": op,
"target": target,
"requested_by": requested_by,
"trace_id": trace_id or uuid.uuid4().hex,
"nonce": nonce,
"reply_channel": reply_channel,
}
try:
await redis.publish(CONTROL_OPS_CHANNEL, json.dumps(payload))
except Exception as e:
msg = f"❌ Failed to broadcast `{op}`/`{target}`: `{e}`"
await send(msg)
return msg
# Aggregate acks until the timeout budget is exhausted.
timeout = float(getattr(config, "control_reply_timeout", 3.0))
loop = asyncio.get_event_loop()
deadline = loop.time() + timeout
replies: list[str] = []
while loop.time() < deadline:
try:
item = await redis.blpop(reply_channel, timeout=1)
except Exception:
break
if not item:
continue
# item is (key, value); value may be bytes
replies.append(_decode(item[1]))
try:
await redis.delete(reply_channel)
except Exception:
pass
if replies:
body = "\n".join(f"• {r}" for r in replies)
roster = f"✅ `{op}`/`{target}` acknowledged by {len(replies)} instance(s):\n{body}"
else:
roster = (
f"⚠️ `{op}`/`{target}` broadcast, but no instances acknowledged within "
f"{timeout:.0f}s (they may restart anyway, or the registry is empty)."
)
await send(roster)
return roster