Source code for classifiers.update_dangerous_command_embeddings

#!/usr/bin/env python3
"""Compute dangerous-command centroid embeddings and store in Redis + RediSearch.

Loads :file:`dangerous_command_index.json` (destructive ops + malware-execution
risk paraphrases), embeds example strings per category
via ``compute_tool_centroids_bulk`` (same as tool embeddings), writes one
``dangerous_cmd_emb:{category_id}`` HASH per category, and removes orphaned
keys no longer present in the JSON.

Run ``python init_redis_indexes.py`` (or bot startup ``ensure_indexes``) once
so ``idx:dangerous_cmds`` exists before first use.

When the message pipeline defers embedding (unaddressed messages with batch
queue), ``query_embedding`` is absent and the guard does not run — same as no
suffix.

Usage::

    python -m classifiers.update_dangerous_command_embeddings [--force-all]

Environment:
    REDIS_URL, OPENROUTER_API_KEY / config.yaml api_key (for embeddings)
"""

from __future__ import annotations

import argparse
import asyncio
import jsonutil as json
import logging
import os
import sys
from typing import Any

import redis.asyncio as aioredis

sys.path.insert(
    0,
    os.path.abspath(os.path.join(os.path.dirname(__file__), "..")),
)

from gemini_embed_pool import init_quota_tracking  # noqa: E402

from config import Config  # noqa: E402
from rag_system.openrouter_embeddings import OpenRouterEmbeddings  # noqa: E402
from classifiers.tool_embedding_batch import (  # noqa: E402
    compute_tool_centroids_bulk,
    normalize_synthetic_queries,
)
from classifiers.redis_vector_index import (  # noqa: E402
    delete_dangerous_cmd_embedding_hash,
    scan_dangerous_cmd_category_ids,
    store_dangerous_cmd_embedding_hash,
)

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)

INDEX_PATH = os.path.join(
    os.path.dirname(__file__),
    "dangerous_command_index.json",
)


