#!/usr/bin/env python3
"""Compute dangerous-command centroid embeddings and store in Redis + RediSearch.
Loads :file:`dangerous_command_index.json` (destructive ops + malware-execution
risk paraphrases), embeds example strings per category
via ``compute_tool_centroids_bulk`` (same as tool embeddings), writes one
``dangerous_cmd_emb:{category_id}`` HASH per category, and removes orphaned
keys no longer present in the JSON.
Run ``python init_redis_indexes.py`` (or bot startup ``ensure_indexes``) once
so ``idx:dangerous_cmds`` exists before first use.
When the message pipeline defers embedding (unaddressed messages with batch
queue), ``query_embedding`` is absent and the guard does not run — same as no
suffix.
Usage::
python -m classifiers.update_dangerous_command_embeddings [--force-all]
Environment:
REDIS_URL, OPENROUTER_API_KEY / config.yaml api_key (for embeddings)
"""
from __future__ import annotations
import argparse
import asyncio
import jsonutil as json
import logging
import os
import sys
from typing import Any
import redis.asyncio as aioredis
sys.path.insert(
0,
os.path.abspath(os.path.join(os.path.dirname(__file__), "..")),
)
from gemini_embed_pool import init_quota_tracking # noqa: E402
from config import Config # noqa: E402
from rag_system.openrouter_embeddings import OpenRouterEmbeddings # noqa: E402
from classifiers.tool_embedding_batch import ( # noqa: E402
compute_tool_centroids_bulk,
normalize_synthetic_queries,
)
from classifiers.redis_vector_index import ( # noqa: E402
delete_dangerous_cmd_embedding_hash,
scan_dangerous_cmd_category_ids,
store_dangerous_cmd_embedding_hash,
)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
INDEX_PATH = os.path.join(
os.path.dirname(__file__),
"dangerous_command_index.json",
)
[docs]
def load_index() -> dict[str, Any]:
"""Load and parse the dangerous-command example corpus from disk.
Reads the JSON file at module-level ``INDEX_PATH``
(``dangerous_command_index.json``, the curated corpus of destructive ops and
malware-execution-risk paraphrases) and returns its parsed contents. The
caller expects a top-level ``categories`` mapping of category id to a blob
with ``label`` and ``examples``. This is a filesystem read only; it does not
touch Redis or the network.
Called by :func:`update_dangerous_command_embeddings` in this module; no other
callers were found.
Returns:
dict[str, Any]: The parsed corpus, typically containing ``categories``
and ``version`` keys.
Raises:
FileNotFoundError: If ``INDEX_PATH`` does not exist.
ValueError: If the file does not contain valid JSON.
"""
with open(INDEX_PATH, "r", encoding="utf-8") as f:
return json.load(f)
[docs]
async def update_dangerous_command_embeddings(*, force_all: bool = False) -> bool:
"""Recompute every dangerous-command centroid and sync it to Redis.
The core routine of this module. It loads the corpus via :func:`load_index`,
normalizes each category's example strings with
:func:`classifiers.tool_embedding_batch.normalize_synthetic_queries`, and
computes one mean centroid vector per category through
:func:`classifiers.tool_embedding_batch.compute_tool_centroids_bulk` using an
:class:`OpenRouterEmbeddings` client. Each centroid plus metadata (label,
query count, corpus version) is written to a
``dangerous_cmd_emb:{category_id}`` HASH via
:func:`classifiers.redis_vector_index.store_dangerous_cmd_embedding_hash`, and
category ids present in Redis but no longer in the JSON are pruned via
:func:`classifiers.redis_vector_index.delete_dangerous_cmd_embedding_hash`.
These vectors are what the runtime guard scores an incoming query embedding
against to flag destructive or malware-execution intent.
Opens (from :class:`config.Config` / ``REDIS_URL``) and always closes its own
async Redis connection, calls :func:`gemini_embed_pool.init_quota_tracking`,
reads the JSON corpus, and makes OpenRouter embedding HTTP calls. Called only
by :func:`main` here (the ``python -m
classifiers.update_dangerous_command_embeddings`` entry point); no other
callers were found.
Args:
force_all (bool): Accepted for symmetry with the sibling refresh scripts.
It is effectively a no-op because every run already rewrites all
categories from the JSON; only an informational log line is emitted.
Returns:
bool: ``True`` if at least one centroid was stored or one orphan removed,
``False`` on a hard failure (no categories, missing API key, or no usable
examples).
"""
cfg = Config.load()
redis_url = cfg.redis_url or os.environ.get(
"REDIS_URL",
"redis://localhost:6379/0",
)
redis_client = aioredis.from_url(
redis_url,
decode_responses=True,
**cfg.redis_connection_kwargs_for_url(redis_url),
)
init_quota_tracking(redis_client)
try:
raw = load_index()
categories = raw.get("categories") or {}
if not categories:
logger.error("No categories in %s", INDEX_PATH)
return False
api_key = cfg.api_key or os.environ.get("OPENROUTER_API_KEY", "")
if not api_key:
logger.error(
"No API key for embeddings (config.api_key / OPENROUTER_API_KEY)"
)
return False
emb_client = OpenRouterEmbeddings(
api_key=api_key,
model=cfg.embedding_model,
)
tool_queries: dict[str, list[str]] = {}
meta_by_cat: dict[str, tuple[str, list[str]]] = {}
for cat_id, blob in categories.items():
if not isinstance(blob, dict):
continue
label = str(blob.get("label", cat_id))
examples = blob.get("examples") or []
qs = normalize_synthetic_queries(examples)
if not qs:
logger.warning("Category %s: no examples, skipping", cat_id)
continue
tool_queries[cat_id] = qs
meta_by_cat[cat_id] = (label, qs)
if not tool_queries:
logger.error("No usable categories with examples")
return False
existing = set(await scan_dangerous_cmd_category_ids(redis_client))
registered = set(tool_queries.keys())
if force_all:
logger.info("--force-all: refreshing all categories from JSON")
centroids = await compute_tool_centroids_bulk(emb_client, tool_queries)
n_ok = 0
for cat_id, centroid in centroids.items():
if centroid is None:
logger.warning("Category %s: failed to compute centroid", cat_id)
continue
label, qs = meta_by_cat[cat_id]
meta = {
"category_id": cat_id,
"label": label,
"query_count": len(qs),
"version": raw.get("version", 1),
}
await store_dangerous_cmd_embedding_hash(
redis_client,
cat_id,
centroid,
meta,
)
n_ok += 1
logger.info(
"Stored %s (%d examples)",
cat_id,
len(qs),
)
orphaned = existing - registered
for oid in sorted(orphaned):
logger.info("Removing orphaned dangerous_cmd embedding: %s", oid)
await delete_dangerous_cmd_embedding_hash(redis_client, oid)
logger.info(
"Dangerous-command embeddings update complete: %d stored, %d orphaned removed",
n_ok,
len(orphaned),
)
return n_ok > 0 or len(orphaned) > 0
finally:
await redis_client.aclose()
[docs]
async def main() -> None:
"""Parse CLI flags and run the dangerous-command embedding refresh.
Builds an :class:`argparse.ArgumentParser` exposing the single
``--force-all`` flag, then awaits
:func:`update_dangerous_command_embeddings` with the parsed value. The
coroutine handles its own Redis connection (one
``dangerous_cmd_emb:{category_id}`` HASH per category plus the
``idx:dangerous_cmds`` RediSearch index) and OpenRouter embedding calls; on
completion this wrapper translates the boolean result into a process exit
code via ``sys.exit`` (``0`` when at least one category was stored or an
orphan removed, ``1`` otherwise).
This is the module entry point invoked under ``if __name__ ==
"__main__"`` through ``asyncio.run(main())`` (e.g. ``python -m
classifiers.update_dangerous_command_embeddings``); no other internal
callers were found.
Raises:
SystemExit: Always, carrying the success/failure exit code.
"""
parser = argparse.ArgumentParser(
description="Refresh dangerous-command centroid embeddings in Redis",
)
parser.add_argument(
"--force-all",
action="store_true",
help="Same full refresh (default already rewrites all categories from JSON)",
)
args = parser.parse_args()
ok = await update_dangerous_command_embeddings(force_all=args.force_all)
sys.exit(0 if ok else 1)
if __name__ == "__main__":
asyncio.run(main())