Source code for rag_system.auto_search

"""RAG Auto-Search Manager.

Manages per-channel auto-search configuration and provides context
injection for automatic RAG searches on user messages.

Adapted for the v3 multi-platform architecture: channel keys use
``platform:channel_id`` composite format.
"""

from __future__ import annotations

import asyncio
import json
import logging
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional

import redis.asyncio as aioredis

logger = logging.getLogger(__name__)

REDIS_KEY_PREFIX = "stargazer:v3:rag:auto_search:"
_CLOUD_STORE_PREFIX = "cloud_usr_"
_CLOUD_SHARE_PREFIX = "stargazer:v3:cloud_rag:shared:"


[docs] class RAGAutoSearchManager: """Per-channel auto-search configuration backed by Redis."""
[docs] def __init__(self, redis_client: aioredis.Redis) -> None: """Initialize the instance. Args: redis_client (aioredis.Redis): Redis connection client. """ self._redis = redis_client
# -- config CRUD ---------------------------------------------------------
[docs] async def set_channel_config( self, channel_key: str, store_names: List[str], enabled: bool = True, n_results: int = 3, min_score: float = 0.5, ) -> Dict[str, Any]: """Set auto-search configuration for a channel. Parameters ---------- channel_key: Composite key ``"platform:channel_id"``. """ config: Dict[str, Any] = { "channel_key": channel_key, "store_names": store_names, "enabled": enabled, "n_results": min(max(1, n_results), 10), "min_score": min(max(0.0, min_score), 1.0), "updated_at": datetime.now(timezone.utc).isoformat(), } key = f"{REDIS_KEY_PREFIX}{channel_key}" await self._redis.set(key, json.dumps(config)) logger.info( "Set RAG auto-search for %s: stores=%s enabled=%s", channel_key, store_names, enabled, ) return config
[docs] async def get_channel_config( self, channel_key: str, ) -> Optional[Dict[str, Any]]: """Retrieve the channel config. Args: channel_key (str): The channel key value. Returns: Optional[Dict[str, Any]]: The result. """ key = f"{REDIS_KEY_PREFIX}{channel_key}" data = await self._redis.get(key) return json.loads(data) if data else None
[docs] async def disable_channel(self, channel_key: str) -> bool: """Disable channel. Args: channel_key (str): The channel key value. Returns: bool: True on success, False otherwise. """ config = await self.get_channel_config(channel_key) if not config: return False config["enabled"] = False config["updated_at"] = datetime.now(timezone.utc).isoformat() key = f"{REDIS_KEY_PREFIX}{channel_key}" await self._redis.set(key, json.dumps(config)) return True
[docs] async def remove_channel_config(self, channel_key: str) -> bool: """Delete the specified channel config. Args: channel_key (str): The channel key value. Returns: bool: True on success, False otherwise. """ key = f"{REDIS_KEY_PREFIX}{channel_key}" return (await self._redis.delete(key)) > 0
[docs] async def list_configured_channels(self) -> List[Dict[str, Any]]: """List configured channels. Returns: List[Dict[str, Any]]: The result. """ configs: List[Dict[str, Any]] = [] async for key in self._redis.scan_iter(f"{REDIS_KEY_PREFIX}*"): data = await self._redis.get(key) if data: configs.append(json.loads(data)) return configs
# -- search -------------------------------------------------------------- async def _can_access_cloud_store( self, store_name: str, user_id: str, channel_key: str, ) -> bool: """Check if *user_id* may search a ``cloud_usr_`` store. Access is granted when: 1. The user owns the store (user_id embedded in the name), OR 2. The store has been shared with the channel via Redis set. """ owner_id = store_name[len(_CLOUD_STORE_PREFIX):].split("_", 1)[0] if user_id == owner_id: return True try: return bool(await self._redis.sismember( f"{_CLOUD_SHARE_PREFIX}{store_name}", channel_key, )) except Exception: return False
[docs] async def search_for_message( self, channel_key: str, message_content: str, chunk_size: int = 10_000, query_embedding: list[float] | None = None, user_id: str = "", ) -> Optional[str]: """Perform auto-search if the channel is configured. Parameters ---------- query_embedding: Pre-computed embedding for *message_content*. When provided it is forwarded to ChromaDB to skip its internal embedding call. user_id: The message author. Used to enforce access control on ``cloud_usr_`` stores. Returns XML-formatted RAG context string, or ``None``. """ config = await self.get_channel_config(channel_key) if not config or not config.get("enabled"): return None store_names = config.get("store_names", []) if not store_names: return None n_results = config.get("n_results", 3) min_score = config.get("min_score", 0.1) from .file_rag_manager import get_rag_store async def _search_store(store_name: str) -> List[Dict[str, Any]]: if store_name.startswith(_CLOUD_STORE_PREFIX): if not await self._can_access_cloud_store( store_name, user_id, channel_key, ): return [] try: store = get_rag_store(store_name) results = await asyncio.to_thread( store.search, query=message_content, n_results=n_results * 2, return_content=True, query_embedding=query_embedding, ) store_results: List[Dict[str, Any]] = [] for r in results: score = r.get("similarity_score", 0) if score and score >= min_score: content = r.get("content", "") chunk = self._extract_relevant_chunk( content, message_content, max_size=chunk_size, ) store_results.append({ "store": store_name, "filename": r.get("filename", "unknown"), "file_path": r.get("file_path", ""), "source_url": r.get("source_url"), "similarity_score": score, "chunk": chunk, }) return store_results except Exception as e: logger.warning( "RAG auto-search failed for store '%s': %s", store_name, e, ) return [] search_results = await asyncio.gather( *(_search_store(s) for s in store_names), return_exceptions=False, ) all_results: List[Dict[str, Any]] = [] for store_results in search_results: all_results.extend(store_results) if not all_results: return None all_results.sort( key=lambda x: x.get("similarity_score", 0), reverse=True, ) return self._format_rag_context(all_results[:n_results])
# -- helpers ------------------------------------------------------------- @staticmethod def _extract_relevant_chunk( content: str, query: str, max_size: int = 10_000, ) -> str: """Internal helper: extract relevant chunk. Args: content (str): Content data. query (str): Search query or input string. max_size (int): The max size value. Returns: str: Result string. """ if not content: return "" content = content.strip() if len(content) <= max_size: return content query_words = {w.lower() for w in query.split() if len(w) > 3} if not query_words: return content[:max_size].rstrip() + "\n..." paragraphs = content.split("\n\n") if len(paragraphs) == 1: paragraphs = content.split("\n") scored = [] for i, para in enumerate(paragraphs): if not para.strip(): continue para_lower = para.lower() score = sum(1 for w in query_words if w in para_lower) if i < 5: score += 0.5 scored.append((score, i, para)) if not scored: return content[:max_size].rstrip() + "\n..." scored.sort(key=lambda x: (-x[0], x[1])) selected = [] current_size = 0 for _score, idx, para in scored: para = para.strip() if current_size + len(para) + 2 <= max_size: selected.append((idx, para)) current_size += len(para) + 2 if not selected: return scored[0][2][:max_size].rstrip() + "\n..." selected.sort(key=lambda x: x[0]) result = "\n\n".join(p for _, p in selected) if len(result) < len(content) - 100: result += "\n\n[... additional content truncated ...]" return result @staticmethod def _format_rag_context(results: List[Dict[str, Any]]) -> str: """Format search results as an XML string for LLM context injection. Includes ``object_storage_uri`` when present so the LLM can reference the full cloud file location. """ xml_parts: List[str] = [] for r in results: attrs = ( f'store="{r.get("store", "")}" ' f'filename="{r.get("filename", "")}" ' f'similarity="{r.get("similarity_score", 0)}"' ) uri = r.get("file_path", "") if uri and "://" in uri: attrs += f' object_storage_uri="{uri}"' source_url = r.get("source_url") if source_url: attrs += f' source_url="{source_url}"' xml_parts.append( f" <rag_result {attrs}>\n" f'{r.get("chunk", "")}\n' f" </rag_result>" ) return "\n".join(xml_parts) @staticmethod def _escape_xml(text: str) -> str: """Internal helper: escape xml. Args: text (str): Text content. Returns: str: Result string. """ if not text: return "" return ( text.replace("&", "&amp;") .replace("<", "&lt;") .replace(">", "&gt;") .replace('"', "&quot;") .replace("'", "&apos;") )