[docs] def load_index() -> dict[str, Any]: """Load and parse the dangerous-command example corpus from disk. Reads the JSON file at module-level ``INDEX_PATH`` (``dangerous_command_index.json``, the curated corpus of destructive ops and malware-execution-risk paraphrases) and returns its parsed contents. The caller expects a top-level ``categories`` mapping of category id to a blob with ``label`` and ``examples``. This is a filesystem read only; it does not touch Redis or the network. Called by :func:`update_dangerous_command_embeddings` in this module; no other callers were found. Returns: dict[str, Any]: The parsed corpus, typically containing ``categories`` and ``version`` keys. Raises: FileNotFoundError: If ``INDEX_PATH`` does not exist. ValueError: If the file does not contain valid JSON. """ with open(INDEX_PATH, "r", encoding="utf-8") as f: return json.load(f)
[docs] async def update_dangerous_command_embeddings(*, force_all: bool = False) -> bool: """Recompute every dangerous-command centroid and sync it to Redis. The core routine of this module. It loads the corpus via :func:`load_index`, normalizes each category's example strings with :func:`classifiers.tool_embedding_batch.normalize_synthetic_queries`, and computes one mean centroid vector per category through :func:`classifiers.tool_embedding_batch.compute_tool_centroids_bulk` using an :class:`OpenRouterEmbeddings` client. Each centroid plus metadata (label, query count, corpus version) is written to a ``dangerous_cmd_emb:{category_id}`` HASH via :func:`classifiers.redis_vector_index.store_dangerous_cmd_embedding_hash`, and category ids present in Redis but no longer in the JSON are pruned via :func:`classifiers.redis_vector_index.delete_dangerous_cmd_embedding_hash`. These vectors are what the runtime guard scores an incoming query embedding against to flag destructive or malware-execution intent. Opens (from :class:`config.Config` / ``REDIS_URL``) and always closes its own async Redis connection, calls :func:`gemini_embed_pool.init_quota_tracking`, reads the JSON corpus, and makes OpenRouter embedding HTTP calls. Called only by :func:`main` here (the ``python -m classifiers.update_dangerous_command_embeddings`` entry point); no other callers were found. Args: force_all (bool): Accepted for symmetry with the sibling refresh scripts. It is effectively a no-op because every run already rewrites all categories from the JSON; only an informational log line is emitted. Returns: bool: ``True`` if at least one centroid was stored or one orphan removed, ``False`` on a hard failure (no categories, missing API key, or no usable examples). """ cfg = Config.load() redis_url = cfg.redis_url or os.environ.get( "REDIS_URL", "redis://localhost:6379/0", ) redis_client = aioredis.from_url( redis_url, decode_responses=True, **cfg.redis_connection_kwargs_for_url(redis_url), ) init_quota_tracking(redis_client) try: raw = load_index() categories = raw.get("categories") or {} if not categories: logger.error("No categories in %s", INDEX_PATH) return False api_key = cfg.api_key or os.environ.get("OPENROUTER_API_KEY", "") if not api_key: logger.error( "No API key for embeddings (config.api_key / OPENROUTER_API_KEY)" ) return False emb_client = OpenRouterEmbeddings( api_key=api_key, model=cfg.embedding_model, ) tool_queries: dict[str, list[str]] = {} meta_by_cat: dict[str, tuple[str, list[str]]] = {} for cat_id, blob in categories.items(): if not isinstance(blob, dict): continue label = str(blob.get("label", cat_id)) examples = blob.get("examples") or [] qs = normalize_synthetic_queries(examples) if not qs: logger.warning("Category %s: no examples, skipping", cat_id) continue tool_queries[cat_id] = qs meta_by_cat[cat_id] = (label, qs) if not tool_queries: logger.error("No usable categories with examples") return False existing = set(await scan_dangerous_cmd_category_ids(redis_client)) registered = set(tool_queries.keys()) if force_all: logger.info("--force-all: refreshing all categories from JSON") centroids = await compute_tool_centroids_bulk(emb_client, tool_queries) n_ok = 0 for cat_id, centroid in centroids.items(): if centroid is None: logger.warning("Category %s: failed to compute centroid", cat_id) continue label, qs = meta_by_cat[cat_id] meta = { "category_id": cat_id, "label": label, "query_count": len(qs), "version": raw.get("version", 1), } await store_dangerous_cmd_embedding_hash( redis_client, cat_id, centroid, meta, ) n_ok += 1 logger.info( "Stored %s (%d examples)", cat_id, len(qs), ) orphaned = existing - registered for oid in sorted(orphaned): logger.info("Removing orphaned dangerous_cmd embedding: %s", oid) await delete_dangerous_cmd_embedding_hash(redis_client, oid) logger.info( "Dangerous-command embeddings update complete: %d stored, %d orphaned removed", n_ok, len(orphaned), ) return n_ok > 0 or len(orphaned) > 0 finally: await redis_client.aclose()
[docs] async def main() -> None: """Parse CLI flags and run the dangerous-command embedding refresh. Builds an :class:`argparse.ArgumentParser` exposing the single ``--force-all`` flag, then awaits :func:`update_dangerous_command_embeddings` with the parsed value. The coroutine handles its own Redis connection (one ``dangerous_cmd_emb:{category_id}`` HASH per category plus the ``idx:dangerous_cmds`` RediSearch index) and OpenRouter embedding calls; on completion this wrapper translates the boolean result into a process exit code via ``sys.exit`` (``0`` when at least one category was stored or an orphan removed, ``1`` otherwise). This is the module entry point invoked under ``if __name__ == "__main__"`` through ``asyncio.run(main())`` (e.g. ``python -m classifiers.update_dangerous_command_embeddings``); no other internal callers were found. Raises: SystemExit: Always, carrying the success/failure exit code. """ parser = argparse.ArgumentParser( description="Refresh dangerous-command centroid embeddings in Redis", ) parser.add_argument( "--force-all", action="store_true", help="Same full refresh (default already rewrites all categories from JSON)", ) args = parser.parse_args() ok = await update_dangerous_command_embeddings(force_all=args.force_all) sys.exit(0 if ok else 1)
if __name__ == "__main__": asyncio.run(main())