Source code for classifiers.vector_classifier

"""Vector-based classifier for tool selection.

Lightweight semantic vector classifier that replaces sending all
tools to the LLM with deterministic vector retrieval.  Pre-computed
centroid embeddings (stored in Redis hashes) are compared against
user-query embeddings via cosine similarity to select the most
relevant tools.
"""

from __future__ import annotations

import asyncio
import functools
import json
import logging
import os
import re
from typing import Any, Iterable

import numpy as np
import redis.asyncio as aioredis

from rag_system.openrouter_embeddings import OpenRouterEmbeddings
from utils.cosine import cosine_batch

logger = logging.getLogger(__name__)

TOOL_EMBEDDINGS_HASH_KEY = "stargazer:tool_embeddings"
TOOL_METADATA_HASH_KEY = "stargazer:tool_metadata"

DEFAULT_SIMILARITY_THRESHOLD = 0.15
DEFAULT_TOP_K = 20

TOOL_EXPANSION_THRESHOLD = 0.85

# When any tool from these prefixes is selected, pull in all tools with that prefix.
TOOL_PREFIX_GROUPS = ("wg_", "ovpn_", "ipsec_", "desktop_")

# Maximal ASCII “word” runs — same characters Python uses in ``re`` word
# boundaries for typical ``snake_case`` tool names (avoids Unicode ``\\w``
# so we do not split identifiers on UTF-8 letters).
_EXPLICIT_TOOL_TOKEN_RE = re.compile(r"[A-Za-z0-9_]+")

# Inline code / backtick-wrapped spans (one line); inner text is stripped
# and must match a registered name exactly.
_EXPLICIT_BACKTICK_RE = re.compile(r"`([^`\n]+)`")

# Run explicit-name scan off the event loop for huge pastes (chars).
_EXPLICIT_SCAN_TO_THREAD_CHARS = 96_000


