Source code for rag_system.file_rag_manager

"""File-based RAG Manager.

Manages file indexing and retrieval using ChromaDB and OpenRouter embeddings.
Returns entire files, not chunks, for complete context.

Features:
- URL fetching support
- PDF parsing via PyMuPDF
- Chunked embeddings for better semantic matching
- Full file retrieval on search
"""

import hashlib
import json
import logging
import os
import re
from collections import OrderedDict
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Collection, Dict, List, Optional, Tuple
from urllib.parse import unquote, urlparse

import chromadb
import httpx
from chromadb.config import Settings

from .openrouter_embeddings import SyncOpenRouterEmbeddings

logger = logging.getLogger(__name__)

_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))


def _sanitize_collection_name(name: str) -> str:
    """Sanitize *name* for use as a ChromaDB collection name.

    ChromaDB requires 3-512 characters, only ``[a-zA-Z0-9._-]``,
    must start and end with ``[a-zA-Z0-9]``.
    """
    sanitized = re.sub(r"[^a-zA-Z0-9._-]", "_", name)
    sanitized = re.sub(r"_+", "_", sanitized)
    sanitized = sanitized.strip("._-")
    if len(sanitized) < 3:
        sanitized = f"store_{sanitized}" if sanitized else "store_default"
    if len(sanitized) > 500:
        sanitized = sanitized[:500].rstrip("._-")
    return sanitized


DEFAULT_STORE_PATH = os.path.join(_PROJECT_ROOT, "rag_stores")
STORE_FILES_SUBDIR = "files"

SUPPORTED_EXTENSIONS = {
    ".txt", ".md", ".py", ".js", ".ts", ".jsx", ".tsx",
    ".json", ".yaml", ".yml", ".toml", ".ini", ".cfg",
    ".html", ".css", ".scss", ".less",
    ".c", ".cpp", ".h", ".hpp", ".rs", ".go", ".java", ".kt",
    ".sh", ".bash", ".zsh", ".fish",
    ".sql", ".graphql",
    ".xml", ".csv",
    ".r", ".R", ".rmd",
    ".lua", ".rb", ".php", ".pl", ".pm",
    ".dockerfile", ".makefile",
    ".env", ".gitignore", ".dockerignore",
    ".rst", ".tex", ".bib",
    ".pdf",
}

MAX_FILE_SIZE = 15 * 1024 * 1024  # 15 MB
DEFAULT_CHUNK_SIZE = 1500
DEFAULT_CHUNK_OVERLAP = 200
CHROMA_MAX_BATCH = 5000


# ---------------------------------------------------------------------------
# PDF helpers
# ---------------------------------------------------------------------------

