Source code for rag_system.auto_search

"""RAG Auto-Search Manager.

Manages per-channel auto-search configuration and provides context
injection for automatic RAG searches on user messages.

Adapted for the v3 multi-platform architecture: channel keys use
``platform:channel_id`` composite format.
"""

from __future__ import annotations

import asyncio
import jsonutil as json
import logging
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional

import redis.asyncio as aioredis

logger = logging.getLogger(__name__)

REDIS_KEY_PREFIX = "stargazer:v3:rag:auto_search:"
_CLOUD_STORE_PREFIX = "cloud_usr_"
_CLOUD_SHARE_PREFIX = "stargazer:v3:cloud_rag:shared:"


[docs] class RAGAutoSearchManager: """Per-channel RAG auto-search configuration and query fan-out, backed by Redis. Stores one config record per channel under the ``stargazer:v3:rag:auto_search:`` Redis key prefix and, when a channel is enabled, automatically runs a semantic search across its configured stores for every inbound user message so the result can be injected into the LLM context. This is what makes a channel "RAG-aware" without the user having to invoke a tool. It owns no other state: all persistence is plain Redis ``GET`` / ``SET`` / ``SCAN`` / ``DELETE`` on string keys (config) and ``SISMEMBER`` on ``stargazer:v3:cloud_rag:shared:`` sets (cloud-store access control), while the actual vector search is delegated to :class:`rag_system.file_rag_manager.FileRAGManager` instances resolved via :func:`rag_system.file_rag_manager.get_rag_store`. Constructed with a live ``redis.asyncio`` client by the RAG tool layer (``tools/rag.py``, ``tools/cloud_rag.py``), the web config API (``web/rag_config_api.py``), and the inference message pipeline, whose :meth:`search_for_message` is called per message from ``message_processor.generate_and_send``. """
[docs] def __init__(self, redis_client: aioredis.Redis) -> None: """Initialize the instance. Args: redis_client (aioredis.Redis): Redis connection client. """ self._redis = redis_client
# -- config CRUD ---------------------------------------------------------
[docs] async def set_channel_config( self, channel_key: str, store_names: List[str], enabled: bool = True, n_results: int = 3, min_score: float = 0.5, ) -> Dict[str, Any]: """Write (create or overwrite) the auto-search config for one channel. Builds the canonical config dict, clamps ``n_results`` to the 1-10 range and ``min_score`` to 0.0-1.0 so callers cannot persist out-of-band values, stamps a fresh ``updated_at`` timestamp, and JSON-serializes it into the single Redis key ``stargazer:v3:rag:auto_search:<channel_key>`` via ``SET``. Once written, :meth:`search_for_message` will start auto-searching that channel on the next message. An ``info`` log line records the change. Called by the RAG admin tools in ``tools/rag.py`` and ``tools/cloud_rag.py`` (enable/share handlers) and by the web config API in ``web/rag_config_api.py`` to persist user edits. Args: channel_key (str): Composite ``"platform:channel_id"`` identifier. store_names (List[str]): RAG store names to search for this channel. enabled (bool): Whether auto-search is active for the channel. n_results (int): Number of chunks to inject; clamped to 1-10. min_score (float): Minimum similarity to keep a chunk; clamped to 0.0-1.0. Returns: Dict[str, Any]: The persisted config dict (post-clamping, with the new ``updated_at`` value). """ config: Dict[str, Any] = { "channel_key": channel_key, "store_names": store_names, "enabled": enabled, "n_results": min(max(1, n_results), 10), "min_score": min(max(0.0, min_score), 1.0), "updated_at": datetime.now(timezone.utc).isoformat(), } key = f"{REDIS_KEY_PREFIX}{channel_key}" await self._redis.set(key, json.dumps(config)) logger.info( "Set RAG auto-search for %s: stores=%s enabled=%s", channel_key, store_names, enabled, ) return config
[docs] async def get_channel_config( self, channel_key: str, ) -> Optional[Dict[str, Any]]: """Load and decode the persisted auto-search config for a channel. Reads the ``stargazer:v3:rag:auto_search:<channel_key>`` Redis key with a single ``GET`` and JSON-decodes it, returning ``None`` when the channel was never configured. This is the canonical read used both to render current settings and to gate whether a message should be auto-searched at all. Called internally by :meth:`search_for_message` (the enablement check) and :meth:`disable_channel`, by the web config API in ``web/rag_config_api.py``, by the cloud-RAG share/unshare tools in ``tools/cloud_rag.py``, and by the RAG status tool in ``tools/rag.py``. Args: channel_key (str): Composite ``"platform:channel_id"`` identifier. Returns: Optional[Dict[str, Any]]: The decoded config dict, or ``None`` if the channel has no stored config. """ key = f"{REDIS_KEY_PREFIX}{channel_key}" data = await self._redis.get(key) return json.loads(data) if data else None
[docs] async def disable_channel(self, channel_key: str) -> bool: """Turn off auto-search for a channel without discarding its store list. Loads the existing config via :meth:`get_channel_config`, flips ``enabled`` to ``False``, refreshes ``updated_at``, and writes the record back with a Redis ``SET``. This is a soft toggle: the configured ``store_names`` are preserved so the channel can be re-enabled later without re-selecting stores. Returns ``False`` (a no-op) when the channel was never configured. Called by the RAG admin tool in ``tools/rag.py`` (the disable action). Args: channel_key (str): Composite ``"platform:channel_id"`` identifier. Returns: bool: ``True`` if a config existed and was updated, ``False`` if there was nothing to disable. """ config = await self.get_channel_config(channel_key) if not config: return False config["enabled"] = False config["updated_at"] = datetime.now(timezone.utc).isoformat() key = f"{REDIS_KEY_PREFIX}{channel_key}" await self._redis.set(key, json.dumps(config)) return True
[docs] async def remove_channel_config(self, channel_key: str) -> bool: """Permanently delete a channel's auto-search config from Redis. Issues a single ``DEL`` on the ``stargazer:v3:rag:auto_search:<channel_key>`` key. Unlike :meth:`disable_channel`, this discards the stored ``store_names`` entirely, so the channel reverts to having no RAG configuration at all. Called by the web config API in ``web/rag_config_api.py``, by the RAG admin tool in ``tools/rag.py``, and by the cloud-RAG unshare flow in ``tools/cloud_rag.py`` when the last shared store is removed. Args: channel_key (str): Composite ``"platform:channel_id"`` identifier. Returns: bool: ``True`` if a key was actually deleted, ``False`` if no config existed for the channel. """ key = f"{REDIS_KEY_PREFIX}{channel_key}" return (await self._redis.delete(key)) > 0
[docs] async def list_configured_channels(self) -> List[Dict[str, Any]]: """Enumerate every channel that currently has an auto-search config. Walks all ``stargazer:v3:rag:auto_search:*`` keys with a non-blocking ``SCAN`` iterator and ``GET``-decodes each into its config dict. Used to render an admin overview of which channels have RAG enabled and against which stores; the scan plus per-key fetch makes this O(number of configured channels) rather than a single bulk read. Called by the web config API in ``web/rag_config_api.py`` and by the RAG status/listing tool in ``tools/rag.py``. Returns: List[Dict[str, Any]]: One decoded config dict per configured channel (order follows the Redis scan, i.e. unspecified). """ configs: List[Dict[str, Any]] = [] async for key in self._redis.scan_iter(f"{REDIS_KEY_PREFIX}*"): data = await self._redis.get(key) if data: configs.append(json.loads(data)) return configs
# -- search -------------------------------------------------------------- async def _can_access_cloud_store( self, store_name: str, user_id: str, channel_key: str, ) -> bool: """Check if *user_id* may search a ``cloud_usr_`` store. Access is granted when: 1. The user owns the store (user_id embedded in the name), OR 2. The store has been shared with the channel via Redis set. """ owner_id = store_name[len(_CLOUD_STORE_PREFIX) :].split("_", 1)[0] if user_id == owner_id: return True try: return bool( await self._redis.sismember( f"{_CLOUD_SHARE_PREFIX}{store_name}", channel_key, ) ) except Exception: return False
[docs] async def search_for_message( self, channel_key: str, message_content: str, chunk_size: int = 10_000, query_embedding: list[float] | None = None, user_id: str = "", ) -> Optional[str]: """Perform auto-search if the channel is configured. Parameters ---------- query_embedding: Pre-computed 3072-d embedding for *message_content*. When provided it is forwarded to the pgvector store as the KNN query vector, skipping a redundant embedding call. user_id: The message author. Used to enforce access control on ``cloud_usr_`` stores. Returns XML-formatted RAG context string, or ``None``. """ if not (message_content or "").strip(): return None config = await self.get_channel_config(channel_key) if not config or not config.get("enabled"): return None store_names = config.get("store_names", []) if not store_names: return None n_results = config.get("n_results", 3) min_score = config.get("min_score", 0.1) from .file_rag_manager import get_rag_store async def _search_store(store_name: str) -> List[Dict[str, Any]]: """Search one configured store and return scored chunk dicts. This closure is defined inside :meth:`RAGAutoSearchManager.search_for_message` and captures that call's *channel_key*, *message_content*, *query_embedding*, *user_id*, *n_results* and *min_score*. For ``cloud_usr_`` stores it first gates access through :meth:`RAGAutoSearchManager._can_access_cloud_store` (which reads the ``stargazer:v3:cloud_rag:shared:`` Redis set), returning an empty list when the user is neither owner nor a shared recipient. It then resolves the store via :func:`rag_system.file_rag_manager.get_rag_store` and runs the blocking ``store.search`` in a worker thread via :func:`asyncio.to_thread`, forwarding the pre-computed embedding so the pgvector store skips re-embedding. Each returned chunk is kept only when its ``similarity_score`` meets *min_score*, with chunk text truncated to the enclosing call's *chunk_size*. Any exception is swallowed and logged as a ``warning`` so one failing store never aborts the gathered fan-out. The caller invokes this concurrently for every configured store via :func:`asyncio.gather`; it has no other callers. Args: store_name (str): Name of the RAG store to query. Returns: List[Dict[str, Any]]: One dict per surviving chunk, carrying ``store``, ``filename``, ``file_path``, ``source_url``, ``similarity_score`` and ``chunk``. Empty on access denial or error. """ if store_name.startswith(_CLOUD_STORE_PREFIX): if not await self._can_access_cloud_store( store_name, user_id, channel_key, ): return [] try: store = get_rag_store(store_name) results = await asyncio.to_thread( store.search, query=message_content, n_results=n_results * 5, return_content=True, content_mode="chunks", query_embedding=query_embedding, ) store_results: List[Dict[str, Any]] = [] for r in results: score = r.get("similarity_score") if score is not None and score >= min_score: chunk = (r.get("content", "") or "")[:chunk_size] store_results.append( { "store": store_name, "filename": r.get("filename", "unknown"), "file_path": r.get("file_path", ""), "source_url": r.get("source_url"), "similarity_score": score, "chunk": chunk, } ) return store_results except Exception as e: logger.warning( "RAG auto-search failed for store '%s': %s", store_name, e, ) return [] search_results = await asyncio.gather( *(_search_store(s) for s in store_names), return_exceptions=False, ) all_results: List[Dict[str, Any]] = [] for store_results in search_results: all_results.extend(store_results) if not all_results: return None all_results.sort( key=lambda x: x.get("similarity_score", 0), reverse=True, ) return self._format_rag_context(all_results[:n_results])
# -- helpers ------------------------------------------------------------- @staticmethod def _extract_relevant_chunk( content: str, query: str, max_size: int = 10_000, ) -> str: """Trim long content down to the paragraphs most relevant to a query. A pure, side-effect-free text utility for size-bounding a document before injection. When the content already fits within *max_size* it is returned verbatim; otherwise the text is split into paragraphs (falling back to lines), each scored by how many >3-character query words it contains plus a small boost for the first few paragraphs, and the highest-scoring paragraphs are packed in original order until the budget is exhausted. A truncation marker is appended when content was dropped. It touches no Redis, store, or network state and exists only to keep injected excerpts focused and bounded. This static helper currently has no in-repo callers (it is a retained utility; :meth:`search_for_message` instead truncates each chunk by raw ``chunk_size``). Args: content (str): The full text to condense. query (str): The user query whose terms drive paragraph scoring. max_size (int): Maximum length of the returned excerpt in characters. Returns: str: The condensed excerpt, or the original content unchanged when it already fits within *max_size*. """ if not content: return "" content = content.strip() if len(content) <= max_size: return content query_words = {w.lower() for w in query.split() if len(w) > 3} if not query_words: return content[:max_size].rstrip() + "\n..." paragraphs = content.split("\n\n") if len(paragraphs) == 1: paragraphs = content.split("\n") scored = [] for i, para in enumerate(paragraphs): if not para.strip(): continue para_lower = para.lower() score = sum(1 for w in query_words if w in para_lower) if i < 5: score += 0.5 scored.append((score, i, para)) if not scored: return content[:max_size].rstrip() + "\n..." scored.sort(key=lambda x: (-x[0], x[1])) selected = [] current_size = 0 for _score, idx, para in scored: para = para.strip() if current_size + len(para) + 2 <= max_size: selected.append((idx, para)) current_size += len(para) + 2 if not selected: return scored[0][2][:max_size].rstrip() + "\n..." selected.sort(key=lambda x: x[0]) result = "\n\n".join(p for _, p in selected) if len(result) < len(content) - 100: result += "\n\n[... additional content truncated ...]" return result @staticmethod def _read_whole_file_hint(store: str, filename: str) -> str: """Build the XML-safe ``rag_read_store_file(...)`` hint for a result. Renders the exact tool-call string the LLM can copy to pull the full file text behind a retrieval chunk, then runs it through :meth:`_escape_xml` so it is safe to embed as an XML attribute value. Pure string formatting with no I/O. Called by :meth:`_format_rag_context`, which emits the returned string as the ``read_whole_file`` attribute on each ``rag_result`` element; the rendered hint is covered by ``tests/test_rag_auto_search_format.py``. Args: store (str): Name of the RAG store holding the file. filename (str): Indexed filename to reference in the hint. Returns: str: An XML-escaped ``rag_read_store_file(...)`` call string. """ return RAGAutoSearchManager._escape_xml( f"rag_read_store_file(store_name='{store}', filename='{filename}')" ) @staticmethod def _format_rag_context(results: List[Dict[str, Any]]) -> str: """Format search results as an XML string for LLM context injection. Includes ``object_storage_uri`` when present so the LLM can reference the full cloud file location. Each result is a retrieval chunk; ``read_whole_file`` names the tool call for full file text. """ if not results: return "" preamble = ( "<!-- RAG excerpts are retrieval chunks; use rag_read_store_file " "for full file text. -->\n" ) xml_parts: List[str] = [] for r in results: store = str(r.get("store", "")) filename = str(r.get("filename", "")) attrs = ( f'store="{RAGAutoSearchManager._escape_xml(store)}" ' f'filename="{RAGAutoSearchManager._escape_xml(filename)}" ' f'similarity="{r.get("similarity_score", 0)}" ' f'read_whole_file="{RAGAutoSearchManager._read_whole_file_hint(store, filename)}"' ) uri = r.get("file_path", "") if uri and "://" in uri: attrs += f' object_storage_uri="{RAGAutoSearchManager._escape_xml(str(uri))}"' source_url = r.get("source_url") if source_url: attrs += ( f' source_url="{RAGAutoSearchManager._escape_xml(str(source_url))}"' ) xml_parts.append( f" <rag_result {attrs}>\n" f'{r.get("chunk", "")}\n' f" </rag_result>" ) return preamble + "\n".join(xml_parts) @staticmethod def _escape_xml(text: str) -> str: """Escape the five XML special characters in a string. Replaces ``&``, ``<``, ``>``, double-quote and single-quote with their XML entity forms so untrusted store names, filenames, URIs and chunk attributes cannot break out of the ``rag_result`` markup produced for LLM context. Pure and side-effect-free; empty input yields an empty string. Called by :meth:`_format_rag_context` (for every emitted attribute) and by :meth:`_read_whole_file_hint`. Args: text (str): Raw text to escape. Returns: str: The XML-escaped text, or ``""`` for empty input. """ if not text: return "" return ( text.replace("&", "&amp;") .replace("<", "&lt;") .replace(">", "&gt;") .replace('"', "&quot;") .replace("'", "&apos;") )