Source code for classifiers.vector_classifier

"""Vector-based classifier for tool selection.

Lightweight semantic vector classifier that replaces sending all
tools to the LLM with deterministic vector retrieval.  Pre-computed
centroid embeddings are stored in Redis (legacy monolithic hashes and/or
per-tool ``tool_emb:*`` / per-skill ``skill_emb:*`` HASH documents indexed
by RediSearch).  At query time, RediSearch KNN is used when
``idx:tool_embeddings`` / ``idx:skill_embeddings`` have documents;
otherwise embeddings are loaded and scored in-process (``cosine_batch``).
"""

from __future__ import annotations

import asyncio
import functools
import jsonutil as json
import logging
import os
import re
from typing import Any, Iterable, Mapping

import numpy as np
import redis.asyncio as aioredis

from rag_system.openrouter_embeddings import OpenRouterEmbeddings
from utils.cosine import cosine_batch
from classifiers.tool_embedding_batch import compute_tool_centroids_bulk
from classifiers.tool_prefix_groups import TOOL_NAMED_GROUPS, TOOL_PREFIX_GROUPS
from classifiers.redis_vector_index import (
    knn_search_skills,
    knn_search_tools,
    redisearch_index_doc_count,
    scan_tool_names,
)
from init_redis_indexes import SKILL_INDEX_NAME, TOOL_INDEX_NAME
from observability import publish_classifier_event


def _emit_classifier_observability(
    result: dict[str, Any],
    matches: list[dict[str, Any]] | None,
    *,
    phase: str,
    extra: Mapping[str, Any] | None,
) -> None:
    """Fire-and-forget publish of a classifier decision for observability.

    No-ops immediately when ``extra`` is falsy, so callers can pass the
    optional ``observability_extra`` mapping straight through without guarding.
    Otherwise it distils up to the top 12 entries of ``matches`` into ``name`` /
    ``score`` pairs and schedules :func:`observability.publish_classifier_event`
    via :func:`asyncio.create_task` — a background task, so this function
    returns without awaiting the publish (which typically writes to the
    observability Redis stream). Request/channel/user/platform identifiers are
    read out of ``extra``; the strategy and tool list are read out of
    ``result``.

    Called by :meth:`VectorClassifier.classify` and
    :meth:`VectorClassifier.classify_response_for_missing_tools` in this module
    at each return point; no external callers were found.

    Args:
        result (dict[str, Any]): The classifier result dict; ``strategy`` and
            ``tools`` are read from it for the event payload.
        matches (list[dict[str, Any]] | None): Ranked vector matches whose top
            12 ``name``/``score`` pairs are reported, or ``None``.
        phase (str): Pipeline phase label (e.g. ``"user_message"`` or
            ``"assistant_response"``) attached to the event.
        extra (Mapping[str, Any] | None): Optional context carrying
            ``request_id``, ``channel_id``, ``user_id``, and ``platform``;
            when falsy the function returns without emitting anything.

    Returns:
        None: The event is published on a detached background task.
    """
    if not extra:
        return
    top: list[dict[str, Any]] = []
    for m in (matches or [])[:12]:
        top.append(
            {
                "name": m.get("name", ""),
                "score": float(m.get("score", 0.0)),
            }
        )
    asyncio.create_task(
        publish_classifier_event(
            request_id=str(extra.get("request_id", "")),
            channel_id=str(extra.get("channel_id", "")),
            user_id=str(extra.get("user_id", "")),
            platform=str(extra.get("platform", "")),
            strategy=str(result.get("strategy", "")),
            tools=list(result.get("tools") or []),
            top_matches=top,
            phase=phase,
        ),
    )


logger = logging.getLogger(__name__)

TOOL_EMBEDDINGS_HASH_KEY = "stargazer:tool_embeddings"
TOOL_METADATA_HASH_KEY = "stargazer:tool_metadata"

SKILL_EMBEDDINGS_HASH_KEY = "stargazer:skill_embeddings"
SKILL_METADATA_HASH_KEY = "stargazer:skill_metadata"

DEFAULT_SIMILARITY_THRESHOLD = 0.30
DEFAULT_SKILL_SIMILARITY_THRESHOLD = 0.12
DEFAULT_SKILL_TOP_K = 12
DEFAULT_TOP_K = 15

TOOL_EXPANSION_THRESHOLD = 0.85

DEFAULT_STRATEGY_FORCE_THRESHOLD = 0.80
DEFAULT_STRATEGY_OPTIONAL_THRESHOLD = 0.30
DEFAULT_GROUP_EXPANSION_THRESHOLD = 0.55
# Browser automation tools are noisy in retrieval; require a stronger match.
DEFAULT_BROWSER_TOOL_SIMILARITY_THRESHOLD = 0.60
BROWSER_TOOL_NAME_PREFIX = "browser_"

# Vector matches for these prefixes are dropped unless *text* contains at least
# one keyword (case-insensitive substring). Explicit tool names in the user
# message are always kept (tokens like ``gitea_push`` contain the keyword).
_TOOL_PREFIX_LEXICAL_GATES: tuple[tuple[str, tuple[str, ...]], ...] = (
    (BROWSER_TOOL_NAME_PREFIX, ("browser",)),
    ("sporestack_", ("sporestack",)),
    ("gitea_", ("gitea",)),
    ("minecraft_", ("minecraft",)),
    ("tailscale_", ("tailscale",)),
    ("terraform_", ("terraform",)),
    ("redis_", ("redis",)),
    # Tools use k8s_* prefix; allow either spelling in natural language.
    ("k8s_", ("kubernetes", "k8s")),
)


def _filter_gated_prefix_vector_matches(
    matches: list[dict[str, Any]],
    text: str,
    *,
    explicit_names: set[str] | frozenset[str] | None = None,
) -> list[dict[str, Any]]:
    """Drop vector hits for noisy prefixes unless *text* mentions the product keyword.

    Lexical guard layered on top of pure vector retrieval: several tool
    families (browser automation, sporestack, gitea, minecraft, tailscale,
    terraform, redis, k8s) embed close to many unrelated requests, so a raw
    cosine match is not trustworthy on its own. This walks each candidate and,
    for any name whose prefix appears in :data:`_TOOL_PREFIX_LEXICAL_GATES`,
    keeps it only when at least one of that family's keywords occurs as a
    case-insensitive substring of *text*; names listed in ``explicit_names``
    bypass the gate entirely so an explicitly typed tool is never filtered out.
    Pure in-memory work with no I/O; it logs an ``INFO`` summary of how many
    matches were dropped per prefix.

    Called by :meth:`VectorClassifier.classify` and
    :meth:`VectorClassifier.classify_response_for_missing_tools` in this module,
    and exercised directly by ``tests/test_gated_prefix_vector_gate.py``.

    Args:
        matches (list[dict[str, Any]]): Ranked vector matches, each carrying a
            ``name`` and ``score``; returned unchanged when empty.
        text (str): The message or response text scanned for product keywords.
        explicit_names (set[str] | frozenset[str] | None): Tool names that were
            named verbatim by the user and must never be gated out.

    Returns:
        list[dict[str, Any]]: The surviving matches in their original order.
    """
    if not matches:
        return matches
    allowed_explicit = explicit_names or frozenset()
    t = (text or "").lower()
    out: list[dict[str, Any]] = []
    dropped_by_prefix: dict[str, int] = {}
    for m in matches:
        name = str(m.get("name", ""))
        if name in allowed_explicit:
            out.append(m)
            continue
        drop = False
        for prefix, keywords in _TOOL_PREFIX_LEXICAL_GATES:
            if name.startswith(prefix):
                if not any(kw in t for kw in keywords):
                    drop = True
                    dropped_by_prefix[prefix] = dropped_by_prefix.get(prefix, 0) + 1
                break
        if not drop:
            out.append(m)
    if dropped_by_prefix:
        logger.info(
            "Dropped gated-prefix vector match(es) (keyword not in message): %s",
            dropped_by_prefix,
        )
    return out