[docs] def extract_pdf_text(file_path: str) -> Optional[str]: """Extract text from a PDF file using PyMuPDF.""" try: import fitz # PyMuPDF doc = fitz.open(file_path) text_parts = [] for page_num, page in enumerate(doc, 1): page_text = page.get_text() if page_text.strip(): text_parts.append(f"--- Page {page_num} ---\n{page_text}") doc.close() if not text_parts: logger.warning("No text extracted from PDF: %s", file_path) return None return "\n\n".join(text_parts) except ImportError: logger.error( "PyMuPDF (fitz) not installed. Install with: pip install pymupdf", ) return None except Exception as e: logger.error("Failed to extract PDF text from %s: %s", file_path, e) return None
[docs] def compress_pdf( file_path: str, output_path: Optional[str] = None, remove_images: bool = True, ) -> Tuple[str, int, int]: """Compress a PDF using PyMuPDF's ``ez_save()``.""" import fitz # PyMuPDF original_size = os.path.getsize(file_path) doc = fitz.open(file_path) if remove_images: images_removed = 0 for page in doc: for img in page.get_images(full=True): try: page.delete_image(img[0]) images_removed += 1 except Exception: pass if images_removed: logger.info("Removed %d images from PDF", images_removed) out_path = output_path or file_path doc.ez_save(out_path) doc.close() compressed_size = os.path.getsize(out_path) reduction = ( 100 - (compressed_size * 100 // original_size) if original_size else 0 ) logger.info( "Compressed PDF: %d -> %d bytes (%d%% reduction)", original_size, compressed_size, reduction, ) return out_path, original_size, compressed_size
# --------------------------------------------------------------------------- # Text chunking # ---------------------------------------------------------------------------
[docs] def chunk_text( text: str, chunk_size: int = DEFAULT_CHUNK_SIZE, overlap: int = DEFAULT_CHUNK_OVERLAP, ) -> List[str]: """Split *text* into overlapping chunks on paragraph/sentence boundaries.""" if len(text) <= chunk_size: return [text] chunks: List[str] = [] paragraphs = re.split(r"\n\n+", text) current_chunk = "" for para in paragraphs: if len(current_chunk) + len(para) + 2 > chunk_size: if current_chunk: chunks.append(current_chunk.strip()) current_chunk = ( current_chunk[-overlap:] if len(current_chunk) > overlap else "" ) if len(para) > chunk_size: sentences = re.split(r"(?<=[.!?])\s+", para) for sentence in sentences: if len(current_chunk) + len(sentence) + 1 > chunk_size: if current_chunk: chunks.append(current_chunk.strip()) current_chunk = ( current_chunk[-overlap:] if len(current_chunk) > overlap else "" ) if len(sentence) > chunk_size: for i in range( 0, len(sentence), chunk_size - overlap, ): chunks.append( sentence[i:i + chunk_size], ) else: current_chunk = sentence else: current_chunk += ( (" " + sentence) if current_chunk else sentence ) else: current_chunk = para else: current_chunk += ("\n\n" + para) if current_chunk else para if current_chunk.strip(): chunks.append(current_chunk.strip()) return chunks
# --------------------------------------------------------------------------- # URL fetching # ---------------------------------------------------------------------------
[docs] async def fetch_url_content( url: str, timeout: float = 30.0, ) -> Tuple[Optional[bytes], Optional[str], Optional[str]]: """Fetch content from *url*. Returns ``(bytes, content_type, filename)``. """ try: async with httpx.AsyncClient( timeout=timeout, follow_redirects=True, ) as client: response = await client.get(url) response.raise_for_status() content = response.content content_type = ( response.headers.get("content-type", "") .split(";")[0].strip() ) filename = None cd = response.headers.get("content-disposition", "") if "filename=" in cd: match = re.search(r'filename[*]?=["\']?([^"\';\s]+)', cd) if match: filename = unquote(match.group(1)) if not filename: parsed = urlparse(url) path = unquote(parsed.path) if path and "/" in path: filename = path.split("/")[-1] if not filename or "." not in filename: filename = None if not filename: ext_map = { "application/pdf": "document.pdf", "text/plain": "document.txt", "application/json": "data.json", "text/html": "page.html", } filename = ext_map.get(content_type, "downloaded_file") return content, content_type, filename except Exception as e: logger.error("Error fetching %s: %s", url, e) return None, None, None
# --------------------------------------------------------------------------- # FileRAGManager # ---------------------------------------------------------------------------
[docs] class FileRAGManager: """File-based RAG with ChromaDB storage and OpenRouter embeddings."""
[docs] def __init__( self, store_name: str = "default", store_path: Optional[str] = None, api_key: Optional[str] = None, embedding_model: str = "google/gemini-embedding-001", max_file_size: int = MAX_FILE_SIZE, gemini_only: bool = True, document_task_type: Optional[str] = None, query_task_type: Optional[str] = None, ): """Initialize the instance. Args: store_name (str): The store name value. store_path (Optional[str]): The store path value. api_key (Optional[str]): The api key value. embedding_model (str): The embedding model value. max_file_size (int): The max file size value. gemini_only (bool): Use only the Gemini API for embeddings. document_task_type: Optional Gemini task type for indexed text (e.g. ``RETRIEVAL_DOCUMENT``). query_task_type: Optional Gemini task type for search queries (e.g. ``RETRIEVAL_QUERY``). """ self.store_name = store_name self.store_path = store_path or DEFAULT_STORE_PATH self.embedding_model = embedding_model self.max_file_size = max_file_size self._sanitized_name = _sanitize_collection_name(store_name) self._collection_name = f"files_{self._sanitized_name}" self.db_path = os.path.join(self.store_path, self._sanitized_name) os.makedirs(self.db_path, exist_ok=True) self.files_path = os.path.join(self.db_path, STORE_FILES_SUBDIR) os.makedirs(self.files_path, exist_ok=True) self.embedding_fn = SyncOpenRouterEmbeddings( api_key=api_key, model=embedding_model, gemini_only=gemini_only, document_task_type=document_task_type, query_task_type=query_task_type, ) # 🔥 use shared registry -- prevents singleton collision from chroma_registry import get_client self.client = get_client( self.db_path, settings=Settings(anonymized_telemetry=False, allow_reset=True), ) self.collection = self.client.get_or_create_collection( name=self._collection_name, embedding_function=self.embedding_fn, metadata={ "description": f"File RAG store: {store_name}", "created_at": datetime.now(timezone.utc).isoformat(), "embedding_model": embedding_model, }, ) logger.info( "Initialized FileRAGManager '%s' (collection: %s) at %s", store_name, self._collection_name, self.db_path, )
# -- helpers ------------------------------------------------------------- @staticmethod def _compute_file_hash(content: str) -> str: """Internal helper: compute file hash. Args: content (str): Content data. Returns: str: Result string. """ return hashlib.sha256(content.encode("utf-8")).hexdigest() @staticmethod def _get_file_id(file_path: str) -> str: """Internal helper: get file id. Args: file_path (str): The file path value. Returns: str: Result string. """ normalized = os.path.normpath(os.path.abspath(file_path)) return hashlib.md5(normalized.encode("utf-8")).hexdigest() @staticmethod def _is_supported_file(file_path: str) -> bool: """Internal helper: is supported file. Args: file_path (str): The file path value. Returns: bool: True on success, False otherwise. """ ext = Path(file_path).suffix.lower() filename = Path(file_path).name.lower() known_files = { "dockerfile", "makefile", "readme", "license", "changelog", } return ext in SUPPORTED_EXTENSIONS or filename in known_files def _read_file_content(self, file_path: str) -> Optional[str]: """Internal helper: read file content. Args: file_path (str): The file path value. Returns: Optional[str]: The result. """ try: file_size = os.path.getsize(file_path) is_pdf = file_path.lower().endswith(".pdf") if file_size > self.max_file_size: if is_pdf: logger.info( "PDF exceeds size limit, attempting compression...", ) compressed_path = file_path + ".compressed.pdf" try: _, _, compressed_size = compress_pdf( file_path, compressed_path, ) if compressed_size <= self.max_file_size: content = extract_pdf_text(compressed_path) try: os.remove(compressed_path) except OSError: pass return content try: os.remove(compressed_path) except OSError: pass except Exception: pass return None logger.warning( "File too large (>%d bytes): %s", self.max_file_size, file_path, ) return None if is_pdf: return extract_pdf_text(file_path) try: with open(file_path, "r", encoding="utf-8") as f: return f.read() except UnicodeDecodeError: with open(file_path, "r", encoding="latin-1") as f: return f.read() except Exception as e: logger.error("Failed to read file %s: %s", file_path, e) return None @staticmethod def _create_embedding_text(file_path: str, content: str) -> str: """Internal helper: create embedding text. Args: file_path (str): The file path value. content (str): Content data. Returns: str: Result string. """ path_parts = Path(file_path).parts filename = Path(file_path).name ext = Path(file_path).suffix return ( f"File: {filename}\nPath: {'/'.join(path_parts[-3:])}\n" f"Type: {ext or 'unknown'}\n\n{content}" ) # -- index ---------------------------------------------------------------
[docs] def index_file( self, file_path: str, tags: Optional[List[str]] = None, use_chunking: bool = True, chunk_size: int = DEFAULT_CHUNK_SIZE, chunk_overlap: int = DEFAULT_CHUNK_OVERLAP, force: bool = False, ) -> Dict[str, Any]: """Index a single file into the collection. When *force* is True the content-hash dedup check is skipped so the file is always re-embedded (but the store is **not** cleared). """ file_path = os.path.abspath(file_path) if not os.path.exists(file_path): return {"success": False, "error": f"File not found: {file_path}"} if not os.path.isfile(file_path): return {"success": False, "error": f"Not a file: {file_path}"} if not self._is_supported_file(file_path): return { "success": False, "error": f"Unsupported file type: {file_path}", } content = self._read_file_content(file_path) if content is None: return { "success": False, "error": f"Failed to read file: {file_path}", } file_id = self._get_file_id(file_path) content_hash = self._compute_file_hash(content) try: existing = self.collection.get( where={"file_path": file_path}, include=["metadatas"], ) if existing and existing.get("metadatas"): if not force: meta0 = existing["metadatas"][0] if meta0.get("content_hash") == content_hash: return { "success": True, "action": "skipped", "file_path": file_path, "reason": "content unchanged", } old_ids = existing.get("ids", []) if old_ids: self.collection.delete(ids=old_ids) except Exception: pass base_metadata: Dict[str, Any] = { "file_path": file_path, "filename": Path(file_path).name, "extension": Path(file_path).suffix, "content_hash": content_hash, "file_size": len(content), "indexed_at": datetime.now(timezone.utc).isoformat(), "tags": json.dumps(tags or []), "store_name": self.store_name, "source_type": "local", } try: if use_chunking and len(content) > chunk_size: chunks = [ c for c in chunk_text(content, chunk_size, chunk_overlap) if c and c.strip() ] if not chunks: chunks = [ content[:chunk_size] if len(content) > chunk_size else content ] ids, documents, metadatas = [], [], [] for i, chk in enumerate(chunks): chunk_meta = base_metadata.copy() chunk_meta.update({ "chunk_index": i, "chunk_count": len(chunks), "is_chunked": True, }) ids.append(f"{file_id}_chunk_{i}") documents.append( self._create_embedding_text(file_path, chk), ) metadatas.append(chunk_meta) for start in range(0, len(ids), CHROMA_MAX_BATCH): end = start + CHROMA_MAX_BATCH self.collection.upsert( ids=ids[start:end], documents=documents[start:end], metadatas=metadatas[start:end], ) logger.info( "Indexed file with %d chunks: %s", len(chunks), file_path, ) return { "success": True, "action": "indexed", "file_id": file_id, "file_path": file_path, "file_size": len(content), "chunk_count": len(chunks), } base_metadata.update({ "is_chunked": False, "chunk_index": 0, "chunk_count": 1, }) self.collection.upsert( ids=[file_id], documents=[self._create_embedding_text(file_path, content)], metadatas=[base_metadata], ) logger.info("Indexed file: %s", file_path) return { "success": True, "action": "indexed", "file_id": file_id, "file_path": file_path, "file_size": len(content), } except Exception as e: logger.error( "Failed to index file %s: %s", file_path, e, exc_info=True, ) return {"success": False, "error": str(e)}
[docs] async def index_url( self, url: str, tags: Optional[List[str]] = None, use_chunking: bool = True, chunk_size: int = DEFAULT_CHUNK_SIZE, chunk_overlap: int = DEFAULT_CHUNK_OVERLAP, ) -> Dict[str, Any]: """Index url. Args: url (str): URL string. tags (Optional[List[str]]): The tags value. use_chunking (bool): The use chunking value. chunk_size (int): The chunk size value. chunk_overlap (int): The chunk overlap value. Returns: Dict[str, Any]: The result. """ content_bytes, content_type, filename = await fetch_url_content(url) if content_bytes is None: return {"success": False, "error": f"Failed to fetch URL: {url}"} url_hash = hashlib.md5(url.encode()).hexdigest()[:12] ext = Path(filename).suffix if filename else "" if not ext: ext_map = { "application/pdf": ".pdf", "text/plain": ".txt", "application/json": ".json", "text/html": ".html", } ext = ext_map.get(content_type, ".txt") is_pdf = ext.lower() == ".pdf" file_size = len(content_bytes) if file_size > self.max_file_size and not is_pdf: return { "success": False, "error": ( f"File too large: {file_size} bytes " f"(limit: {self.max_file_size})" ), } base_filename = Path(filename).name if filename else "downloaded" base_filename = re.sub(r'[/\\:*?"<>|]', "_", base_filename) stored_filename = f"{url_hash}_{base_filename}" if not stored_filename.endswith(ext): stored_filename += ext stored_path = os.path.join(self.files_path, stored_filename) try: with open(stored_path, "wb") as f: f.write(content_bytes) except Exception as e: return {"success": False, "error": f"Failed to save file: {e}"} if is_pdf: if file_size > self.max_file_size: compressed_path = stored_path + ".compressed.pdf" try: _, _, compressed_size = compress_pdf( stored_path, compressed_path, ) if compressed_size <= self.max_file_size: os.replace(compressed_path, stored_path) else: for p in (compressed_path, stored_path): try: os.remove(p) except OSError: pass return { "success": False, "error": ( f"PDF still too large after compression: " f"{compressed_size}" ), } except Exception as e: try: os.remove(stored_path) except OSError: pass return { "success": False, "error": f"PDF compression failed: {e}", } content = extract_pdf_text(stored_path) if content is None: return { "success": False, "error": "Failed to extract text from PDF", } else: try: content = content_bytes.decode("utf-8") except UnicodeDecodeError: try: content = content_bytes.decode("latin-1") except Exception as e: return { "success": False, "error": f"Failed to decode content: {e}", } content_hash = self._compute_file_hash(content) file_id = hashlib.md5(url.encode()).hexdigest() try: existing = self.collection.get( where={"source_url": url}, include=["metadatas"], ) if existing and existing.get("metadatas"): meta0 = existing["metadatas"][0] if meta0.get("content_hash") == content_hash: return { "success": True, "action": "skipped", "url": url, "reason": "content unchanged", } old_ids = existing.get("ids", []) if old_ids: self.collection.delete(ids=old_ids) except Exception: pass base_metadata: Dict[str, Any] = { "file_path": stored_path, "filename": stored_filename, "extension": ext, "content_hash": content_hash, "file_size": len(content), "indexed_at": datetime.now(timezone.utc).isoformat(), "tags": json.dumps(tags or []), "store_name": self.store_name, "source_type": "url", "source_url": url, } try: if use_chunking and len(content) > chunk_size: chunks = [ c for c in chunk_text(content, chunk_size, chunk_overlap) if c and c.strip() ] if not chunks: chunks = [ content[:chunk_size] if len(content) > chunk_size else content ] ids, documents, metadatas = [], [], [] for i, chk in enumerate(chunks): chunk_meta = base_metadata.copy() chunk_meta.update({ "chunk_index": i, "chunk_count": len(chunks), "is_chunked": True, }) ids.append(f"{file_id}_chunk_{i}") doc_text = f"File: {stored_filename}\nURL: {url}\n" doc_text += f"Type: {ext}\n\n{chk}" documents.append(doc_text) metadatas.append(chunk_meta) self.collection.upsert( ids=ids, documents=documents, metadatas=metadatas, ) logger.info( "Indexed URL with %d chunks: %s", len(chunks), url, ) return { "success": True, "action": "indexed", "url": url, "filename": stored_filename, "file_size": len(content), "chunk_count": len(chunks), "stored_path": stored_path, } base_metadata.update({ "is_chunked": False, "chunk_index": 0, "chunk_count": 1, }) doc_text = f"File: {stored_filename}\nURL: {url}\n" doc_text += f"Type: {ext}\n\n{content}" self.collection.upsert( ids=[file_id], documents=[doc_text], metadatas=[base_metadata], ) logger.info("Indexed URL: %s", url) return { "success": True, "action": "indexed", "url": url, "filename": stored_filename, "file_size": len(content), "stored_path": stored_path, } except Exception as e: logger.error("Failed to index URL %s: %s", url, e, exc_info=True) return {"success": False, "error": str(e)}
[docs] def index_directory( self, directory_path: str, recursive: bool = True, tags: Optional[List[str]] = None, exclude_patterns: Optional[List[str]] = None, max_workers: int = 6, force: bool = False, allowed_extensions: Optional[Collection[str]] = None, ) -> Dict[str, Any]: """Index all supported files in *directory_path*. When *max_workers* > 1, files are indexed concurrently using a thread pool. Each file's embedding batches are already parallelised inside the embedding function, so even ``max_workers=1`` benefits from concurrent API calls. *force* bypasses the per-file content-hash dedup check without clearing the store, so already-indexed files get re-embedded. When *allowed_extensions* is set, only files whose suffix (after normalizing to a leading dot, lowercase) appears in the collection are queued; ``None`` means no extension filter (all supported types under *SUPPORTED_EXTENSIONS*). """ directory_path = os.path.abspath(directory_path) if not os.path.isdir(directory_path): return { "success": False, "error": f"Not a directory: {directory_path}", } exclude_patterns = exclude_patterns or [ "__pycache__", ".git", "node_modules", ".venv", "venv", ".env", "*.pyc", "*.pyo", ".DS_Store", ] results: Dict[str, Any] = { "success": True, "indexed": 0, "skipped": 0, "failed": 0, "files": [], } def should_exclude(path: str) -> bool: path_lower = path.lower() for pat in exclude_patterns: # type: ignore[union-attr] if pat.startswith("*"): if path_lower.endswith(pat[1:]): return True elif pat in path: return True return False file_paths: List[str] = [] walker = ( os.walk(directory_path) if recursive else [(directory_path, [], os.listdir(directory_path))] ) for root, dirs, files in walker: if recursive: dirs[:] = [d for d in dirs if not should_exclude(d)] for fname in files: file_path = os.path.join(root, fname) if should_exclude(file_path): continue if not recursive and not os.path.isfile(file_path): continue file_paths.append(file_path) if allowed_extensions is not None: ext_set: set[str] = set() for raw in allowed_extensions: e = raw.strip().lower() if not e: continue if not e.startswith("."): e = "." + e ext_set.add(e) if ext_set: file_paths = [ fp for fp in file_paths if Path(fp).suffix.lower() in ext_set ] total_files = len(file_paths) done_count = 0 def _record(result: Dict[str, Any], lock=None) -> None: nonlocal done_count if lock: lock.acquire() try: results["files"].append(result) if result.get("success"): key = ( "indexed" if result.get("action") == "indexed" else "skipped" ) results[key] += 1 else: results["failed"] += 1 done_count += 1 if done_count % 25 == 0 or done_count == total_files: logger.info( "Progress: %d / %d files (%d indexed, %d skipped, %d failed)", done_count, total_files, results["indexed"], results["skipped"], results["failed"], ) finally: if lock: lock.release() if max_workers > 1 and len(file_paths) > 1: import threading from concurrent.futures import ThreadPoolExecutor, as_completed lock = threading.Lock() def _index_one(fp: str) -> Dict[str, Any]: return self.index_file(fp, tags=tags, force=force) workers = min(max_workers, len(file_paths)) with ThreadPoolExecutor(max_workers=workers) as pool: futures = { pool.submit(_index_one, fp): fp for fp in file_paths } for future in as_completed(futures): _record(future.result(), lock=lock) else: for file_path in file_paths: result = self.index_file(file_path, tags=tags, force=force) _record(result) logger.info( "Directory indexing complete: %d indexed, %d skipped, %d failed", results["indexed"], results["skipped"], results["failed"], ) return results
# -- search -------------------------------------------------------------- @staticmethod def _merge_chunks( chunks: List[tuple], max_content_size: int, ) -> str: """Merge matched chunks into a single string. *chunks* is a list of ``(chunk_index, chunk_text)`` tuples sorted by ``chunk_index``. Adjacent chunks (index differs by 1) are joined directly; non-adjacent chunks are separated by a ``[...]`` marker so the reader knows content was skipped. """ if not chunks: return "" merged_parts: List[str] = [] current_size = 0 prev_idx: int | None = None for idx, text in chunks: text = text.strip() if not text: continue if current_size + len(text) + 10 > max_content_size: remaining = max_content_size - current_size - 10 if remaining > 200: merged_parts.append(text[:remaining].rstrip()) break if prev_idx is not None: if idx == prev_idx + 1: merged_parts.append(text) else: merged_parts.append("\n\n[...]\n\n" + text) else: merged_parts.append(text) current_size += len(text) + 10 prev_idx = idx result = "\n".join(merged_parts) return result
[docs] def search( self, query: str, n_results: int = 5, tags: Optional[List[str]] = None, return_content: bool = True, query_embedding: list[float] | None = None, max_content_size: int = 8000, ) -> List[Dict[str, Any]]: """Semantic search returning relevant chunks per file. Instead of returning entire file contents, this collects the matching chunk texts that ChromaDB found and merges them (respecting ``max_content_size``). Small files whose full text fits within one chunk are returned in full automatically. Args: query: Natural-language search query. n_results: Maximum number of files to return. tags: Optional tag filter. return_content: Include chunk text in results. query_embedding: Pre-computed query embedding (skips ChromaDB's internal embedding call). max_content_size: Maximum characters of merged chunk text to return per file (default 8000). """ try: where_filter = None if tags: where_filter = { "$or": [{"tags": {"$contains": tag}} for tag in tags], } query_kwargs: Dict[str, Any] = { "n_results": n_results * 5, "where": where_filter, "include": ["metadatas", "distances", "documents"], } if query_embedding is not None: query_kwargs["query_embeddings"] = [query_embedding] else: query_kwargs["query_texts"] = [query] results = self.collection.query(**query_kwargs) if not results or not results.get("metadatas"): return [] documents = results.get("documents", [[]])[0] file_chunks: Dict[str, Dict[str, Any]] = {} for i, metadata in enumerate(results["metadatas"][0]): file_path = metadata.get("file_path", "") distance = ( results["distances"][0][i] if results.get("distances") else 1.0 ) chunk_index = int(metadata.get("chunk_index", 0)) chunk_text = documents[i] if i < len(documents) else "" if file_path not in file_chunks: file_chunks[file_path] = { "metadata": metadata, "best_distance": distance, "chunks": [], } entry = file_chunks[file_path] if distance < entry["best_distance"]: entry["best_distance"] = distance entry["metadata"] = metadata entry["chunks"].append((distance, chunk_index, chunk_text)) output = [] sorted_files = sorted( file_chunks.items(), key=lambda x: x[1]["best_distance"], ) for file_path, data in sorted_files[:n_results]: metadata = data["metadata"] best_distance = data["best_distance"] entry: Dict[str, Any] = { "file_path": file_path, "filename": metadata.get("filename", ""), "extension": metadata.get("extension", ""), "file_size": metadata.get("file_size", 0), "indexed_at": metadata.get("indexed_at", ""), "tags": json.loads(metadata.get("tags", "[]")), "similarity_score": ( 1.0 - best_distance if best_distance is not None else None ), "source_type": metadata.get("source_type", "local"), "source_url": metadata.get("source_url", None), } if return_content: ranked = sorted( data["chunks"], key=lambda c: c[0], ) by_index = sorted( [(ci, ct) for _, ci, ct in ranked], key=lambda x: x[0], ) seen_indices: set = set() deduped = [] for ci, ct in by_index: if ci not in seen_indices: seen_indices.add(ci) deduped.append((ci, ct)) content = self._merge_chunks(deduped, max_content_size) entry["content"] = ( content if content else "[No chunk text available]" ) output.append(entry) logger.info( "Search query '%s...' returned %d unique files", query[:50], len(output), ) return output except Exception as e: logger.error("Search failed: %s", e, exc_info=True) return []
# -- remove / list / stats -----------------------------------------------
[docs] def remove_file(self, file_path: str) -> Dict[str, Any]: """Delete the specified file. Args: file_path (str): The file path value. Returns: Dict[str, Any]: The result. """ file_path = os.path.abspath(file_path) try: existing = self.collection.get( where={"file_path": file_path}, include=["metadatas"], ) if existing and existing.get("ids"): self.collection.delete(ids=existing["ids"]) return { "success": True, "file_path": file_path, "entries_removed": len(existing["ids"]), } return { "success": False, "error": f"File not found in index: {file_path}", } except Exception as e: return {"success": False, "error": str(e)}
[docs] def remove_url(self, url: str) -> Dict[str, Any]: """Delete the specified url. Args: url (str): URL string. Returns: Dict[str, Any]: The result. """ try: existing = self.collection.get( where={"source_url": url}, include=["metadatas"], ) if existing and existing.get("ids"): self.collection.delete(ids=existing["ids"]) return { "success": True, "url": url, "entries_removed": len(existing["ids"]), } return { "success": False, "error": f"URL not found in index: {url}", } except Exception as e: return {"success": False, "error": str(e)}
[docs] def list_indexed_files(self, limit: int = 100) -> List[Dict[str, Any]]: """List indexed files. Args: limit (int): Maximum number of items. Returns: List[Dict[str, Any]]: The result. """ try: results = self.collection.get(limit=limit, include=["metadatas"]) return [ { "file_path": m.get("file_path", ""), "filename": m.get("filename", ""), "extension": m.get("extension", ""), "file_size": m.get("file_size", 0), "indexed_at": m.get("indexed_at", ""), "tags": json.loads(m.get("tags", "[]")), } for m in results.get("metadatas", []) ] except Exception as e: logger.error("Failed to list files: %s", e) return []
[docs] def list_store_files(self) -> List[Dict[str, Any]]: """List store files. Returns: List[Dict[str, Any]]: The result. """ files = [] if not os.path.exists(self.files_path): return files for fname in os.listdir(self.files_path): fp = os.path.join(self.files_path, fname) if os.path.isfile(fp): try: st = os.stat(fp) files.append({ "filename": fname, "size": st.st_size, "modified": datetime.fromtimestamp( st.st_mtime, tz=timezone.utc, ).isoformat(), "path": fp, }) except Exception: pass return sorted(files, key=lambda x: x.get("filename", ""))
[docs] def read_store_file(self, filename: str) -> Dict[str, Any]: """Read store file. Args: filename (str): The filename value. Returns: Dict[str, Any]: The result. """ if "/" in filename or "\\" in filename or ".." in filename: return {"success": False, "error": "Invalid filename."} fp = os.path.join(self.files_path, filename) if not os.path.isfile(fp): return {"success": False, "error": f"File not found: {filename}"} try: if filename.lower().endswith(".pdf"): content = extract_pdf_text(fp) if content is None: return { "success": False, "error": "Failed to extract text from PDF", } else: try: with open(fp, "r", encoding="utf-8") as f: content = f.read() except UnicodeDecodeError: with open(fp, "r", encoding="latin-1") as f: content = f.read() return { "success": True, "filename": filename, "content": content, "size": len(content), } except Exception as e: return {"success": False, "error": str(e)}
[docs] def close(self) -> None: """Release the ChromaDB client and its underlying SQLite resources.""" try: if hasattr(self.client, "close"): self.client.close() except Exception: logger.debug( "Error closing ChromaDB client for '%s'", self.store_name, exc_info=True, )
[docs] def get_stats(self) -> Dict[str, Any]: """Retrieve the stats. Returns: Dict[str, Any]: The result. """ try: return { "store_name": self.store_name, "store_path": self.db_path, "file_count": self.collection.count(), "embedding_model": self.embedding_model, } except Exception as e: return {"error": str(e)}
[docs] def clear(self) -> Dict[str, Any]: """Clear. Returns: Dict[str, Any]: The result. """ try: self.client.delete_collection(name=self._collection_name) self.collection = self.client.create_collection( name=self._collection_name, embedding_function=self.embedding_fn, metadata={ "description": f"File RAG store: {self.store_name}", "created_at": datetime.now(timezone.utc).isoformat(), "embedding_model": self.embedding_model, }, ) logger.info("Cleared RAG store: %s", self.store_name) return { "success": True, "message": f"Store '{self.store_name}' cleared", } except Exception as e: return {"success": False, "error": str(e)}
# --------------------------------------------------------------------------- # Global store registry (LRU-bounded) # --------------------------------------------------------------------------- _STORE_REGISTRY_MAX_SIZE = int(os.environ.get("RAG_STORE_CACHE_SIZE", "5")) _store_registry: OrderedDict[tuple, FileRAGManager] = OrderedDict() # Sphinx-generated docs store: Gemini retrieval task types for index vs query. STARGAZER_DOCS_STORE_NAME = "stargazer_docs" STARGAZER_DOCS_DOCUMENT_TASK = "RETRIEVAL_DOCUMENT" STARGAZER_DOCS_QUERY_TASK = "RETRIEVAL_QUERY"
[docs] def get_rag_store( store_name: str = "default", api_key: Optional[str] = None, max_file_size: Optional[int] = None, gemini_only: bool = True, document_task_type: Optional[str] = None, query_task_type: Optional[str] = None, ) -> FileRAGManager: """Get or create a RAG store by name (LRU-cached). At most ``_STORE_REGISTRY_MAX_SIZE`` stores are kept open simultaneously. When a new store would exceed the limit the least recently used entry is closed and evicted. Cache entries are keyed by ``store_name`` plus optional embedding task types so different embedding configurations do not share one client. """ cache_key = ( store_name, document_task_type, query_task_type, ) if cache_key in _store_registry: _store_registry.move_to_end(cache_key) return _store_registry[cache_key] kwargs: Dict[str, Any] = {"store_name": store_name, "api_key": api_key} if max_file_size is not None: kwargs["max_file_size"] = max_file_size if gemini_only: kwargs["gemini_only"] = True if document_task_type is not None: kwargs["document_task_type"] = document_task_type if query_task_type is not None: kwargs["query_task_type"] = query_task_type store = FileRAGManager(**kwargs) _store_registry[cache_key] = store while len(_store_registry) > _STORE_REGISTRY_MAX_SIZE: _evicted_key, evicted_store = _store_registry.popitem(last=False) logger.info("Evicting RAG store %s from cache (LRU)", _evicted_key) evicted_store.close() return store
[docs] def get_stargazer_docs_store() -> FileRAGManager: """Return the shared RAG store for Sphinx / tool documentation. Uses ``RETRIEVAL_DOCUMENT`` for indexed chunks and ``RETRIEVAL_QUERY`` for search queries (Gemini embedding task types). """ return get_rag_store( STARGAZER_DOCS_STORE_NAME, document_task_type=STARGAZER_DOCS_DOCUMENT_TASK, query_task_type=STARGAZER_DOCS_QUERY_TASK, )
[docs] def list_rag_stores() -> List[str]: """List all available RAG store directory names.""" if not os.path.exists(DEFAULT_STORE_PATH): return [] return [ d for d in os.listdir(DEFAULT_STORE_PATH) if os.path.isdir(os.path.join(DEFAULT_STORE_PATH, d)) ]
[docs] def list_rag_stores_with_stats() -> List[Dict[str, Any]]: """List stores with file counts using only filesystem ops (no ChromaDB). Counts physical files in each store's ``files/`` subdirectory as a lightweight proxy for the indexed entry count. This never opens a ChromaDB client and therefore uses zero additional RAM. """ if not os.path.exists(DEFAULT_STORE_PATH): return [] stores: List[Dict[str, Any]] = [] for name in sorted(os.listdir(DEFAULT_STORE_PATH)): store_dir = os.path.join(DEFAULT_STORE_PATH, name) if not os.path.isdir(store_dir): continue files_dir = os.path.join(store_dir, STORE_FILES_SUBDIR) try: file_count = len(os.listdir(files_dir)) if os.path.isdir(files_dir) else 0 except OSError: file_count = 0 stores.append({"name": name, "file_count": file_count}) return stores
[docs] def delete_rag_store(store_name: str) -> Dict[str, Any]: """Delete a RAG store completely.""" import shutil sanitized_name = _sanitize_collection_name(store_name) store_path = os.path.join(DEFAULT_STORE_PATH, sanitized_name) if not os.path.exists(store_path): return { "success": False, "error": f"Store '{store_name}' does not exist", } try: keys_to_drop = [ k for k in list(_store_registry.keys()) if k[0] == store_name ] for k in keys_to_drop: evicted = _store_registry.pop(k, None) if evicted is not None: evicted.close() shutil.rmtree(store_path) logger.info("Deleted RAG store: %s (path: %s)", store_name, store_path) return { "success": True, "message": f"Store '{store_name}' deleted successfully", "path_deleted": store_path, } except Exception as e: return {"success": False, "error": str(e)}