"""RAG Auto-Search Manager.
Manages per-channel auto-search configuration and provides context
injection for automatic RAG searches on user messages.
Adapted for the v3 multi-platform architecture: channel keys use
``platform:channel_id`` composite format.
"""
from __future__ import annotations
import asyncio
import json
import logging
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
import redis.asyncio as aioredis
logger = logging.getLogger(__name__)
REDIS_KEY_PREFIX = "stargazer:v3:rag:auto_search:"
_CLOUD_STORE_PREFIX = "cloud_usr_"
_CLOUD_SHARE_PREFIX = "stargazer:v3:cloud_rag:shared:"
[docs]
class RAGAutoSearchManager:
"""Per-channel auto-search configuration backed by Redis."""
[docs]
def __init__(self, redis_client: aioredis.Redis) -> None:
"""Initialize the instance.
Args:
redis_client (aioredis.Redis): Redis connection client.
"""
self._redis = redis_client
# -- config CRUD ---------------------------------------------------------
[docs]
async def set_channel_config(
self,
channel_key: str,
store_names: List[str],
enabled: bool = True,
n_results: int = 3,
min_score: float = 0.5,
) -> Dict[str, Any]:
"""Set auto-search configuration for a channel.
Parameters
----------
channel_key:
Composite key ``"platform:channel_id"``.
"""
config: Dict[str, Any] = {
"channel_key": channel_key,
"store_names": store_names,
"enabled": enabled,
"n_results": min(max(1, n_results), 10),
"min_score": min(max(0.0, min_score), 1.0),
"updated_at": datetime.now(timezone.utc).isoformat(),
}
key = f"{REDIS_KEY_PREFIX}{channel_key}"
await self._redis.set(key, json.dumps(config))
logger.info(
"Set RAG auto-search for %s: stores=%s enabled=%s",
channel_key, store_names, enabled,
)
return config
[docs]
async def get_channel_config(
self, channel_key: str,
) -> Optional[Dict[str, Any]]:
"""Retrieve the channel config.
Args:
channel_key (str): The channel key value.
Returns:
Optional[Dict[str, Any]]: The result.
"""
key = f"{REDIS_KEY_PREFIX}{channel_key}"
data = await self._redis.get(key)
return json.loads(data) if data else None
[docs]
async def disable_channel(self, channel_key: str) -> bool:
"""Disable channel.
Args:
channel_key (str): The channel key value.
Returns:
bool: True on success, False otherwise.
"""
config = await self.get_channel_config(channel_key)
if not config:
return False
config["enabled"] = False
config["updated_at"] = datetime.now(timezone.utc).isoformat()
key = f"{REDIS_KEY_PREFIX}{channel_key}"
await self._redis.set(key, json.dumps(config))
return True
[docs]
async def remove_channel_config(self, channel_key: str) -> bool:
"""Delete the specified channel config.
Args:
channel_key (str): The channel key value.
Returns:
bool: True on success, False otherwise.
"""
key = f"{REDIS_KEY_PREFIX}{channel_key}"
return (await self._redis.delete(key)) > 0
# -- search --------------------------------------------------------------
async def _can_access_cloud_store(
self,
store_name: str,
user_id: str,
channel_key: str,
) -> bool:
"""Check if *user_id* may search a ``cloud_usr_`` store.
Access is granted when:
1. The user owns the store (user_id embedded in the name), OR
2. The store has been shared with the channel via Redis set.
"""
owner_id = store_name[len(_CLOUD_STORE_PREFIX):].split("_", 1)[0]
if user_id == owner_id:
return True
try:
return bool(await self._redis.sismember(
f"{_CLOUD_SHARE_PREFIX}{store_name}", channel_key,
))
except Exception:
return False
[docs]
async def search_for_message(
self,
channel_key: str,
message_content: str,
chunk_size: int = 10_000,
query_embedding: list[float] | None = None,
user_id: str = "",
) -> Optional[str]:
"""Perform auto-search if the channel is configured.
Parameters
----------
query_embedding:
Pre-computed embedding for *message_content*. When provided
it is forwarded to ChromaDB to skip its internal embedding
call.
user_id:
The message author. Used to enforce access control on
``cloud_usr_`` stores.
Returns XML-formatted RAG context string, or ``None``.
"""
config = await self.get_channel_config(channel_key)
if not config or not config.get("enabled"):
return None
store_names = config.get("store_names", [])
if not store_names:
return None
n_results = config.get("n_results", 3)
min_score = config.get("min_score", 0.1)
from .file_rag_manager import get_rag_store
async def _search_store(store_name: str) -> List[Dict[str, Any]]:
if store_name.startswith(_CLOUD_STORE_PREFIX):
if not await self._can_access_cloud_store(
store_name, user_id, channel_key,
):
return []
try:
store = get_rag_store(store_name)
results = await asyncio.to_thread(
store.search,
query=message_content,
n_results=n_results * 2,
return_content=True,
query_embedding=query_embedding,
)
store_results: List[Dict[str, Any]] = []
for r in results:
score = r.get("similarity_score", 0)
if score and score >= min_score:
content = r.get("content", "")
chunk = self._extract_relevant_chunk(
content, message_content, max_size=chunk_size,
)
store_results.append({
"store": store_name,
"filename": r.get("filename", "unknown"),
"file_path": r.get("file_path", ""),
"source_url": r.get("source_url"),
"similarity_score": score,
"chunk": chunk,
})
return store_results
except Exception as e:
logger.warning(
"RAG auto-search failed for store '%s': %s",
store_name, e,
)
return []
search_results = await asyncio.gather(
*(_search_store(s) for s in store_names),
return_exceptions=False,
)
all_results: List[Dict[str, Any]] = []
for store_results in search_results:
all_results.extend(store_results)
if not all_results:
return None
all_results.sort(
key=lambda x: x.get("similarity_score", 0), reverse=True,
)
return self._format_rag_context(all_results[:n_results])
# -- helpers -------------------------------------------------------------
@staticmethod
def _extract_relevant_chunk(
content: str, query: str, max_size: int = 10_000,
) -> str:
"""Internal helper: extract relevant chunk.
Args:
content (str): Content data.
query (str): Search query or input string.
max_size (int): The max size value.
Returns:
str: Result string.
"""
if not content:
return ""
content = content.strip()
if len(content) <= max_size:
return content
query_words = {w.lower() for w in query.split() if len(w) > 3}
if not query_words:
return content[:max_size].rstrip() + "\n..."
paragraphs = content.split("\n\n")
if len(paragraphs) == 1:
paragraphs = content.split("\n")
scored = []
for i, para in enumerate(paragraphs):
if not para.strip():
continue
para_lower = para.lower()
score = sum(1 for w in query_words if w in para_lower)
if i < 5:
score += 0.5
scored.append((score, i, para))
if not scored:
return content[:max_size].rstrip() + "\n..."
scored.sort(key=lambda x: (-x[0], x[1]))
selected = []
current_size = 0
for _score, idx, para in scored:
para = para.strip()
if current_size + len(para) + 2 <= max_size:
selected.append((idx, para))
current_size += len(para) + 2
if not selected:
return scored[0][2][:max_size].rstrip() + "\n..."
selected.sort(key=lambda x: x[0])
result = "\n\n".join(p for _, p in selected)
if len(result) < len(content) - 100:
result += "\n\n[... additional content truncated ...]"
return result
@staticmethod
def _format_rag_context(results: List[Dict[str, Any]]) -> str:
"""Format search results as an XML string for LLM context injection.
Includes ``object_storage_uri`` when present so the LLM can
reference the full cloud file location.
"""
xml_parts: List[str] = []
for r in results:
attrs = (
f'store="{r.get("store", "")}" '
f'filename="{r.get("filename", "")}" '
f'similarity="{r.get("similarity_score", 0)}"'
)
uri = r.get("file_path", "")
if uri and "://" in uri:
attrs += f' object_storage_uri="{uri}"'
source_url = r.get("source_url")
if source_url:
attrs += f' source_url="{source_url}"'
xml_parts.append(
f" <rag_result {attrs}>\n"
f'{r.get("chunk", "")}\n'
f" </rag_result>"
)
return "\n".join(xml_parts)
@staticmethod
def _escape_xml(text: str) -> str:
"""Internal helper: escape xml.
Args:
text (str): Text content.
Returns:
str: Result string.
"""
if not text:
return ""
return (
text.replace("&", "&")
.replace("<", "<")
.replace(">", ">")
.replace('"', """)
.replace("'", "'")
)