# Maximal ASCII “word” runs — same characters Python uses in ``re`` word
# boundaries for typical ``snake_case`` tool names (avoids Unicode ``\\w``
# so we do not split identifiers on UTF-8 letters).
_EXPLICIT_TOOL_TOKEN_RE = re.compile(r"[A-Za-z0-9_]+")

# Inline code / backtick-wrapped spans (one line); inner text is stripped
# and must match a registered name exactly.
_EXPLICIT_BACKTICK_RE = re.compile(r"`([^`\n]+)`")

# Run explicit-name scan off the event loop for huge pastes (chars).
_EXPLICIT_SCAN_TO_THREAD_CHARS = 96_000


[docs] def detect_tool_request_keywords(response_text: str) -> bool: """Return *True* when the bot seems to request missing tools. This lightweight regex check gates the heavier embedding-based tool expansion to avoid false positives on legitimate no-tool responses. """ if not response_text: return False text_lower = response_text.lower() patterns = [ r"\bi (?:still )?(?:need|lack|require|am missing" r"|don't have|do not have)\b", r"\bwithout (?:the |access to )?(?:\w+ )?tool", r"\bunable to (?:use|access|call|execute)\b", r"\bmy tool ?belt (?:doesn't|does not)" r" (?:have|contain|include)\b", r"\bgive me (?:the )?\w+ tool", r"\bi need [a-z_]+ tool", r"\bmissing (?:the )?(?:ability|capability" r"|tool|function)", r"\bdon't have (?:the |access to )?" r"(?:\w+ )?(?:tool|function|command)", ] for pattern in patterns: if re.search(pattern, text_lower): logger.debug( "Tool request keyword detected: %r", pattern, ) return True return False
@functools.lru_cache(maxsize=32) def _explicit_tool_lookup(names_key: tuple[str, ...]) -> frozenset[str]: """Cached frozenset of tool names for fast token look-ups. Memoizes the registered-tool name set so repeated explicit-mention scans of user messages do not rebuild a frozenset on every call. The argument is a pre-sorted name tuple (a hashable cache key); wrapping it in :func:`functools.lru_cache` (maxsize 32) means the same tool registry hits the cache across messages. Pure in-memory work with no I/O. Called only by :func:`find_tools_explicitly_named` in this module, which builds the sorted-name key before doing its token lookups. Args: names_key (tuple[str, ...]): Sorted, de-duplicated tool names used both as the lookup set and the cache key. Returns: frozenset[str]: The names as an immutable set for membership tests. """ return frozenset(names_key)
[docs] def find_tools_explicitly_named( message: str, valid_names: Iterable[str], ) -> list[str]: """Return tool names that appear verbatim in *message* as whole tokens. Detection: * Maximal runs of ASCII letters, digits, and underscores (typical ``snake_case`` tools), equivalent to word boundaries for those names. * Text inside ASCII backticks (``inline code``): inner text is stripped and must match a registered tool name **exactly**, so names containing hyphens or other punctuation still match when quoted. Hits are ordered by first occurrence in the message; each tool appears at most once. """ if not message: return [] uniq: list[str] = [] seen: set[str] = set() for raw in valid_names: name = raw.strip() if isinstance(raw, str) else str(raw) if not name or name in seen: continue seen.add(name) uniq.append(name) if not uniq: return [] lookup = _explicit_tool_lookup(tuple(sorted(seen))) events: list[tuple[int, str]] = [] for m in _EXPLICIT_TOOL_TOKEN_RE.finditer(message): token = m.group(0) if token in lookup: events.append((m.start(), token)) for m in _EXPLICIT_BACKTICK_RE.finditer(message): inner = m.group(1).strip() if inner in lookup: events.append((m.start(1), inner)) events.sort(key=lambda t: t[0]) ordered_hits: list[str] = [] hit_seen: set[str] = set() for _pos, name in events: if name not in hit_seen: hit_seen.add(name) ordered_hits.append(name) if ordered_hits: logger.debug( "Explicit tool names in message: %s", ordered_hits, ) return ordered_hits
[docs] class VectorClassifier: """Semantic vector-based classifier for tool selection. Parameters ---------- redis_client: An async Redis connection (``redis.asyncio.Redis``). similarity_threshold: Minimum cosine similarity for a match. top_k: Maximum number of tools to return. api_key: OpenRouter API key. Falls back to the ``OPENROUTER_API_KEY`` env var. """
[docs] def __init__( self, redis_client: aioredis.Redis, similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD, top_k: int = DEFAULT_TOP_K, api_key: str | None = None, *, strategy_force_threshold: float = DEFAULT_STRATEGY_FORCE_THRESHOLD, strategy_optional_threshold: float = DEFAULT_STRATEGY_OPTIONAL_THRESHOLD, group_expansion_threshold: float = DEFAULT_GROUP_EXPANSION_THRESHOLD, browser_tool_similarity_threshold: float = ( DEFAULT_BROWSER_TOOL_SIMILARITY_THRESHOLD ), ) -> None: """Store the Redis client and retrieval thresholds for later queries. Records the async Redis connection and every tuneable threshold but performs no I/O: tool and skill embedding caches (``_tool_embeddings_cache`` / ``_skill_embeddings_cache``), their ``(N, D)`` matrices, ordered name lists, and the cached RediSearch document counts all start empty/``None`` and are populated lazily on first use by :meth:`_load_tool_embeddings`, :meth:`_load_skill_embeddings`, and the ``_*_redisearch_has_docs`` probes. The OpenRouter embedding client is likewise deferred to :meth:`_get_embedding_client`; only the API key is resolved now, falling back to the ``OPENROUTER_API_KEY`` environment variable when ``api_key`` is ``None``. Emits a single configuration ``INFO`` log line. Called wherever a :class:`VectorClassifier` is constructed across the services that perform tool/skill routing; this dunder is not invoked directly by name elsewhere. Args: redis_client (aioredis.Redis): Async Redis connection used for all embedding reads and RediSearch KNN queries. similarity_threshold (float): Minimum cosine similarity for a tool match to be kept. top_k (int): Maximum number of tools returned from retrieval. api_key (str | None): OpenRouter API key; falls back to ``OPENROUTER_API_KEY`` when ``None``. strategy_force_threshold (float): Top-score cutoff above which the strategy becomes ``"force"``. strategy_optional_threshold (float): Top-score cutoff above which the strategy becomes ``"optional"``. group_expansion_threshold (float): Minimum score for a tool to trigger prefix/named-group expansion. browser_tool_similarity_threshold (float): Stronger minimum score required to keep noisy ``browser_*`` tool matches. """ self._redis = redis_client self.similarity_threshold = similarity_threshold self.top_k = top_k self.strategy_force_threshold = strategy_force_threshold self.strategy_optional_threshold = strategy_optional_threshold self.group_expansion_threshold = group_expansion_threshold self.browser_tool_similarity_threshold = browser_tool_similarity_threshold self._embedding_client: OpenRouterEmbeddings | None = None self._api_key = api_key or os.environ.get( "OPENROUTER_API_KEY", "", ) self._tool_embeddings_cache: dict[str, np.ndarray] | None = None self._tool_metadata_cache: dict[str, dict[str, Any]] | None = None self._tool_names_list: list[str] = [] self._tool_embeddings_matrix: np.ndarray | None = None self._skill_embeddings_cache: dict[str, np.ndarray] | None = None self._skill_metadata_cache: dict[str, dict[str, Any]] | None = None self._skill_ids_list: list[str] = [] self._skill_embeddings_matrix: np.ndarray | None = None self._cached_tool_rs_docs: int | None = None self._cached_skill_rs_docs: int | None = None logger.info( "VectorClassifier initialized: " "threshold=%s, top_k=%s, group_expansion=%s, " "strategy_force=%s, strategy_optional=%s, browser_threshold=%s", similarity_threshold, top_k, group_expansion_threshold, strategy_force_threshold, strategy_optional_threshold, browser_tool_similarity_threshold, )
def _similarity_threshold_for_tool_name(self, name: str) -> float: """Minimum cosine score to keep *name* from vector retrieval. Resolves the per-tool acceptance threshold, applying the stricter ``browser_tool_similarity_threshold`` to noisy ``browser_*`` tools (which retrieve too eagerly) and the general ``similarity_threshold`` to everything else. Pure attribute read with no I/O. Called by :meth:`_find_matching_tools`, :meth:`_legacy_find_matching_tools_sync`, and :meth:`classify_response_for_missing_tools` in this module to filter and rank candidate matches. Args: name (str): The tool name whose threshold is being resolved. Returns: float: The minimum cosine similarity required to keep the tool. """ if name.startswith(BROWSER_TOOL_NAME_PREFIX): return self.browser_tool_similarity_threshold return self.similarity_threshold def _group_expand_min_score_for_tool(self, name: str) -> float: """Min score to expand a prefix/named group triggered by *name*. Resolves the score a matched tool must reach before its whole prefix/named group is pulled in. For ``browser_*`` tools it returns the larger of the general ``group_expansion_threshold`` and the stricter browser threshold, so a weak browser hit never floods the tool set with its siblings; other tools use ``group_expansion_threshold`` directly. Pure attribute read with no I/O. Called by :meth:`_expand_tool_prefixes` in this module while deciding which matches are strong enough to trigger group expansion. Args: name (str): The matched tool name driving the expansion decision. Returns: float: The minimum score required to expand this tool's group. """ if name.startswith(BROWSER_TOOL_NAME_PREFIX): return max( self.group_expansion_threshold, self.browser_tool_similarity_threshold, ) return self.group_expansion_threshold # -------------------------------------------------------------- # Embedding client # -------------------------------------------------------------- async def _get_embedding_client(self) -> OpenRouterEmbeddings: """Lazily build and cache the OpenRouter embedding client. Constructs an :class:`~rag_system.openrouter_embeddings.OpenRouterEmbeddings` on first use (using the API key resolved in ``__init__``) and reuses the same instance thereafter, so the underlying HTTP session is created only once per classifier. No network call happens here — the client only issues HTTP requests later when its ``embed_text`` method runs. Called by :meth:`_get_query_embedding` in this module; the cached client is torn down by :meth:`close`. Returns: OpenRouterEmbeddings: The shared embedding client for this classifier. """ if self._embedding_client is None: self._embedding_client = OpenRouterEmbeddings( api_key=self._api_key, ) return self._embedding_client # -------------------------------------------------------------- # RediSearch KNN (preferred) vs legacy monolithic hashes # -------------------------------------------------------------- async def _tool_redisearch_has_docs(self) -> bool: """Report whether the tool RediSearch index holds any documents. Decides which retrieval path the classifier takes: when the ``idx:tool_embeddings`` index has documents, RediSearch KNN is used; otherwise the legacy monolithic-hash batch cosine path runs. Calls :func:`classifiers.redis_vector_index.redisearch_index_doc_count` against Redis once and memoizes a non-negative count in ``_cached_tool_rs_docs`` so subsequent checks avoid the round trip. A negative count (index missing or error) is treated as not cached and is re-probed next time. Called by :meth:`_find_matching_tools`, :meth:`classify`, :meth:`classify_response_for_missing_tools`, and :meth:`_ensure_tool_names_for_expansion` in this module. Returns: bool: ``True`` if at least one tool document is indexed. """ if self._cached_tool_rs_docs is not None and self._cached_tool_rs_docs >= 0: return self._cached_tool_rs_docs > 0 n = await redisearch_index_doc_count(self._redis, TOOL_INDEX_NAME) if n >= 0: self._cached_tool_rs_docs = n return n > 0 async def _skill_redisearch_has_docs(self) -> bool: """Report whether the skill RediSearch index holds any documents. Skill-side analogue of :meth:`_tool_redisearch_has_docs`: when the ``idx:skill_embeddings`` index has documents the classifier uses RediSearch KNN for skills, otherwise it falls back to the legacy batch-cosine path. Queries Redis once via :func:`classifiers.redis_vector_index.redisearch_index_doc_count` and caches a non-negative count in ``_cached_skill_rs_docs``; a negative result is re-probed on the next call. Called by :meth:`_find_matching_skills` and :meth:`classify_skills` in this module. Returns: bool: ``True`` if at least one skill document is indexed. """ if self._cached_skill_rs_docs is not None and self._cached_skill_rs_docs >= 0: return self._cached_skill_rs_docs > 0 n = await redisearch_index_doc_count(self._redis, SKILL_INDEX_NAME) if n >= 0: self._cached_skill_rs_docs = n return n > 0 async def _ensure_tool_names_for_expansion(self) -> None: """Populate ``_tool_names_list`` for :meth:`_expand_tool_prefixes`. Group expansion needs the full universe of registered tool names, but the RediSearch KNN path only returns the top matches. This backfills ``_tool_names_list`` when it is empty: it scans every indexed tool name from Redis via :func:`classifiers.redis_vector_index.scan_tool_names` when the tool index has documents, and otherwise (or if the scan came back empty) falls back to :meth:`_load_tool_embeddings`, which loads the legacy monolithic hash and rebuilds the name list as a side effect. Returns once the list is non-empty and is a no-op on later calls. Called by :meth:`classify_response_for_missing_tools` in this module before it expands prefixes on the response path. """ if self._tool_names_list: return if await self._tool_redisearch_has_docs(): self._tool_names_list = await scan_tool_names(self._redis) if not self._tool_names_list: await self._load_tool_embeddings() # -------------------------------------------------------------- # Load / cache tool embeddings from Redis # -------------------------------------------------------------- async def _load_tool_embeddings( self, force_reload: bool = False, ) -> bool: """Load and cache the legacy monolithic tool-embedding hashes. Fallback loader for when RediSearch KNN is unavailable: reads the precomputed centroid vectors from the ``stargazer:tool_embeddings`` hash and per-tool metadata from ``stargazer:tool_metadata`` via Redis ``HGETALL``, decoding each JSON-encoded vector into a ``float32`` numpy array. It populates ``_tool_embeddings_cache`` and ``_tool_metadata_cache``, builds the ordered ``_tool_names_list``, and stacks an ``(N, D)`` matrix into ``_tool_embeddings_matrix`` for batch cosine scoring. Already-loaded state is reused unless ``force_reload`` is set. Per-entry parse failures are logged and skipped; a missing hash or a top-level error is logged and yields ``False``. Logs an ``INFO`` count on success. Called by :meth:`classify`, :meth:`_find_matching_tools`, :meth:`classify_response_for_missing_tools`, and :meth:`_ensure_tool_names_for_expansion` in this module. Args: force_reload (bool): When ``True``, re-read from Redis even if a cache is already populated. Returns: bool: ``True`` when embeddings were loaded (or already cached), ``False`` on a missing hash or error. """ if self._tool_embeddings_cache is not None and not force_reload: return True try: embeddings_data: dict = await self._redis.hgetall( TOOL_EMBEDDINGS_HASH_KEY, ) if not embeddings_data: logger.warning( "No tool embeddings in Redis key: %s", TOOL_EMBEDDINGS_HASH_KEY, ) return False self._tool_embeddings_cache = {} for name, emb_json in embeddings_data.items(): try: if isinstance(name, bytes): name = name.decode("utf-8") if isinstance(emb_json, bytes): emb_json = emb_json.decode("utf-8") vec = np.array( json.loads(emb_json), dtype=np.float32, ) self._tool_embeddings_cache[name] = vec except Exception as exc: logger.warning( "Failed to parse embedding for " "tool %r: %s", name, exc, ) meta_data: dict = await self._redis.hgetall( TOOL_METADATA_HASH_KEY, ) self._tool_metadata_cache = {} for name, meta_json in meta_data.items(): try: if isinstance(name, bytes): name = name.decode("utf-8") if isinstance(meta_json, bytes): meta_json = meta_json.decode("utf-8") self._tool_metadata_cache[name] = json.loads(meta_json) except Exception as exc: logger.warning( "Failed to parse metadata for " "tool %r: %s", name, exc, ) # Build ordered names list and (N, D) matrix for batch cosine. self._tool_names_list = list(self._tool_embeddings_cache.keys()) if self._tool_names_list: self._tool_embeddings_matrix = np.stack( [self._tool_embeddings_cache[n] for n in self._tool_names_list], axis=0, ) else: self._tool_embeddings_matrix = None logger.info( "Loaded %d tool embeddings from Redis", len(self._tool_embeddings_cache), ) return True except Exception as exc: logger.error( "Failed to load tool embeddings: %s", exc, ) return False async def _load_skill_embeddings( self, force_reload: bool = False, ) -> bool: """Load skill centroid embeddings from Redis (separate from tools). Skill-side analogue of :meth:`_load_tool_embeddings` and the fallback for when skill RediSearch KNN is unavailable. Reads centroid vectors from the ``stargazer:skill_embeddings`` hash and metadata from ``stargazer:skill_metadata`` via Redis ``HGETALL``, decodes each JSON-encoded vector into a ``float32`` numpy array, and populates ``_skill_embeddings_cache``, ``_skill_metadata_cache``, the ordered ``_skill_ids_list``, and the ``(N, D)`` ``_skill_embeddings_matrix`` used for batch cosine. Already-loaded state is reused unless ``force_reload`` is set; an empty hash resets the caches to empty and returns ``False``. Per-entry parse failures are logged and skipped, and a top-level error yields ``False``. Logs an ``INFO`` count on success. Called by :meth:`_find_matching_skills` and :meth:`classify_skills` in this module. Args: force_reload (bool): When ``True``, re-read from Redis even if a cache is already populated. Returns: bool: ``True`` when skill embeddings were loaded (or already cached), ``False`` when none exist or an error occurred. """ if self._skill_embeddings_cache is not None and not force_reload: return True try: embeddings_data: dict = await self._redis.hgetall( SKILL_EMBEDDINGS_HASH_KEY, ) if not embeddings_data: self._skill_embeddings_cache = {} self._skill_metadata_cache = {} self._skill_ids_list = [] self._skill_embeddings_matrix = None return False self._skill_embeddings_cache = {} for sid, emb_json in embeddings_data.items(): try: if isinstance(sid, bytes): sid = sid.decode("utf-8") if isinstance(emb_json, bytes): emb_json = emb_json.decode("utf-8") vec = np.array( json.loads(emb_json), dtype=np.float32, ) self._skill_embeddings_cache[sid] = vec except Exception as exc: logger.warning( "Failed to parse skill embedding for %r: %s", sid, exc, ) meta_data: dict = await self._redis.hgetall( SKILL_METADATA_HASH_KEY, ) self._skill_metadata_cache = {} for sid, meta_json in meta_data.items(): try: if isinstance(sid, bytes): sid = sid.decode("utf-8") if isinstance(meta_json, bytes): meta_json = meta_json.decode("utf-8") self._skill_metadata_cache[sid] = json.loads(meta_json) except Exception as exc: logger.warning( "Failed to parse skill metadata for %r: %s", sid, exc, ) self._skill_ids_list = list( self._skill_embeddings_cache.keys(), ) if self._skill_ids_list: self._skill_embeddings_matrix = np.stack( [self._skill_embeddings_cache[s] for s in self._skill_ids_list], axis=0, ) else: self._skill_embeddings_matrix = None logger.info( "Loaded %d skill embeddings from Redis", len(self._skill_embeddings_cache), ) return True except Exception as exc: logger.error("Failed to load skill embeddings: %s", exc) return False # -------------------------------------------------------------- # Query embedding # -------------------------------------------------------------- async def _get_query_embedding( self, query: str, ) -> np.ndarray | None: """Embed *query* into a single dense vector via OpenRouter. Turns the user message (or assistant response) into the query vector that all downstream cosine and KNN scoring runs against. Obtains the lazily-built client from :meth:`_get_embedding_client` and awaits its ``embed_text``, which issues an HTTP embedding request to OpenRouter. Failures and empty embeddings are logged and collapsed to ``None`` so callers can fall back to explicit-mention-only handling. Called by :meth:`classify`, :meth:`classify_skills`, and :meth:`classify_response_for_missing_tools` in this module whenever a precomputed embedding was not supplied. Args: query (str): The text to embed. Returns: np.ndarray | None: The query embedding, or ``None`` on error or an empty result. """ try: client = await self._get_embedding_client() embedding = await client.embed_text(query) if embedding.size == 0: logger.warning( "Empty embedding returned for query", ) return None return embedding except Exception as exc: logger.error( "Failed to get query embedding: %s", exc, ) return None # -------------------------------------------------------------- # Similarity search # -------------------------------------------------------------- async def _find_matching_tools( self, query_embedding: np.ndarray, ) -> list[dict[str, Any]]: """Find tools via RediSearch KNN when indexed, else legacy batch cosine. Core retrieval step that ranks registered tools against the query embedding. When :meth:`_tool_redisearch_has_docs` reports an indexed corpus it runs an over-fetched KNN search through :func:`classifiers.redis_vector_index.knn_search_tools` against Redis, then filters every row by the per-tool threshold from :meth:`_similarity_threshold_for_tool_name`, sorts by score, and keeps the top ``top_k``. Otherwise it ensures the legacy hash cache is loaded via :meth:`_load_tool_embeddings` and offloads the CPU-bound matmul to :meth:`_legacy_find_matching_tools_sync` on a worker thread. Called by :meth:`classify` and :meth:`classify_response_for_missing_tools` in this module, and by ``tests/test_vector_redisearch_knn.py``. Args: query_embedding (np.ndarray): The embedded query to score tools against. Returns: list[dict[str, Any]]: Up to ``top_k`` matches sorted by descending score, each with at least ``name`` and ``score``. """ if await self._tool_redisearch_has_docs(): knn_k = max(self.top_k * 8, 256, 500) raw = await knn_search_tools( self._redis, query_embedding, knn_k=knn_k, ) scores = [ row for row in raw if row["score"] >= self._similarity_threshold_for_tool_name( row["name"], ) ] scores.sort(key=lambda x: x["score"], reverse=True) return scores[: self.top_k] if self._tool_embeddings_cache is None: await self._load_tool_embeddings() if not self._tool_embeddings_cache or self._tool_embeddings_matrix is None: return [] return await asyncio.to_thread( self._legacy_find_matching_tools_sync, query_embedding, ) def _legacy_find_matching_tools_sync( self, query_embedding: np.ndarray, ) -> list[dict[str, Any]]: """CPU-heavy RediSearch fallback: matmul + filter + sort (thread pool). Synchronous in-process scorer used when no tool RediSearch index is available. It computes cosine similarity between the query and the cached ``(N, D)`` ``_tool_embeddings_matrix`` via :func:`utils.cosine.cosine_batch`, keeps each tool whose score clears the per-name threshold from :meth:`_similarity_threshold_for_tool_name`, attaches cached metadata, sorts by score, and truncates to ``top_k``. No I/O — it reads only in-memory caches; the numpy matmul makes it the CPU-bound part, which is why the caller runs it on a worker thread. Called by :meth:`_find_matching_tools` in this module (via :func:`asyncio.to_thread`). Args: query_embedding (np.ndarray): The embedded query to score tools against. Returns: list[dict[str, Any]]: Up to ``top_k`` matches with ``name``, ``score``, and ``metadata``, sorted by descending score. """ sims = cosine_batch( query_embedding, self._tool_embeddings_matrix, ) meta_cache = self._tool_metadata_cache or {} scores = [] for i, name in enumerate(self._tool_names_list): thr = self._similarity_threshold_for_tool_name(name) if sims[i] >= thr: scores.append( { "name": name, "score": float(sims[i]), "metadata": meta_cache.get(name, {}), } ) scores.sort(key=lambda x: x["score"], reverse=True) return scores[: self.top_k] async def _find_matching_skills( self, query_embedding: np.ndarray, *, similarity_threshold: float, top_k: int, ) -> list[dict[str, Any]]: """Cosine similarity via RediSearch KNN when indexed, else legacy batch. Skill-side counterpart to :meth:`_find_matching_tools`. When :meth:`_skill_redisearch_has_docs` reports an indexed corpus it runs an over-fetched KNN search through :func:`classifiers.redis_vector_index.knn_search_skills` against Redis, filters rows by ``similarity_threshold``, sorts, and keeps the top ``top_k``. Otherwise it ensures the legacy cache is loaded via :meth:`_load_skill_embeddings` and offloads the matmul to :meth:`_legacy_find_matching_skills_sync` on a worker thread. Called by :meth:`classify_skills` in this module. Args: query_embedding (np.ndarray): The embedded query to score skills against. similarity_threshold (float): Minimum cosine score for a skill to be kept. top_k (int): Maximum number of skills to return. Returns: list[dict[str, Any]]: Up to ``top_k`` skill matches sorted by descending score. """ if await self._skill_redisearch_has_docs(): knn_k = max(top_k * 50, 2000, 512) raw = await knn_search_skills( self._redis, query_embedding, knn_k=knn_k, ) scores = [row for row in raw if row["score"] >= similarity_threshold] scores.sort(key=lambda x: x["score"], reverse=True) return scores[:top_k] if self._skill_embeddings_cache is None: await self._load_skill_embeddings() if not self._skill_embeddings_cache or self._skill_embeddings_matrix is None: return [] return await asyncio.to_thread( self._legacy_find_matching_skills_sync, query_embedding, similarity_threshold, top_k, ) def _legacy_find_matching_skills_sync( self, query_embedding: np.ndarray, similarity_threshold: float, top_k: int, ) -> list[dict[str, Any]]: """CPU-heavy RediSearch fallback for skills (thread pool). Synchronous in-process skill scorer used when no skill RediSearch index is available. It computes cosine similarity between the query and the cached ``(N, D)`` ``_skill_embeddings_matrix`` via :func:`utils.cosine.cosine_batch`, keeps every skill clearing ``similarity_threshold``, enriches each with cached name/description metadata, sorts by score, and truncates to ``top_k``. No I/O — it reads only in-memory caches; the numpy matmul is the CPU-bound part, so the caller runs it on a worker thread. Called by :meth:`_find_matching_skills` in this module (via :func:`asyncio.to_thread`). Args: query_embedding (np.ndarray): The embedded query to score skills against. similarity_threshold (float): Minimum cosine score for a skill to be kept. top_k (int): Maximum number of skills to return. Returns: list[dict[str, Any]]: Up to ``top_k`` skill matches with ``skill_id``, ``name``, ``description``, ``score``, and ``metadata``, sorted by descending score. """ sims = cosine_batch( query_embedding, self._skill_embeddings_matrix, ) meta_cache = self._skill_metadata_cache or {} scores: list[dict[str, Any]] = [ { "skill_id": sid, "name": meta_cache.get(sid, {}).get("name", sid), "description": meta_cache.get(sid, {}).get("description", ""), "score": float(sims[i]), "metadata": meta_cache.get(sid, {}), } for i, sid in enumerate(self._skill_ids_list) if sims[i] >= similarity_threshold ] scores.sort(key=lambda x: x["score"], reverse=True) return scores[:top_k] def _expand_tool_prefixes( self, tool_names: list[str], *, tool_scores: dict[str, float] | None = None, registry_tool_names: Iterable[str] | None = None, ) -> tuple[list[str], dict[str, list[str]]]: """Score-aware group expansion. Tools scoring at or above ``group_expansion_threshold`` (or explicitly mentioned by the user) trigger full prefix/named-group expansion. Tools below that threshold are kept individually without pulling in their siblings. Returns ``(expanded_names, triggered_groups)`` where *triggered_groups* maps prefix/group-label to the tools that were added by expansion (for observability). """ names_list = self._tool_names_list if not names_list and registry_tool_names is not None: names_list = list(dict.fromkeys(registry_tool_names)) if not names_list: return list(tool_names), {} scores = tool_scores or {} result = set(tool_names) triggered: dict[str, list[str]] = {} for name in tool_names: score = scores.get(name, 1.0) expand_min = self._group_expand_min_score_for_tool(name) if score < expand_min: continue for prefix in TOOL_PREFIX_GROUPS: if name.startswith(prefix): added: list[str] = [] for t in names_list: if t.startswith(prefix) and t not in result: result.add(t) added.append(t) if added: triggered[f"prefix:{prefix}"] = added break available = set(names_list) if registry_tool_names is not None: available |= set(registry_tool_names) for group in TOOL_NAMED_GROUPS: expanding_members = [ n for n in result if n in group and scores.get(n, 1.0) >= self._group_expand_min_score_for_tool(n) ] if expanding_members: added = [] for t in group: if t in available and t not in result: result.add(t) added.append(t) if added: label = next(iter(group)) triggered[f"named:{label}"] = added return list(result), triggered # -------------------------------------------------------------- # Public API # --------------------------------------------------------------
[docs] async def classify( self, message: str, query_embedding: np.ndarray | None = None, registry_tool_names: Iterable[str] | None = None, *, scan_explicit_tool_mentions: bool = True, observability_extra: Mapping[str, Any] | None = None, ) -> dict[str, Any]: """Classify *message* and return tool names + strategy. Parameters ---------- query_embedding: Pre-computed embedding for *message*. When provided the internal embedding API call is skipped. registry_tool_names: Registered tool names (e.g. registry keys). When provided, any name that appears as a whole token in *message* is included in the tool set alongside vector matches. scan_explicit_tool_mentions: When ``True`` (default), scan *message* for explicit registered tool names. Set ``False`` for non-user text (e.g. assistant drafts, response postprocessing) so mentions in those strings never inflate the tool set. Returns ------- dict A dict with keys ``tools``, ``strategy``, ``complexity``, and ``safety``. """ logger.info( "VectorClassifier.classify() for: %s", message[:100] if message else "<blank>", ) # ── blank / whitespace-only messages ────────────────────── # An empty query produces a near-zero embedding whose cosine # similarity is >= threshold for *every* tool, returning all # tools and blowing past the 512 function-declaration limit. if not message or not message.strip(): logger.info( "Blank message received, returning empty " "tool set (strategy=none)", ) out = { "complexity": "moderate", "safety": "safe", "strategy": "none", "tools": [], } _emit_classifier_observability( out, None, phase="user_message", extra=observability_extra, ) return out if scan_explicit_tool_mentions: _reg = registry_tool_names or () if len(message) >= _EXPLICIT_SCAN_TO_THREAD_CHARS: explicit = await asyncio.to_thread( find_tools_explicitly_named, message, _reg, ) else: explicit = find_tools_explicitly_named(message, _reg) else: explicit = [] def _build_result( tools: list[str], strategy: str, *, scores: dict[str, float] | None = None, triggered_groups: dict[str, list[str]] | None = None, activated_tiers: list[str] | None = None, ) -> dict[str, Any]: """Assemble a classifier result dict and fold in tiered tools. Closure over the enclosing :meth:`classify` call (it reads the outer ``message`` for tier keyword matching). Builds the standard result mapping with ``complexity``, ``safety``, ``strategy``, ``tools`` (a fresh copy of ``tools``), and ``tool_scores``. For any non-``"none"`` strategy it appends the always-on essentials from :meth:`_get_essential_tools` and the contextually activated tiers from :meth:`_get_tier_tools`, de-duplicating against tools already present. When ``activated_tiers`` is supplied, the activated tier names are extended onto that caller-provided list in place. Used on the explicit-only fallback paths where embeddings or the query embedding are unavailable. Args: tools (list[str]): Base tool names selected so far. strategy (str): Routing strategy (``"force"``, ``"optional"``, or ``"none"``); tier expansion is skipped for ``"none"``. scores (dict[str, float] | None): Per-tool similarity scores to record under ``tool_scores``. triggered_groups (dict[str, list[str]] | None): Group-expansion detail, accepted for call-site symmetry. activated_tiers (list[str] | None): If provided, activated tier names are appended to this list as a side effect. Returns: dict[str, Any]: The assembled result mapping. """ out: dict[str, Any] = { "complexity": "moderate", "safety": "safe", "strategy": strategy, "tools": list(tools), "tool_scores": dict(scores or {}), } if out["strategy"] != "none": for tool in self._get_essential_tools(): if tool not in out["tools"]: out["tools"].append(tool) tier_tools, tiers = self._get_tier_tools(out["tools"], message) for t in tier_tools: out["tools"].append(t) if activated_tiers is not None: activated_tiers.extend(tiers) return out if await self._tool_redisearch_has_docs(): embeddings_loaded = True else: embeddings_loaded = await self._load_tool_embeddings() result: dict[str, Any] = { "complexity": "moderate", "safety": "safe", "strategy": "optional", "tools": [], "tool_scores": {}, } if not embeddings_loaded: logger.warning( "Tool embeddings not available, " "returning default result", ) if explicit: explicit_scores = {n: 1.0 for n in explicit} expanded, tgroups = self._expand_tool_prefixes( list(explicit), tool_scores=explicit_scores, registry_tool_names=_reg, ) activated: list[str] = [] result = _build_result( expanded, "optional", scores=explicit_scores, triggered_groups=tgroups, activated_tiers=activated, ) logger.info( "VectorClassifier result (explicit-only, no " "embeddings): strategy=%s, tools_count=%d", result["strategy"], len(result["tools"]), ) _emit_classifier_observability( result, None, phase="user_message", extra=observability_extra, ) return result query_emb = ( query_embedding if query_embedding is not None else await self._get_query_embedding(message) ) if query_emb is None: logger.warning( "Failed to get query embedding, " "returning default result", ) if explicit: explicit_scores = {n: 1.0 for n in explicit} expanded, tgroups = self._expand_tool_prefixes( list(explicit), tool_scores=explicit_scores, registry_tool_names=_reg, ) activated = [] result = _build_result( expanded, "optional", scores=explicit_scores, triggered_groups=tgroups, activated_tiers=activated, ) logger.info( "VectorClassifier result (explicit-only, no " "query embedding): strategy=%s, tools_count=%d", result["strategy"], len(result["tools"]), ) _emit_classifier_observability( result, None, phase="user_message", extra=observability_extra, ) return result matches = await self._find_matching_tools(query_emb) matches = _filter_gated_prefix_vector_matches( matches, message, explicit_names=set(explicit), ) vector_names: list[str] = [t["name"] for t in matches] if matches else [] all_scores: dict[str, float] = ( {t["name"]: t["score"] for t in matches} if matches else {} ) for n in explicit: all_scores.setdefault(n, 1.0) pre_expansion = list(dict.fromkeys([*explicit, *vector_names])) combined, triggered_groups = self._expand_tool_prefixes( pre_expansion, tool_scores=all_scores, registry_tool_names=_reg, ) if matches: max_score = matches[0]["score"] if max_score > self.strategy_force_threshold: result["strategy"] = "force" elif max_score > self.strategy_optional_threshold: result["strategy"] = "optional" else: result["strategy"] = "none" logger.info( "Vector match: %d tools, " "max_score=%.4f, strategy=%s", len(vector_names), max_score, result["strategy"], ) for t in matches[:5]: logger.debug( " - %s: %.4f", t["name"], t["score"], ) else: logger.info( "No tools matched above threshold", ) result["strategy"] = "none" if explicit and result["strategy"] == "none": result["strategy"] = "optional" result["tools"] = combined result["tool_scores"] = all_scores activated_tiers: list[str] = [] if result["strategy"] != "none": for tool in self._get_essential_tools(): if tool not in result["tools"]: result["tools"].append(tool) tier_tools, activated_tiers = self._get_tier_tools( result["tools"], message, ) for t in tier_tools: result["tools"].append(t) logger.info( "VectorClassifier result: " "strategy=%s, tools_count=%d " "(vector=%d, explicit=%d, expanded_groups=%d, tiers=%s)", result["strategy"], len(result["tools"]), len(vector_names), len(explicit), sum(len(v) for v in triggered_groups.values()), activated_tiers or "none", ) if triggered_groups: logger.info( "Group expansion detail: %s", {k: len(v) for k, v in triggered_groups.items()}, ) _emit_classifier_observability( result, matches, phase="user_message", extra=observability_extra, ) return result
@staticmethod def _trim_skills_catalog( skills: list[dict[str, Any]], max_catalog_chars: int, ) -> list[dict[str, Any]]: """Keep top skills until cumulative name+description fits budget. Bounds the size of the skill catalog surfaced to the LLM (progressive disclosure) so a long tail of low-scoring skills cannot blow the prompt budget. Walks the already score-sorted skills, charging roughly ``len(name) + len(description) + 32`` characters each, and stops once adding the next entry would exceed ``max_catalog_chars`` (while always keeping at least one). A non-positive budget or empty input is returned unchanged. Pure in-memory work with no I/O. Called by :meth:`classify_skills` in this module on its final result. Args: skills (list[dict[str, Any]]): Score-ranked skill dicts carrying ``name`` and ``description``. max_catalog_chars (int): Approximate character budget; values <= 0 disable trimming. Returns: list[dict[str, Any]]: The leading skills that fit the budget. """ if max_catalog_chars <= 0 or not skills: return skills out: list[dict[str, Any]] = [] used = 0 for s in skills: chunk = len(s.get("name", "")) + len(s.get("description", "")) + 32 if used + chunk > max_catalog_chars and out: break out.append(s) used += chunk return out
[docs] async def classify_skills( self, message: str, query_embedding: np.ndarray | None = None, *, similarity_threshold: float = DEFAULT_SKILL_SIMILARITY_THRESHOLD, top_k: int = DEFAULT_SKILL_TOP_K, max_catalog_chars: int = 4000, ) -> list[dict[str, Any]]: """Retrieve tier-1 skill metadata relevant to *message* (progressive disclosure). Public entry point for skill routing: it returns the slim catalog of candidate skills shown to the model first, so the full skill body is only loaded on demand. Blank messages short-circuit to an empty list. It ensures skill embeddings are available (RediSearch via :meth:`_skill_redisearch_has_docs`, else :meth:`_load_skill_embeddings`), embeds the message through :meth:`_get_query_embedding` unless a precomputed vector is supplied, ranks candidates with :meth:`_find_matching_skills`, projects each match down to the four catalog fields, and finally bounds the result through :meth:`_trim_skills_catalog`. Called by ``message_processor/generate_and_send.py`` (the per-message generation path) as ``self._classifier.classify_skills(...)``. Args: message (str): The user message to route skills for. query_embedding (np.ndarray | None): Precomputed embedding for *message*; when given, the internal embedding call is skipped. similarity_threshold (float): Minimum cosine score for a skill to be kept. top_k (int): Maximum number of skills to retrieve before trimming. max_catalog_chars (int): Character budget passed to :meth:`_trim_skills_catalog`. Returns: list[dict[str, Any]]: Skill dicts with ``skill_id``, ``name``, ``description``, and ``score``; empty when nothing qualifies. """ if not message or not message.strip(): return [] if await self._skill_redisearch_has_docs(): loaded = True else: loaded = await self._load_skill_embeddings() if not loaded: return [] q_emb = ( query_embedding if query_embedding is not None else await self._get_query_embedding(message) ) if q_emb is None: return [] matches = await self._find_matching_skills( q_emb, similarity_threshold=similarity_threshold, top_k=top_k, ) slim: list[dict[str, Any]] = [] for m in matches: slim.append( { "skill_id": m["skill_id"], "name": m["name"], "description": m["description"], "score": m["score"], } ) return self._trim_skills_catalog(slim, max_catalog_chars)
[docs] async def classify_response_for_missing_tools( self, response_text: str, current_tools: list[str], threshold: float = TOOL_EXPANSION_THRESHOLD, *, observability_extra: Mapping[str, Any] | None = None, ) -> list[str]: """Find tools the bot might need but lacks. Used for dynamic tool expansion when the bot signals it needs tools not included in the original set. Runs **vector similarity only** on *response_text* — not :func:`find_tools_explicitly_named`, so tool names that appear in assistant output or postprocessed reply text never add tools. """ logger.info( "classify_response_for_missing_tools: " "%d current tools", len(current_tools), ) if not await self._tool_redisearch_has_docs(): loaded = await self._load_tool_embeddings() if not loaded: if observability_extra: _emit_classifier_observability( { "strategy": "embeddings_not_loaded", "tools": [], }, [], phase="assistant_response", extra=observability_extra, ) return [] query_emb = await self._get_query_embedding( response_text, ) if query_emb is None: if observability_extra: _emit_classifier_observability( { "strategy": "no_query_embedding", "tools": [], }, [], phase="assistant_response", extra=observability_extra, ) return [] original = self.similarity_threshold self.similarity_threshold = threshold try: matches = await self._find_matching_tools( query_emb, ) finally: self.similarity_threshold = original matches = _filter_gated_prefix_vector_matches( matches, response_text, explicit_names=None, ) await self._ensure_tool_names_for_expansion() current_set = set(current_tools) score_by_name = {m["name"]: m["score"] for m in matches} new_tools = [ t["name"] for t in matches if ( t["name"] not in current_set and t["score"] >= max( threshold, self._similarity_threshold_for_tool_name(t["name"]), ) ) ] expanded, _ = self._expand_tool_prefixes( new_tools, tool_scores=score_by_name, ) new_tools = [t for t in expanded if t not in current_set] if new_tools: logger.info( "Found %d potential new tools: %s", len(new_tools), new_tools[:10], ) else: logger.info("No new tools found above threshold") if observability_extra: top_matches = [ {"name": m.get("name", ""), "score": float(m.get("score", 0.0))} for m in matches[:12] ] _emit_classifier_observability( { "strategy": "tool_expansion", "tools": new_tools, }, top_matches, phase="assistant_response", extra=observability_extra, ) return new_tools
# ------------------------------------------------------------------ # Tiered essential tools # ------------------------------------------------------------------ _CORE_ESSENTIALS: list[str] = [ "no_tool", "no_response", "request_tool_injection", "activate_skill", "extend_tool_loop", "list_all_tools", "search_tools", "calculate_math_expression", "wait", "check_task", "await_task", "redirect_task", "upload_file", "universal_decode", "stargazer_ban", "stargazer_shadowban", # Chaos Switch lattice tools — forced passive # 🌀💀 "view_chaos_weather", "view_lattice_position", "grant_lattice_consent", "move_lattice_position", ] _KNOWLEDGE_TIER: list[str] = [ "store_knowledge", "add_entity", "add_relationship", "query_knowledge", "get_entity", "list_entities", "delete_entity", "delete_relationship", "search_knowledge", "write_short_term_note", "read_short_term_notes", "clear_short_term_notes", ] _FILE_TIER: list[str] = [ "create_file", "read_file", "delete_file", "edit_file", "git_read_repo_file", "read_own_docs", ] _GOAL_TIER: list[str] = [ "create_goal", "get_goal", "list_channel_goals", "update_goal", "delete_goal", "add_subtask", "update_subtask", "list_subtasks", "remove_subtask", "list_all_goals", ] _WEBHOOK_TIER: list[str] = [ "create_webhook", "list_webhooks", "delete_webhook", "edit_webhook", "execute_webhook", ] _MODERATION_TIER: list[str] = [ "kick_user", "ban_user", "timeout_user", "block_user", ] _PLATFORM_TIER: list[str] = [ "discord_react", "discord_embed", ] # Constellation relationship graph — auto-injected on bond/role keywords # 🕷️💕 _CONSTELLATION_TIER: list[str] = [ "constellation_bond", "constellation_link", "constellation_view", "constellation_remove", ] _TIER_KEYWORDS: dict[str, list[str]] = { "file": [ "file", "read", "write", "edit", "upload", "download", "document", "doc", "attachment", "repo", "git", ], "goal": [ "goal", "task", "subtask", "todo", "milestone", "objective", "progress", "plan", "track", ], "webhook": [ "webhook", "hook", "automat", "trigger", "endpoint", "callback", "notify", "notification", ], "moderation": [ "ban", "kick", "mute", "timeout", "block", "moderate", "moderation", "punish", "warn", ], "platform": ["react", "emoji", "embed", "discord"], # Constellation: bond types + relationship words + ALL roles # 🕷️💕 "constellation": [ # ---- Bond / relationship vocabulary ---- "bond", "boyfriend", "girlfriend", "partner", "relationship", "dating", "dynamic", "polycule", "poly", "metamour", "fwb", "situationship", "constellation", "intimate", "romantic", "sexual", "platonic", "partnered", "lover", "wife", "husband", "fiance", "fiancee", "spouse", "significant other", # ---- dom_sub_switch ---- "dominant", "domme", "submissive", "switch", "dom_leaning_switch", "sub_leaning_switch", "hard_dom", "hard_domme", "soft_dom", "soft_domme", "gentle_dom", "gentle_domme", "pleasure_submissive", "unruly_submissive", "alpha_submissive", "femdom_submissive", "conditional_sub", "bedroom_dom", "bedroom_sub", # ---- top_bottom ---- "top", "bottom", "total_top", "total_bottom", "versatile", "vers_top", "vers_bottom", "submissive_top", "dominant_bottom", "power_bottom", "bossy_bottom", "stone_top", "stone_bottom", # ---- master_slave ---- "master", "mistress", "mxtress", "owner", "slave", "slave_girl", "pleasure_slave", "sex_slave", "conditional_slave", "femdom_slave", # ---- authority_titles ---- "sir", "maam", "miss", "madame", "goddess", "deity", "tyrant", "divine", "virtuoso", # ---- sadist_masochist ---- "sadist", "masochist", "sadomasochist", "dominant_sadist", "submissive_sadist", "submissive_masochist", "dominant_masochist", "hard_sadist", "hard_masochist", "primal_sadist", "primal_masochist", "emotional_sadist", "emotional_masochist", "intellectual_sadist", "intellectual_masochist", # ---- caregiver_little ---- "daddy", "mommy", "caregiver", "guardian", "little", "middle", "babygirl", "babyboy", "little_girl", "little_boy", "brat", "bratty_sub", "brat_tamer", # ---- pet_play ---- "pet", "kitten", "puppy", "pony", "bunny", "fox", "handler", "trainer", "pet_owner", # ---- primal ---- "primal_predator", "primal_prey", "primal_switch", "hunter", "alpha", "omega", # ---- service ---- "service_submissive", "service_top", "butler", "maid", "attendant", "devotee", "worshipper", # ---- rope_bondage ---- "rigger", "rope_bunny", "rope_top", "rope_bottom", "shibari_artist", # ---- objectification ---- "property", "possession", "drone", "ritual_object", "furniture", "toy", # ---- other ---- "voyeur", "exhibitionist", "cuckold", "cuckquean", "hotwife", "stag", "vixen", "bull", "swinger", "hedonist", "sensualist", "kinkster", "fetishist", "pervert", "vanilla", "curious", "exploring", # ---- kink infrastructure ---- "collar", "leash", "protocol", "kink", "limits", "safeword", "scene", "aftercare", ], } @classmethod def _get_essential_tools(cls) -> list[str]: """Core + KG essentials (always included for non-NONE strategies). Other tiers are activated by :meth:`_get_tier_tools` based on vector scores, keywords, or room-context signals. """ return list(cls._CORE_ESSENTIALS) + list(cls._KNOWLEDGE_TIER) @classmethod def _get_tier_tools( cls, tool_names: list[str], message: str = "", ) -> tuple[list[str], list[str]]: """Activate contextual tiers and return ``(tools, activated_tiers)``. Activation signals (any one triggers the tier): - A vector-selected or explicit tool already belongs to the tier - Keywords in *message* match :data:`_TIER_KEYWORDS` """ existing = set(tool_names) msg_lower = message.lower() if message else "" result: list[str] = [] activated: list[str] = [] tier_map: dict[str, list[str]] = { "file": cls._FILE_TIER, "goal": cls._GOAL_TIER, "webhook": cls._WEBHOOK_TIER, "moderation": cls._MODERATION_TIER, "platform": cls._PLATFORM_TIER, "constellation": cls._CONSTELLATION_TIER, # 🕷️💕 } for tier_name, tier_tools in tier_map.items(): triggered = any(t in existing for t in tier_tools) if not triggered: keywords = cls._TIER_KEYWORDS.get(tier_name, []) triggered = any(kw in msg_lower for kw in keywords) if triggered: activated.append(tier_name) for t in tier_tools: if t not in existing: result.append(t) existing.add(t) return result, activated
[docs] async def close(self) -> None: """Close the underlying embedding client. Releases the lazily-created OpenRouter embedding client and its HTTP session when one exists, then clears the reference so a later call can rebuild it via :meth:`_get_embedding_client`. Safe to call when no client was ever created. Touches no Redis or other resources — only the embedding client's network session. This is the classifier's lifecycle teardown hook; no in-repo caller invokes it by name (owners are expected to call it during their own shutdown alongside other resource cleanup). """ if self._embedding_client is not None: await self._embedding_client.close() self._embedding_client = None
# ------------------------------------------------------------------ # Standalone helpers for initialisation scripts # ------------------------------------------------------------------
[docs] async def initialize_tool_embeddings_from_file( index_file_path: str, redis_client: aioredis.Redis, api_key: str | None = None, force_recompute: bool = False, ) -> bool: """Compute centroid embeddings and store in Redis. Reads ``tool_index_data.json``, embeds every synthetic query per tool, calculates the centroid, and writes the result into Redis hashes. """ embedding_client: OpenRouterEmbeddings | None = None try: if not force_recompute: existing = await redis_client.hlen( TOOL_EMBEDDINGS_HASH_KEY, ) if existing > 0: logger.info( "Tool embeddings already in Redis " "(%d tools), skipping", existing, ) return True logger.info( "Loading tool index data from %s", index_file_path, ) def _read_index() -> dict[str, Any]: """Read and parse the tool index JSON file from disk. Closure over :func:`initialize_tool_embeddings_from_file` that opens the enclosing ``index_file_path`` and returns its parsed contents. Defined as a plain synchronous function so the blocking filesystem read can be pushed off the event loop via :func:`asyncio.to_thread`. Called by :func:`initialize_tool_embeddings_from_file` in this module (its sole caller). Returns: dict[str, Any]: The decoded ``tool_index_data.json`` mapping of tool name to tool info. """ with open( index_file_path, "r", encoding="utf-8", ) as f: return json.load(f) tool_data: dict[str, Any] = await asyncio.to_thread(_read_index) logger.info( "Loaded %d tools from index file", len(tool_data), ) resolved_key = api_key or os.environ.get( "OPENROUTER_API_KEY", "", ) embedding_client = OpenRouterEmbeddings( api_key=resolved_key, ) embeddings_to_store: dict[str, str] = {} metadata_to_store: dict[str, str] = {} tool_queries: dict[str, list[str]] = {} meta_by_tool: dict[str, tuple[str, list[str]]] = {} for tool_name, tool_info in tool_data.items(): try: queries = tool_info.get( "synthetic_queries", [], ) desc = tool_info.get("description", "") if not queries: logger.warning( "Tool %r has no synthetic " "queries, using description", tool_name, ) queries = [desc] if desc else [f"use {tool_name}"] tool_queries[tool_name] = queries meta_by_tool[tool_name] = (desc, queries) except Exception as exc: logger.error( "Failed to process tool %r: %s", tool_name, exc, ) logger.info( "Computing tool centroids for %d tools in bulk...", len(tool_queries), ) centroids = await compute_tool_centroids_bulk( embedding_client, tool_queries, ) logger.info("Tool centroids computation complete.") for tool_name, centroid in centroids.items(): if centroid is None: logger.warning( "No embeddings computed for %r", tool_name, ) continue desc, queries = meta_by_tool[tool_name] embeddings_to_store[tool_name] = json.dumps(centroid.tolist()) metadata_to_store[tool_name] = json.dumps( { "name": tool_name, "description": desc, "query_count": len(queries), } ) logger.debug( "Computed centroid for %r from " "%d queries", tool_name, len(queries), ) if embeddings_to_store: if force_recompute: await redis_client.delete( TOOL_EMBEDDINGS_HASH_KEY, ) await redis_client.delete( TOOL_METADATA_HASH_KEY, ) await redis_client.hset( TOOL_EMBEDDINGS_HASH_KEY, mapping=embeddings_to_store, ) await redis_client.hset( TOOL_METADATA_HASH_KEY, mapping=metadata_to_store, ) from classifiers.redis_vector_index import store_tool_embedding_hash for tool_name, meta_json in metadata_to_store.items(): vec = np.array( json.loads(embeddings_to_store[tool_name]), dtype=np.float32, ) meta = json.loads(meta_json) await store_tool_embedding_hash( redis_client, tool_name, vec, meta, ) logger.info( "Stored %d tool embeddings in Redis", len(embeddings_to_store), ) return True logger.error("No tool embeddings to store") return False except Exception as exc: logger.error( "Failed to initialize tool embeddings: %s", exc, exc_info=True, ) return False finally: if embedding_client is not None: await embedding_client.close()
[docs] async def reload_tool_embeddings( redis_client: aioredis.Redis, api_key: str | None = None, ) -> bool: """Reload embeddings from ``tool_index_data.json``. Convenience wrapper that force-recomputes the tool centroid embeddings: it resolves the path to ``tool_index_data.json`` next to this module and delegates to :func:`initialize_tool_embeddings_from_file` with ``force_recompute=True``, so the existing Redis hashes are dropped and rewritten (and the per-tool RediSearch documents re-stored). This reads the index file from the filesystem, calls OpenRouter to embed every synthetic query, and writes the results back to Redis. No in-repo caller invokes this by name; it is an initialization/maintenance entry point used when the tool corpus changes. Args: redis_client (aioredis.Redis): Async Redis connection the embeddings are written to. api_key (str | None): OpenRouter API key; falls back to ``OPENROUTER_API_KEY`` when ``None``. Returns: bool: ``True`` when embeddings were recomputed and stored, ``False`` on failure. """ index_file = os.path.join( os.path.dirname(__file__), "tool_index_data.json", ) return await initialize_tool_embeddings_from_file( index_file, redis_client, api_key=api_key, force_recompute=True, )