Source code for classifiers.vector_classifier
"""Vector-based classifier for tool selection.
Lightweight semantic vector classifier that replaces sending all
tools to the LLM with deterministic vector retrieval. Pre-computed
centroid embeddings (stored in Redis hashes) are compared against
user-query embeddings via cosine similarity to select the most
relevant tools.
"""
from __future__ import annotations
import asyncio
import functools
import json
import logging
import os
import re
from typing import Any, Iterable
import numpy as np
import redis.asyncio as aioredis
from rag_system.openrouter_embeddings import OpenRouterEmbeddings
from utils.cosine import cosine_batch
logger = logging.getLogger(__name__)
TOOL_EMBEDDINGS_HASH_KEY = "stargazer:tool_embeddings"
TOOL_METADATA_HASH_KEY = "stargazer:tool_metadata"
DEFAULT_SIMILARITY_THRESHOLD = 0.15
DEFAULT_TOP_K = 20
TOOL_EXPANSION_THRESHOLD = 0.85
# When any tool from these prefixes is selected, pull in all tools with that prefix.
TOOL_PREFIX_GROUPS = ("wg_", "ovpn_", "ipsec_", "desktop_")
# Maximal ASCII “word” runs — same characters Python uses in ``re`` word
# boundaries for typical ``snake_case`` tool names (avoids Unicode ``\\w``
# so we do not split identifiers on UTF-8 letters).
_EXPLICIT_TOOL_TOKEN_RE = re.compile(r"[A-Za-z0-9_]+")
# Inline code / backtick-wrapped spans (one line); inner text is stripped
# and must match a registered name exactly.
_EXPLICIT_BACKTICK_RE = re.compile(r"`([^`\n]+)`")
# Run explicit-name scan off the event loop for huge pastes (chars).
_EXPLICIT_SCAN_TO_THREAD_CHARS = 96_000
[docs]
def detect_tool_request_keywords(response_text: str) -> bool:
"""Return *True* when the bot seems to request missing tools.
This lightweight regex check gates the heavier embedding-based
tool expansion to avoid false positives on legitimate no-tool
responses.
"""
if not response_text:
return False
text_lower = response_text.lower()
patterns = [
r"\bi (?:still )?(?:need|lack|require|am missing"
r"|don't have|do not have)\b",
r"\bwithout (?:the |access to )?(?:\w+ )?tool",
r"\bunable to (?:use|access|call|execute)\b",
r"\bmy tool ?belt (?:doesn't|does not)"
r" (?:have|contain|include)\b",
r"\bgive me (?:the )?\w+ tool",
r"\bi need [a-z_]+ tool",
r"\bmissing (?:the )?(?:ability|capability"
r"|tool|function)",
r"\bdon't have (?:the |access to )?"
r"(?:\w+ )?(?:tool|function|command)",
]
for pattern in patterns:
if re.search(pattern, text_lower):
logger.debug(
"Tool request keyword detected: %r",
pattern,
)
return True
return False
@functools.lru_cache(maxsize=32)
def _explicit_tool_lookup(names_key: tuple[str, ...]) -> frozenset[str]:
"""Cached frozenset of tool names for fast token look-ups."""
return frozenset(names_key)
[docs]
def find_tools_explicitly_named(
message: str,
valid_names: Iterable[str],
) -> list[str]:
"""Return tool names that appear verbatim in *message* as whole tokens.
Detection:
* Maximal runs of ASCII letters, digits, and underscores (typical
``snake_case`` tools), equivalent to word boundaries for those names.
* Text inside ASCII backticks (``inline code``): inner text is stripped
and must match a registered tool name **exactly**, so names containing
hyphens or other punctuation still match when quoted.
Hits are ordered by first occurrence in the message; each tool appears
at most once.
"""
if not message:
return []
uniq: list[str] = []
seen: set[str] = set()
for raw in valid_names:
name = raw.strip() if isinstance(raw, str) else str(raw)
if not name or name in seen:
continue
seen.add(name)
uniq.append(name)
if not uniq:
return []
lookup = _explicit_tool_lookup(tuple(sorted(seen)))
events: list[tuple[int, str]] = []
for m in _EXPLICIT_TOOL_TOKEN_RE.finditer(message):
token = m.group(0)
if token in lookup:
events.append((m.start(), token))
for m in _EXPLICIT_BACKTICK_RE.finditer(message):
inner = m.group(1).strip()
if inner in lookup:
events.append((m.start(1), inner))
events.sort(key=lambda t: t[0])
ordered_hits: list[str] = []
hit_seen: set[str] = set()
for _pos, name in events:
if name not in hit_seen:
hit_seen.add(name)
ordered_hits.append(name)
if ordered_hits:
logger.debug(
"Explicit tool names in message: %s",
ordered_hits,
)
return ordered_hits
[docs]
class VectorClassifier:
"""Semantic vector-based classifier for tool selection.
Parameters
----------
redis_client:
An async Redis connection (``redis.asyncio.Redis``).
similarity_threshold:
Minimum cosine similarity for a match.
top_k:
Maximum number of tools to return.
api_key:
OpenRouter API key. Falls back to the
``OPENROUTER_API_KEY`` env var.
"""
[docs]
def __init__(
self,
redis_client: aioredis.Redis,
similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD,
top_k: int = DEFAULT_TOP_K,
api_key: str | None = None,
) -> None:
"""Initialize the instance.
Args:
redis_client (aioredis.Redis): Redis connection client.
similarity_threshold (float): The similarity threshold value.
top_k (int): The top k value.
api_key (str | None): The api key value.
"""
self._redis = redis_client
self.similarity_threshold = similarity_threshold
self.top_k = top_k
self._embedding_client: OpenRouterEmbeddings | None = None
self._api_key = api_key or os.environ.get(
"OPENROUTER_API_KEY", "",
)
self._tool_embeddings_cache: (
dict[str, np.ndarray] | None
) = None
self._tool_metadata_cache: (
dict[str, dict[str, Any]] | None
) = None
# Pre-built matrix for batch cosine similarity: shape (N, D).
# Built alongside _tool_embeddings_cache.
self._tool_names_list: list[str] = []
self._tool_embeddings_matrix: np.ndarray | None = None
logger.info(
"VectorClassifier initialized: "
"threshold=%s, top_k=%s",
similarity_threshold, top_k,
)
# --------------------------------------------------------------
# Embedding client
# --------------------------------------------------------------
async def _get_embedding_client(self) -> OpenRouterEmbeddings:
"""Internal helper: get embedding client.
Returns:
OpenRouterEmbeddings: The result.
"""
if self._embedding_client is None:
self._embedding_client = OpenRouterEmbeddings(
api_key=self._api_key,
)
return self._embedding_client
# --------------------------------------------------------------
# Load / cache tool embeddings from Redis
# --------------------------------------------------------------
async def _load_tool_embeddings(
self, force_reload: bool = False,
) -> bool:
"""Internal helper: load tool embeddings.
Args:
force_reload (bool): The force reload value.
Returns:
bool: True on success, False otherwise.
"""
if (
self._tool_embeddings_cache is not None
and not force_reload
):
return True
try:
embeddings_data: dict = await self._redis.hgetall(
TOOL_EMBEDDINGS_HASH_KEY,
)
if not embeddings_data:
logger.warning(
"No tool embeddings in Redis key: %s",
TOOL_EMBEDDINGS_HASH_KEY,
)
return False
self._tool_embeddings_cache = {}
for name, emb_json in embeddings_data.items():
try:
if isinstance(name, bytes):
name = name.decode("utf-8")
if isinstance(emb_json, bytes):
emb_json = emb_json.decode("utf-8")
vec = np.array(
json.loads(emb_json),
dtype=np.float32,
)
self._tool_embeddings_cache[name] = vec
except Exception as exc:
logger.warning(
"Failed to parse embedding for "
"tool %r: %s",
name, exc,
)
meta_data: dict = await self._redis.hgetall(
TOOL_METADATA_HASH_KEY,
)
self._tool_metadata_cache = {}
for name, meta_json in meta_data.items():
try:
if isinstance(name, bytes):
name = name.decode("utf-8")
if isinstance(meta_json, bytes):
meta_json = meta_json.decode("utf-8")
self._tool_metadata_cache[name] = (
json.loads(meta_json)
)
except Exception as exc:
logger.warning(
"Failed to parse metadata for "
"tool %r: %s",
name, exc,
)
# Build ordered names list and (N, D) matrix for batch cosine.
self._tool_names_list = list(
self._tool_embeddings_cache.keys()
)
if self._tool_names_list:
self._tool_embeddings_matrix = np.stack(
[self._tool_embeddings_cache[n]
for n in self._tool_names_list],
axis=0,
)
else:
self._tool_embeddings_matrix = None
logger.info(
"Loaded %d tool embeddings from Redis",
len(self._tool_embeddings_cache),
)
return True
except Exception as exc:
logger.error(
"Failed to load tool embeddings: %s", exc,
)
return False
# --------------------------------------------------------------
# Query embedding
# --------------------------------------------------------------
async def _get_query_embedding(
self, query: str,
) -> np.ndarray | None:
"""Internal helper: get query embedding.
Args:
query (str): Search query or input string.
"""
try:
client = await self._get_embedding_client()
embedding = await client.embed_text(query)
if embedding.size == 0:
logger.warning(
"Empty embedding returned for query",
)
return None
return embedding
except Exception as exc:
logger.error(
"Failed to get query embedding: %s", exc,
)
return None
# --------------------------------------------------------------
# Similarity search
# --------------------------------------------------------------
async def _find_matching_tools(
self, query_embedding: np.ndarray,
) -> list[dict[str, Any]]:
"""Internal helper: find matching tools.
Uses a single batch matrix–vector multiply (via ``cosine_batch``)
instead of a per-tool Python loop for a significant speed-up at
3072-dimensional embeddings.
Args:
query_embedding (np.ndarray): The query embedding value.
Returns:
list[dict[str, Any]]: The result.
"""
if self._tool_embeddings_cache is None:
await self._load_tool_embeddings()
if (
not self._tool_embeddings_cache
or self._tool_embeddings_matrix is None
):
return []
sims = cosine_batch(
query_embedding, self._tool_embeddings_matrix,
)
meta_cache = self._tool_metadata_cache or {}
scores: list[dict[str, Any]] = [
{
"name": name,
"score": float(sims[i]),
"metadata": meta_cache.get(name, {}),
}
for i, name in enumerate(self._tool_names_list)
if sims[i] >= self.similarity_threshold
]
scores.sort(key=lambda x: x["score"], reverse=True)
return scores[: self.top_k]
def _expand_tool_prefixes(self, tool_names: list[str]) -> list[str]:
"""When any tool matches a prefix group (wg_, ovpn_, ipsec_), add all tools from that group."""
if not self._tool_names_list:
return tool_names
result = set(tool_names)
for name in tool_names:
for prefix in TOOL_PREFIX_GROUPS:
if name.startswith(prefix):
for t in self._tool_names_list:
if t.startswith(prefix):
result.add(t)
break
return list(result)
# --------------------------------------------------------------
# Public API
# --------------------------------------------------------------
[docs]
async def classify(
self,
message: str,
query_embedding: np.ndarray | None = None,
registry_tool_names: Iterable[str] | None = None,
*,
scan_explicit_tool_mentions: bool = True,
) -> dict[str, Any]:
"""Classify *message* and return tool names + strategy.
Parameters
----------
query_embedding:
Pre-computed embedding for *message*. When provided the
internal embedding API call is skipped.
registry_tool_names:
Registered tool names (e.g. registry keys). When provided,
any name that appears as a whole token in *message* is
included in the tool set alongside vector matches.
scan_explicit_tool_mentions:
When ``True`` (default), scan *message* for explicit registered
tool names. Set ``False`` for non-user text (e.g. assistant
drafts, response postprocessing) so mentions in those strings
never inflate the tool set.
Returns a dict with keys ``tools``, ``strategy``,
``complexity``, and ``safety``.
"""
logger.info(
"VectorClassifier.classify() for: %s",
message[:100] if message else "<blank>",
)
# ── blank / whitespace-only messages ──────────────────────
# An empty query produces a near-zero embedding whose cosine
# similarity is >= threshold for *every* tool, returning all
# tools and blowing past the 512 function-declaration limit.
if not message or not message.strip():
logger.info(
"Blank message received, returning empty "
"tool set (strategy=none)",
)
return {
"complexity": "moderate",
"safety": "safe",
"strategy": "none",
"tools": [],
}
if scan_explicit_tool_mentions:
_reg = registry_tool_names or ()
if len(message) >= _EXPLICIT_SCAN_TO_THREAD_CHARS:
explicit = await asyncio.to_thread(
find_tools_explicitly_named,
message,
_reg,
)
else:
explicit = find_tools_explicitly_named(message, _reg)
else:
explicit = []
def _with_essentials(
tools: list[str],
strategy: str,
) -> dict[str, Any]:
"""Internal helper: build result dict + essential tools."""
out: dict[str, Any] = {
"complexity": "moderate",
"safety": "safe",
"strategy": strategy,
"tools": list(tools),
}
if out["strategy"] != "none":
for tool in self._get_essential_tools():
if tool not in out["tools"]:
out["tools"].append(tool)
return out
embeddings_loaded = await self._load_tool_embeddings()
result: dict[str, Any] = {
"complexity": "moderate",
"safety": "safe",
"strategy": "optional",
"tools": [],
}
if not embeddings_loaded:
logger.warning(
"Tool embeddings not available, "
"returning default result",
)
if explicit:
result = _with_essentials(
self._expand_tool_prefixes(list(explicit)),
"optional",
)
logger.info(
"VectorClassifier result (explicit-only, no "
"embeddings): strategy=%s, tools_count=%d",
result["strategy"],
len(result["tools"]),
)
return result
query_emb = (
query_embedding
if query_embedding is not None
else await self._get_query_embedding(message)
)
if query_emb is None:
logger.warning(
"Failed to get query embedding, "
"returning default result",
)
if explicit:
result = _with_essentials(
self._expand_tool_prefixes(list(explicit)),
"optional",
)
logger.info(
"VectorClassifier result (explicit-only, no "
"query embedding): strategy=%s, tools_count=%d",
result["strategy"],
len(result["tools"]),
)
return result
matches = await self._find_matching_tools(query_emb)
vector_names: list[str] = [
t["name"] for t in matches
] if matches else []
combined = self._expand_tool_prefixes(
list(dict.fromkeys([*explicit, *vector_names])),
)
if matches:
max_score = matches[0]["score"]
if max_score > 0.8:
result["strategy"] = "force"
elif max_score > 0.2:
result["strategy"] = "optional"
else:
result["strategy"] = "none"
logger.info(
"Vector match: %d tools, "
"max_score=%.4f, strategy=%s",
len(vector_names),
max_score,
result["strategy"],
)
for t in matches[:5]:
logger.debug(
" - %s: %.4f",
t["name"], t["score"],
)
else:
logger.info(
"No tools matched above threshold",
)
result["strategy"] = "none"
if explicit and result["strategy"] == "none":
result["strategy"] = "optional"
result["tools"] = combined
if result["strategy"] != "none":
for tool in self._get_essential_tools():
if tool not in result["tools"]:
result["tools"].append(tool)
logger.info(
"VectorClassifier result: "
"strategy=%s, tools_count=%d",
result["strategy"],
len(result["tools"]),
)
return result
[docs]
async def classify_response_for_missing_tools(
self,
response_text: str,
current_tools: list[str],
threshold: float = TOOL_EXPANSION_THRESHOLD,
) -> list[str]:
"""Find tools the bot might need but lacks.
Used for dynamic tool expansion when the bot signals
it needs tools not included in the original set.
Runs **vector similarity only** on *response_text* — not
:func:`find_tools_explicitly_named`, so tool names that appear in
assistant output or postprocessed reply text never add tools.
"""
logger.info(
"classify_response_for_missing_tools: "
"%d current tools",
len(current_tools),
)
loaded = await self._load_tool_embeddings()
if not loaded:
return []
query_emb = await self._get_query_embedding(
response_text,
)
if query_emb is None:
return []
original = self.similarity_threshold
self.similarity_threshold = threshold
try:
matches = await self._find_matching_tools(
query_emb,
)
finally:
self.similarity_threshold = original
current_set = set(current_tools)
new_tools = [
t["name"] for t in matches
if (
t["name"] not in current_set
and t["score"] >= threshold
)
]
new_tools = [
t for t in self._expand_tool_prefixes(new_tools)
if t not in current_set
]
if new_tools:
logger.info(
"Found %d potential new tools: %s",
len(new_tools), new_tools[:10],
)
else:
logger.info("No new tools found above threshold")
return new_tools
@staticmethod
def _get_essential_tools() -> list[str]:
"""Tools always included for non-NONE strategies."""
return [
"no_tool",
"no_response",
"kick_user",
"ban_user",
"timeout_user",
"block_user",
"store_knowledge",
"add_entity",
"add_relationship",
"query_knowledge",
"get_entity",
"list_entities",
"delete_entity",
"delete_relationship",
"search_knowledge",
"write_short_term_note",
"read_short_term_notes",
"clear_short_term_notes",
"calculate_math_expression",
"get_user_profile",
"universal_decode",
"wait",
"check_task",
"await_task",
"redirect_task",
"create_file",
"read_file",
"delete_file",
"edit_file",
"upload_file",
"git_read_repo_file",
"create_goal",
"get_goal",
"list_channel_goals",
"update_goal",
"delete_goal",
"add_subtask",
"update_subtask",
"list_subtasks",
"remove_subtask",
"list_all_goals",
"create_webhook",
"list_webhooks",
"delete_webhook",
"edit_webhook",
"execute_webhook",
"list_all_tools",
"current_time",
"read_own_docs",
"semantic_search",
"discord_react",
"discord_embed",
"extend_tool_loop",
"request_tool_injection",
]
[docs]
async def close(self) -> None:
"""Close the underlying embedding client."""
if self._embedding_client is not None:
await self._embedding_client.close()
self._embedding_client = None
# ------------------------------------------------------------------
# Standalone helpers for initialisation scripts
# ------------------------------------------------------------------
[docs]
async def initialize_tool_embeddings_from_file(
index_file_path: str,
redis_client: aioredis.Redis,
api_key: str | None = None,
force_recompute: bool = False,
) -> bool:
"""Compute centroid embeddings and store in Redis.
Reads ``tool_index_data.json``, embeds every synthetic query
per tool, calculates the centroid, and writes the result into
Redis hashes.
"""
embedding_client: OpenRouterEmbeddings | None = None
try:
if not force_recompute:
existing = await redis_client.hlen(
TOOL_EMBEDDINGS_HASH_KEY,
)
if existing > 0:
logger.info(
"Tool embeddings already in Redis "
"(%d tools), skipping",
existing,
)
return True
logger.info(
"Loading tool index data from %s",
index_file_path,
)
def _read_index() -> dict[str, Any]:
"""Internal helper: read index.
Returns:
dict[str, Any]: The result.
"""
with open(
index_file_path, "r", encoding="utf-8",
) as f:
return json.load(f)
tool_data: dict[str, Any] = await asyncio.to_thread(_read_index)
logger.info(
"Loaded %d tools from index file",
len(tool_data),
)
resolved_key = api_key or os.environ.get(
"OPENROUTER_API_KEY", "",
)
embedding_client = OpenRouterEmbeddings(
api_key=resolved_key,
)
embeddings_to_store: dict[str, str] = {}
metadata_to_store: dict[str, str] = {}
for tool_name, tool_info in tool_data.items():
try:
queries = tool_info.get(
"synthetic_queries", [],
)
desc = tool_info.get("description", "")
if not queries:
logger.warning(
"Tool %r has no synthetic "
"queries, using description",
tool_name,
)
queries = (
[desc] if desc
else [f"use {tool_name}"]
)
try:
embs = [
e for e in
await embedding_client.embed_texts(
queries,
)
if e.size > 0
]
except Exception as exc:
logger.warning(
"Failed to embed queries "
"for %r: %s",
tool_name, exc,
)
embs = []
if not embs:
logger.warning(
"No embeddings computed for %r",
tool_name,
)
continue
centroid = np.mean(embs, axis=0)
norm = np.linalg.norm(centroid)
if norm > 0:
centroid = centroid / norm
embeddings_to_store[tool_name] = (
json.dumps(centroid.tolist())
)
metadata_to_store[tool_name] = json.dumps({
"name": tool_name,
"description": desc,
"query_count": len(queries),
})
logger.debug(
"Computed centroid for %r from "
"%d queries",
tool_name, len(embs),
)
except Exception as exc:
logger.error(
"Failed to process tool %r: %s",
tool_name, exc,
)
if embeddings_to_store:
if force_recompute:
await redis_client.delete(
TOOL_EMBEDDINGS_HASH_KEY,
)
await redis_client.delete(
TOOL_METADATA_HASH_KEY,
)
await redis_client.hset(
TOOL_EMBEDDINGS_HASH_KEY,
mapping=embeddings_to_store,
)
await redis_client.hset(
TOOL_METADATA_HASH_KEY,
mapping=metadata_to_store,
)
logger.info(
"Stored %d tool embeddings in Redis",
len(embeddings_to_store),
)
return True
logger.error("No tool embeddings to store")
return False
except Exception as exc:
logger.error(
"Failed to initialize tool embeddings: %s",
exc, exc_info=True,
)
return False
finally:
if embedding_client is not None:
await embedding_client.close()
[docs]
async def reload_tool_embeddings(
redis_client: aioredis.Redis,
api_key: str | None = None,
) -> bool:
"""Reload embeddings from ``tool_index_data.json``."""
index_file = os.path.join(
os.path.dirname(__file__), "tool_index_data.json",
)
return await initialize_tool_embeddings_from_file(
index_file,
redis_client,
api_key=api_key,
force_recompute=True,
)