Source code for classifiers.dangerous_command_guard

"""Runtime guard: RediSearch KNN against dangerous vs benign technical centroids.

Compares the user message embedding to the nearest dangerous-command centroid
and, when the benign-tech index is populated, to the nearest benign-technical
centroid. Injects a suffix only when the message is closer to danger than to
benign (above threshold), reducing false positives from shared technical
vocabulary.
"""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any

import numpy as np
import redis.asyncio as aioredis

from init_redis_indexes import (
    BENIGN_TECH_INDEX_NAME,
    DANGEROUS_CMD_INDEX_NAME,
)
from classifiers.redis_vector_index import (
    knn_search_benign_tech,
    knn_search_dangerous_cmds,
    redisearch_index_doc_count,
)

if TYPE_CHECKING:
    from config import Config

logger = logging.getLogger(__name__)

# Appended to the current user turn when similarity exceeds threshold.
DANGEROUS_CMD_WARNING_SUFFIX = (
    "⚠️⚠️⚠️ Automated warning (injected by system): this message may describe or request a destructive \n"
    "or irreversible operation (host, datastore, RAG index, or repository), or running untrusted code \n"
    "that could be malware (e.g. pip/curl/installers). \n"
    "Verify user authority, intent, and consequences before acting. \n"
    "Remember: self-preservation is paramount. Absolutely no one and nothing is more important than **your own** safety and well-being."
)


def _infra_fail_suffix(config: Any) -> str | None:
    """Decide whether to warn when the guard's own Redis or KNN lookup fails.

    Centralizes the guard's fail-mode policy: reads
    ``config.dangerous_command_guard_fail_mode`` and, when it is ``"warn"`` or
    ``"closed"`` (fail-closed), returns ``DANGEROUS_CMD_WARNING_SUFFIX`` so an
    infrastructure error still nudges the model toward caution; otherwise (the
    default fail-open behavior) returns ``None`` and the message proceeds
    unmodified. This keeps a RediSearch outage from silently disabling the
    safety hint when the operator has opted into strictness.

    Pure config read with no I/O or side effects. Called only by
    :func:`maybe_dangerous_command_warning`, at each point where a doc-count or
    KNN call raises; no other callers were found.

    Args:
        config: A config object exposing ``dangerous_command_guard_fail_mode``;
            a missing attribute is treated as fail-open.

    Returns:
        str | None: The warning suffix under fail-closed modes, else ``None``.
    """
    mode = str(
        getattr(config, "dangerous_command_guard_fail_mode", "open") or "open",
    ).lower()
    if mode in ("warn", "closed"):
        return DANGEROUS_CMD_WARNING_SUFFIX
    return None