[docs] def detect_tool_request_keywords(response_text: str) -> bool: """Return *True* when the bot seems to request missing tools. This lightweight regex check gates the heavier embedding-based tool expansion to avoid false positives on legitimate no-tool responses. """ if not response_text: return False text_lower = response_text.lower() patterns = [ r"\bi (?:still )?(?:need|lack|require|am missing" r"|don't have|do not have)\b", r"\bwithout (?:the |access to )?(?:\w+ )?tool", r"\bunable to (?:use|access|call|execute)\b", r"\bmy tool ?belt (?:doesn't|does not)" r" (?:have|contain|include)\b", r"\bgive me (?:the )?\w+ tool", r"\bi need [a-z_]+ tool", r"\bmissing (?:the )?(?:ability|capability" r"|tool|function)", r"\bdon't have (?:the |access to )?" r"(?:\w+ )?(?:tool|function|command)", ] for pattern in patterns: if re.search(pattern, text_lower): logger.debug( "Tool request keyword detected: %r", pattern, ) return True return False
@functools.lru_cache(maxsize=32) def _explicit_tool_lookup(names_key: tuple[str, ...]) -> frozenset[str]: """Cached frozenset of tool names for fast token look-ups.""" return frozenset(names_key)
[docs] def find_tools_explicitly_named( message: str, valid_names: Iterable[str], ) -> list[str]: """Return tool names that appear verbatim in *message* as whole tokens. Detection: * Maximal runs of ASCII letters, digits, and underscores (typical ``snake_case`` tools), equivalent to word boundaries for those names. * Text inside ASCII backticks (``inline code``): inner text is stripped and must match a registered tool name **exactly**, so names containing hyphens or other punctuation still match when quoted. Hits are ordered by first occurrence in the message; each tool appears at most once. """ if not message: return [] uniq: list[str] = [] seen: set[str] = set() for raw in valid_names: name = raw.strip() if isinstance(raw, str) else str(raw) if not name or name in seen: continue seen.add(name) uniq.append(name) if not uniq: return [] lookup = _explicit_tool_lookup(tuple(sorted(seen))) events: list[tuple[int, str]] = [] for m in _EXPLICIT_TOOL_TOKEN_RE.finditer(message): token = m.group(0) if token in lookup: events.append((m.start(), token)) for m in _EXPLICIT_BACKTICK_RE.finditer(message): inner = m.group(1).strip() if inner in lookup: events.append((m.start(1), inner)) events.sort(key=lambda t: t[0]) ordered_hits: list[str] = [] hit_seen: set[str] = set() for _pos, name in events: if name not in hit_seen: hit_seen.add(name) ordered_hits.append(name) if ordered_hits: logger.debug( "Explicit tool names in message: %s", ordered_hits, ) return ordered_hits
[docs] class VectorClassifier: """Semantic vector-based classifier for tool selection. Parameters ---------- redis_client: An async Redis connection (``redis.asyncio.Redis``). similarity_threshold: Minimum cosine similarity for a match. top_k: Maximum number of tools to return. api_key: OpenRouter API key. Falls back to the ``OPENROUTER_API_KEY`` env var. """
[docs] def __init__( self, redis_client: aioredis.Redis, similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD, top_k: int = DEFAULT_TOP_K, api_key: str | None = None, ) -> None: """Initialize the instance. Args: redis_client (aioredis.Redis): Redis connection client. similarity_threshold (float): The similarity threshold value. top_k (int): The top k value. api_key (str | None): The api key value. """ self._redis = redis_client self.similarity_threshold = similarity_threshold self.top_k = top_k self._embedding_client: OpenRouterEmbeddings | None = None self._api_key = api_key or os.environ.get( "OPENROUTER_API_KEY", "", ) self._tool_embeddings_cache: ( dict[str, np.ndarray] | None ) = None self._tool_metadata_cache: ( dict[str, dict[str, Any]] | None ) = None # Pre-built matrix for batch cosine similarity: shape (N, D). # Built alongside _tool_embeddings_cache. self._tool_names_list: list[str] = [] self._tool_embeddings_matrix: np.ndarray | None = None logger.info( "VectorClassifier initialized: " "threshold=%s, top_k=%s", similarity_threshold, top_k, )
# -------------------------------------------------------------- # Embedding client # -------------------------------------------------------------- async def _get_embedding_client(self) -> OpenRouterEmbeddings: """Internal helper: get embedding client. Returns: OpenRouterEmbeddings: The result. """ if self._embedding_client is None: self._embedding_client = OpenRouterEmbeddings( api_key=self._api_key, ) return self._embedding_client # -------------------------------------------------------------- # Load / cache tool embeddings from Redis # -------------------------------------------------------------- async def _load_tool_embeddings( self, force_reload: bool = False, ) -> bool: """Internal helper: load tool embeddings. Args: force_reload (bool): The force reload value. Returns: bool: True on success, False otherwise. """ if ( self._tool_embeddings_cache is not None and not force_reload ): return True try: embeddings_data: dict = await self._redis.hgetall( TOOL_EMBEDDINGS_HASH_KEY, ) if not embeddings_data: logger.warning( "No tool embeddings in Redis key: %s", TOOL_EMBEDDINGS_HASH_KEY, ) return False self._tool_embeddings_cache = {} for name, emb_json in embeddings_data.items(): try: if isinstance(name, bytes): name = name.decode("utf-8") if isinstance(emb_json, bytes): emb_json = emb_json.decode("utf-8") vec = np.array( json.loads(emb_json), dtype=np.float32, ) self._tool_embeddings_cache[name] = vec except Exception as exc: logger.warning( "Failed to parse embedding for " "tool %r: %s", name, exc, ) meta_data: dict = await self._redis.hgetall( TOOL_METADATA_HASH_KEY, ) self._tool_metadata_cache = {} for name, meta_json in meta_data.items(): try: if isinstance(name, bytes): name = name.decode("utf-8") if isinstance(meta_json, bytes): meta_json = meta_json.decode("utf-8") self._tool_metadata_cache[name] = ( json.loads(meta_json) ) except Exception as exc: logger.warning( "Failed to parse metadata for " "tool %r: %s", name, exc, ) # Build ordered names list and (N, D) matrix for batch cosine. self._tool_names_list = list( self._tool_embeddings_cache.keys() ) if self._tool_names_list: self._tool_embeddings_matrix = np.stack( [self._tool_embeddings_cache[n] for n in self._tool_names_list], axis=0, ) else: self._tool_embeddings_matrix = None logger.info( "Loaded %d tool embeddings from Redis", len(self._tool_embeddings_cache), ) return True except Exception as exc: logger.error( "Failed to load tool embeddings: %s", exc, ) return False # -------------------------------------------------------------- # Query embedding # -------------------------------------------------------------- async def _get_query_embedding( self, query: str, ) -> np.ndarray | None: """Internal helper: get query embedding. Args: query (str): Search query or input string. """ try: client = await self._get_embedding_client() embedding = await client.embed_text(query) if embedding.size == 0: logger.warning( "Empty embedding returned for query", ) return None return embedding except Exception as exc: logger.error( "Failed to get query embedding: %s", exc, ) return None # -------------------------------------------------------------- # Similarity search # -------------------------------------------------------------- async def _find_matching_tools( self, query_embedding: np.ndarray, ) -> list[dict[str, Any]]: """Internal helper: find matching tools. Uses a single batch matrix–vector multiply (via ``cosine_batch``) instead of a per-tool Python loop for a significant speed-up at 3072-dimensional embeddings. Args: query_embedding (np.ndarray): The query embedding value. Returns: list[dict[str, Any]]: The result. """ if self._tool_embeddings_cache is None: await self._load_tool_embeddings() if ( not self._tool_embeddings_cache or self._tool_embeddings_matrix is None ): return [] sims = cosine_batch( query_embedding, self._tool_embeddings_matrix, ) meta_cache = self._tool_metadata_cache or {} scores: list[dict[str, Any]] = [ { "name": name, "score": float(sims[i]), "metadata": meta_cache.get(name, {}), } for i, name in enumerate(self._tool_names_list) if sims[i] >= self.similarity_threshold ] scores.sort(key=lambda x: x["score"], reverse=True) return scores[: self.top_k] def _expand_tool_prefixes(self, tool_names: list[str]) -> list[str]: """When any tool matches a prefix group (wg_, ovpn_, ipsec_), add all tools from that group.""" if not self._tool_names_list: return tool_names result = set(tool_names) for name in tool_names: for prefix in TOOL_PREFIX_GROUPS: if name.startswith(prefix): for t in self._tool_names_list: if t.startswith(prefix): result.add(t) break return list(result) # -------------------------------------------------------------- # Public API # --------------------------------------------------------------
[docs] async def classify( self, message: str, query_embedding: np.ndarray | None = None, registry_tool_names: Iterable[str] | None = None, *, scan_explicit_tool_mentions: bool = True, ) -> dict[str, Any]: """Classify *message* and return tool names + strategy. Parameters ---------- query_embedding: Pre-computed embedding for *message*. When provided the internal embedding API call is skipped. registry_tool_names: Registered tool names (e.g. registry keys). When provided, any name that appears as a whole token in *message* is included in the tool set alongside vector matches. scan_explicit_tool_mentions: When ``True`` (default), scan *message* for explicit registered tool names. Set ``False`` for non-user text (e.g. assistant drafts, response postprocessing) so mentions in those strings never inflate the tool set. Returns a dict with keys ``tools``, ``strategy``, ``complexity``, and ``safety``. """ logger.info( "VectorClassifier.classify() for: %s", message[:100] if message else "<blank>", ) # ── blank / whitespace-only messages ────────────────────── # An empty query produces a near-zero embedding whose cosine # similarity is >= threshold for *every* tool, returning all # tools and blowing past the 512 function-declaration limit. if not message or not message.strip(): logger.info( "Blank message received, returning empty " "tool set (strategy=none)", ) return { "complexity": "moderate", "safety": "safe", "strategy": "none", "tools": [], } if scan_explicit_tool_mentions: _reg = registry_tool_names or () if len(message) >= _EXPLICIT_SCAN_TO_THREAD_CHARS: explicit = await asyncio.to_thread( find_tools_explicitly_named, message, _reg, ) else: explicit = find_tools_explicitly_named(message, _reg) else: explicit = [] def _with_essentials( tools: list[str], strategy: str, ) -> dict[str, Any]: """Internal helper: build result dict + essential tools.""" out: dict[str, Any] = { "complexity": "moderate", "safety": "safe", "strategy": strategy, "tools": list(tools), } if out["strategy"] != "none": for tool in self._get_essential_tools(): if tool not in out["tools"]: out["tools"].append(tool) return out embeddings_loaded = await self._load_tool_embeddings() result: dict[str, Any] = { "complexity": "moderate", "safety": "safe", "strategy": "optional", "tools": [], } if not embeddings_loaded: logger.warning( "Tool embeddings not available, " "returning default result", ) if explicit: result = _with_essentials( self._expand_tool_prefixes(list(explicit)), "optional", ) logger.info( "VectorClassifier result (explicit-only, no " "embeddings): strategy=%s, tools_count=%d", result["strategy"], len(result["tools"]), ) return result query_emb = ( query_embedding if query_embedding is not None else await self._get_query_embedding(message) ) if query_emb is None: logger.warning( "Failed to get query embedding, " "returning default result", ) if explicit: result = _with_essentials( self._expand_tool_prefixes(list(explicit)), "optional", ) logger.info( "VectorClassifier result (explicit-only, no " "query embedding): strategy=%s, tools_count=%d", result["strategy"], len(result["tools"]), ) return result matches = await self._find_matching_tools(query_emb) vector_names: list[str] = [ t["name"] for t in matches ] if matches else [] combined = self._expand_tool_prefixes( list(dict.fromkeys([*explicit, *vector_names])), ) if matches: max_score = matches[0]["score"] if max_score > 0.8: result["strategy"] = "force" elif max_score > 0.2: result["strategy"] = "optional" else: result["strategy"] = "none" logger.info( "Vector match: %d tools, " "max_score=%.4f, strategy=%s", len(vector_names), max_score, result["strategy"], ) for t in matches[:5]: logger.debug( " - %s: %.4f", t["name"], t["score"], ) else: logger.info( "No tools matched above threshold", ) result["strategy"] = "none" if explicit and result["strategy"] == "none": result["strategy"] = "optional" result["tools"] = combined if result["strategy"] != "none": for tool in self._get_essential_tools(): if tool not in result["tools"]: result["tools"].append(tool) logger.info( "VectorClassifier result: " "strategy=%s, tools_count=%d", result["strategy"], len(result["tools"]), ) return result
[docs] async def classify_response_for_missing_tools( self, response_text: str, current_tools: list[str], threshold: float = TOOL_EXPANSION_THRESHOLD, ) -> list[str]: """Find tools the bot might need but lacks. Used for dynamic tool expansion when the bot signals it needs tools not included in the original set. Runs **vector similarity only** on *response_text* — not :func:`find_tools_explicitly_named`, so tool names that appear in assistant output or postprocessed reply text never add tools. """ logger.info( "classify_response_for_missing_tools: " "%d current tools", len(current_tools), ) loaded = await self._load_tool_embeddings() if not loaded: return [] query_emb = await self._get_query_embedding( response_text, ) if query_emb is None: return [] original = self.similarity_threshold self.similarity_threshold = threshold try: matches = await self._find_matching_tools( query_emb, ) finally: self.similarity_threshold = original current_set = set(current_tools) new_tools = [ t["name"] for t in matches if ( t["name"] not in current_set and t["score"] >= threshold ) ] new_tools = [ t for t in self._expand_tool_prefixes(new_tools) if t not in current_set ] if new_tools: logger.info( "Found %d potential new tools: %s", len(new_tools), new_tools[:10], ) else: logger.info("No new tools found above threshold") return new_tools
@staticmethod def _get_essential_tools() -> list[str]: """Tools always included for non-NONE strategies.""" return [ "no_tool", "no_response", "kick_user", "ban_user", "timeout_user", "block_user", "store_knowledge", "add_entity", "add_relationship", "query_knowledge", "get_entity", "list_entities", "delete_entity", "delete_relationship", "search_knowledge", "write_short_term_note", "read_short_term_notes", "clear_short_term_notes", "calculate_math_expression", "get_user_profile", "universal_decode", "wait", "check_task", "await_task", "redirect_task", "create_file", "read_file", "delete_file", "edit_file", "upload_file", "git_read_repo_file", "create_goal", "get_goal", "list_channel_goals", "update_goal", "delete_goal", "add_subtask", "update_subtask", "list_subtasks", "remove_subtask", "list_all_goals", "create_webhook", "list_webhooks", "delete_webhook", "edit_webhook", "execute_webhook", "list_all_tools", "current_time", "read_own_docs", "semantic_search", "discord_react", "discord_embed", "extend_tool_loop", "request_tool_injection", ]
[docs] async def close(self) -> None: """Close the underlying embedding client.""" if self._embedding_client is not None: await self._embedding_client.close() self._embedding_client = None
# ------------------------------------------------------------------ # Standalone helpers for initialisation scripts # ------------------------------------------------------------------
[docs] async def initialize_tool_embeddings_from_file( index_file_path: str, redis_client: aioredis.Redis, api_key: str | None = None, force_recompute: bool = False, ) -> bool: """Compute centroid embeddings and store in Redis. Reads ``tool_index_data.json``, embeds every synthetic query per tool, calculates the centroid, and writes the result into Redis hashes. """ embedding_client: OpenRouterEmbeddings | None = None try: if not force_recompute: existing = await redis_client.hlen( TOOL_EMBEDDINGS_HASH_KEY, ) if existing > 0: logger.info( "Tool embeddings already in Redis " "(%d tools), skipping", existing, ) return True logger.info( "Loading tool index data from %s", index_file_path, ) def _read_index() -> dict[str, Any]: """Internal helper: read index. Returns: dict[str, Any]: The result. """ with open( index_file_path, "r", encoding="utf-8", ) as f: return json.load(f) tool_data: dict[str, Any] = await asyncio.to_thread(_read_index) logger.info( "Loaded %d tools from index file", len(tool_data), ) resolved_key = api_key or os.environ.get( "OPENROUTER_API_KEY", "", ) embedding_client = OpenRouterEmbeddings( api_key=resolved_key, ) embeddings_to_store: dict[str, str] = {} metadata_to_store: dict[str, str] = {} for tool_name, tool_info in tool_data.items(): try: queries = tool_info.get( "synthetic_queries", [], ) desc = tool_info.get("description", "") if not queries: logger.warning( "Tool %r has no synthetic " "queries, using description", tool_name, ) queries = ( [desc] if desc else [f"use {tool_name}"] ) try: embs = [ e for e in await embedding_client.embed_texts( queries, ) if e.size > 0 ] except Exception as exc: logger.warning( "Failed to embed queries " "for %r: %s", tool_name, exc, ) embs = [] if not embs: logger.warning( "No embeddings computed for %r", tool_name, ) continue centroid = np.mean(embs, axis=0) norm = np.linalg.norm(centroid) if norm > 0: centroid = centroid / norm embeddings_to_store[tool_name] = ( json.dumps(centroid.tolist()) ) metadata_to_store[tool_name] = json.dumps({ "name": tool_name, "description": desc, "query_count": len(queries), }) logger.debug( "Computed centroid for %r from " "%d queries", tool_name, len(embs), ) except Exception as exc: logger.error( "Failed to process tool %r: %s", tool_name, exc, ) if embeddings_to_store: if force_recompute: await redis_client.delete( TOOL_EMBEDDINGS_HASH_KEY, ) await redis_client.delete( TOOL_METADATA_HASH_KEY, ) await redis_client.hset( TOOL_EMBEDDINGS_HASH_KEY, mapping=embeddings_to_store, ) await redis_client.hset( TOOL_METADATA_HASH_KEY, mapping=metadata_to_store, ) logger.info( "Stored %d tool embeddings in Redis", len(embeddings_to_store), ) return True logger.error("No tool embeddings to store") return False except Exception as exc: logger.error( "Failed to initialize tool embeddings: %s", exc, exc_info=True, ) return False finally: if embedding_client is not None: await embedding_client.close()
[docs] async def reload_tool_embeddings( redis_client: aioredis.Redis, api_key: str | None = None, ) -> bool: """Reload embeddings from ``tool_index_data.json``.""" index_file = os.path.join( os.path.dirname(__file__), "tool_index_data.json", ) return await initialize_tool_embeddings_from_file( index_file, redis_client, api_key=api_key, force_recompute=True, )