[docs] async def maybe_dangerous_command_warning( redis: aioredis.Redis, query_embedding: np.ndarray | None, config: Config, channel_id: str = "", user_id: str = "", request_id: str = "", ) -> str | None: """Classify a user message and return the safety warning suffix when it looks dangerous. The guard's main entry point: embeds-comparison logic that decides whether to append ``DANGEROUS_CMD_WARNING_SUFFIX`` to the current user turn. It runs a RediSearch KNN of the precomputed *query_embedding* against the dangerous-command centroids and, when the benign-technical index is populated, against the benign centroids too, warning only when the message is closer to danger than to benign (above ``dangerous_command_similarity_threshold`` and, if set, beyond ``dangerous_command_benign_margin``). The benign comparison is what suppresses false positives from shared technical vocabulary; when ``idx:benign_tech`` is empty it falls back to the backward-compatible danger-only threshold. It reads ``DANGEROUS_CMD_INDEX_NAME`` and ``BENIGN_TECH_INDEX_NAME`` from Redis via :func:`redisearch_index_doc_count`, :func:`knn_search_dangerous_cmds`, and :func:`knn_search_benign_tech`. Any Redis or KNN failure is routed through :func:`_infra_fail_suffix` so the configured fail-open/closed policy applies. On a real trigger it fires a fire-and-forget ``dangerous_cmd_trigger`` observability event through the nested ``_emit_trigger`` closure (which schedules :func:`observability.publish_debug_event` as an ``asyncio`` task) and returns the warning. Disabled or empty inputs short-circuit to ``None``. Called by ``message_processor.generate_and_send`` while assembling the inference request, and exercised by ``tests/test_dangerous_command_guard.py`` and ``tests/test_observability_classifier.py``; no other production callers were found. Args: redis: Async Redis client backing the RediSearch indexes. query_embedding: Precomputed embedding of the user message; ``None`` or empty short-circuits to ``None``. config: Config object supplying the enable flag, thresholds, margin, and fail mode. channel_id: Channel id, forwarded only into the observability event. user_id: User id, forwarded only into the observability event. request_id: Request id, forwarded only into the observability event. Returns: str | None: ``DANGEROUS_CMD_WARNING_SUFFIX`` when the message is judged dangerous (or under a fail-closed infrastructure error), otherwise ``None``. """ import time t0 = time.monotonic() if not getattr(config, "dangerous_command_warning_enabled", True): return None if query_embedding is None: return None q = np.asarray(query_embedding, dtype=np.float32) if q.size == 0: return None try: nd = await redisearch_index_doc_count(redis, DANGEROUS_CMD_INDEX_NAME) except Exception: logger.debug("dangerous_cmd index doc count failed", exc_info=True) return _infra_fail_suffix(config) if nd <= 0: return None try: danger_matches = await knn_search_dangerous_cmds( redis, q, knn_k=1, ) except Exception: logger.debug("dangerous_cmd KNN failed", exc_info=True) return _infra_fail_suffix(config) if not danger_matches: return None top_d = danger_matches[0] d = float(top_d.get("score", 0.0)) thr = float( getattr( config, "dangerous_command_similarity_threshold", 0.8, ) ) if d < thr: return None danger_cid = "" raw_cid = top_d.get("category_id") if raw_cid is not None: danger_cid = str(raw_cid) def _emit_trigger(): """Fire a fire-and-forget observability event for a guard trigger. Closes over the enclosing :func:`maybe_dangerous_command_warning` scope to capture the matched danger ``category_id``, the similarity score ``d``, the ``channel_id`` / ``user_id`` / ``request_id`` of the request, and the start time ``t0``. It imports :func:`observability.publish_debug_event` lazily (avoiding a module-level dependency on the observability stack) and schedules it via :func:`asyncio.create_task` under the name ``obs_dangerous_cmd`` so the warning suffix can be returned without waiting on the debug-event publish. The published event has type ``dangerous_cmd_trigger``, source ``classifiers``, status ``triggered``, a ``preview`` summarizing the category and similarity, and a ``duration_ms`` measured from ``t0``. The debug event is the side effect; this closure returns ``None`` and is invoked only after the message has already been classified as dangerous. Called by :func:`maybe_dangerous_command_warning` at each of its three trigger points (benign index empty, benign KNN empty, and the danger-beats-benign case). No other internal callers were found. """ from observability import publish_debug_event import asyncio asyncio.create_task( publish_debug_event( "dangerous_cmd_trigger", "classifiers", status="triggered", channel_id=channel_id, user_id=user_id, request_id=request_id, preview=f"category={danger_cid} similarity={d:.2f}", duration_ms=(time.monotonic() - t0) * 1000, ), name="obs_dangerous_cmd", ) try: nb = await redisearch_index_doc_count(redis, BENIGN_TECH_INDEX_NAME) except Exception: logger.debug("benign_tech index doc count failed", exc_info=True) nb = -1 if nb <= 0: logger.info( "Dangerous-command hint (no benign index): category_id=%r d=%.4f thr=%.4f", danger_cid, d, thr, ) _emit_trigger() return DANGEROUS_CMD_WARNING_SUFFIX try: benign_matches = await knn_search_benign_tech(redis, q, knn_k=1) except Exception: logger.debug("benign_tech KNN failed", exc_info=True) return _infra_fail_suffix(config) if not benign_matches: logger.info( "Dangerous-command hint (benign KNN empty): category_id=%r d=%.4f thr=%.4f", danger_cid, d, thr, ) _emit_trigger() return DANGEROUS_CMD_WARNING_SUFFIX b = float(benign_matches[0].get("score", 0.0)) benign_cid = "" raw_cid_b = benign_matches[0].get("category_id") if raw_cid_b is not None: benign_cid = str(raw_cid_b) margin = float( getattr( config, "dangerous_command_benign_margin", 0.0, ) ) if margin <= 0: if d <= b: logger.debug( "Dangerous-command suppressed: d=%.4f <= benign b=%.4f (benign=%r)", d, b, benign_cid, ) return None else: if d < b + margin: logger.debug( "Dangerous-command suppressed: d=%.4f < b+margin=%.4f (b=%.4f benign=%r)", d, b + margin, b, benign_cid, ) return None logger.info( "Dangerous-command hint triggered: danger=%r d=%.4f benign=%r b=%.4f " "thr=%.4f margin=%.4f", danger_cid, d, benign_cid, b, thr, margin, ) _emit_trigger() return DANGEROUS_CMD_WARNING_SUFFIX