"""File-based RAG Manager.
Manages file indexing and retrieval using ChromaDB and OpenRouter embeddings.
Returns entire files, not chunks, for complete context.
Features:
- URL fetching support
- PDF parsing via PyMuPDF
- Chunked embeddings for better semantic matching
- Full file retrieval on search
"""
import hashlib
import json
import logging
import os
import re
from collections import OrderedDict
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Collection, Dict, List, Optional, Tuple
from urllib.parse import unquote, urlparse
import chromadb
import httpx
from chromadb.config import Settings
from .openrouter_embeddings import SyncOpenRouterEmbeddings
logger = logging.getLogger(__name__)
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
def _sanitize_collection_name(name: str) -> str:
"""Sanitize *name* for use as a ChromaDB collection name.
ChromaDB requires 3-512 characters, only ``[a-zA-Z0-9._-]``,
must start and end with ``[a-zA-Z0-9]``.
"""
sanitized = re.sub(r"[^a-zA-Z0-9._-]", "_", name)
sanitized = re.sub(r"_+", "_", sanitized)
sanitized = sanitized.strip("._-")
if len(sanitized) < 3:
sanitized = f"store_{sanitized}" if sanitized else "store_default"
if len(sanitized) > 500:
sanitized = sanitized[:500].rstrip("._-")
return sanitized
DEFAULT_STORE_PATH = os.path.join(_PROJECT_ROOT, "rag_stores")
STORE_FILES_SUBDIR = "files"
SUPPORTED_EXTENSIONS = {
".txt", ".md", ".py", ".js", ".ts", ".jsx", ".tsx",
".json", ".yaml", ".yml", ".toml", ".ini", ".cfg",
".html", ".css", ".scss", ".less",
".c", ".cpp", ".h", ".hpp", ".rs", ".go", ".java", ".kt",
".sh", ".bash", ".zsh", ".fish",
".sql", ".graphql",
".xml", ".csv",
".r", ".R", ".rmd",
".lua", ".rb", ".php", ".pl", ".pm",
".dockerfile", ".makefile",
".env", ".gitignore", ".dockerignore",
".rst", ".tex", ".bib",
".pdf",
}
MAX_FILE_SIZE = 15 * 1024 * 1024 # 15 MB
DEFAULT_CHUNK_SIZE = 1500
DEFAULT_CHUNK_OVERLAP = 200
CHROMA_MAX_BATCH = 5000
# ---------------------------------------------------------------------------
# PDF helpers
# ---------------------------------------------------------------------------
[docs]
def compress_pdf(
file_path: str,
output_path: Optional[str] = None,
remove_images: bool = True,
) -> Tuple[str, int, int]:
"""Compress a PDF using PyMuPDF's ``ez_save()``."""
import fitz # PyMuPDF
original_size = os.path.getsize(file_path)
doc = fitz.open(file_path)
if remove_images:
images_removed = 0
for page in doc:
for img in page.get_images(full=True):
try:
page.delete_image(img[0])
images_removed += 1
except Exception:
pass
if images_removed:
logger.info("Removed %d images from PDF", images_removed)
out_path = output_path or file_path
doc.ez_save(out_path)
doc.close()
compressed_size = os.path.getsize(out_path)
reduction = (
100 - (compressed_size * 100 // original_size)
if original_size else 0
)
logger.info(
"Compressed PDF: %d -> %d bytes (%d%% reduction)",
original_size, compressed_size, reduction,
)
return out_path, original_size, compressed_size
# ---------------------------------------------------------------------------
# Text chunking
# ---------------------------------------------------------------------------
[docs]
def chunk_text(
text: str,
chunk_size: int = DEFAULT_CHUNK_SIZE,
overlap: int = DEFAULT_CHUNK_OVERLAP,
) -> List[str]:
"""Split *text* into overlapping chunks on paragraph/sentence
boundaries."""
if len(text) <= chunk_size:
return [text]
chunks: List[str] = []
paragraphs = re.split(r"\n\n+", text)
current_chunk = ""
for para in paragraphs:
if len(current_chunk) + len(para) + 2 > chunk_size:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = (
current_chunk[-overlap:]
if len(current_chunk) > overlap else ""
)
if len(para) > chunk_size:
sentences = re.split(r"(?<=[.!?])\s+", para)
for sentence in sentences:
if len(current_chunk) + len(sentence) + 1 > chunk_size:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = (
current_chunk[-overlap:]
if len(current_chunk) > overlap else ""
)
if len(sentence) > chunk_size:
for i in range(
0, len(sentence), chunk_size - overlap,
):
chunks.append(
sentence[i:i + chunk_size],
)
else:
current_chunk = sentence
else:
current_chunk += (
(" " + sentence) if current_chunk else sentence
)
else:
current_chunk = para
else:
current_chunk += ("\n\n" + para) if current_chunk else para
if current_chunk.strip():
chunks.append(current_chunk.strip())
return chunks
# ---------------------------------------------------------------------------
# URL fetching
# ---------------------------------------------------------------------------
[docs]
async def fetch_url_content(
url: str,
timeout: float = 30.0,
) -> Tuple[Optional[bytes], Optional[str], Optional[str]]:
"""Fetch content from *url*. Returns ``(bytes, content_type, filename)``.
"""
try:
async with httpx.AsyncClient(
timeout=timeout, follow_redirects=True,
) as client:
response = await client.get(url)
response.raise_for_status()
content = response.content
content_type = (
response.headers.get("content-type", "")
.split(";")[0].strip()
)
filename = None
cd = response.headers.get("content-disposition", "")
if "filename=" in cd:
match = re.search(r'filename[*]?=["\']?([^"\';\s]+)', cd)
if match:
filename = unquote(match.group(1))
if not filename:
parsed = urlparse(url)
path = unquote(parsed.path)
if path and "/" in path:
filename = path.split("/")[-1]
if not filename or "." not in filename:
filename = None
if not filename:
ext_map = {
"application/pdf": "document.pdf",
"text/plain": "document.txt",
"application/json": "data.json",
"text/html": "page.html",
}
filename = ext_map.get(content_type, "downloaded_file")
return content, content_type, filename
except Exception as e:
logger.error("Error fetching %s: %s", url, e)
return None, None, None
# ---------------------------------------------------------------------------
# FileRAGManager
# ---------------------------------------------------------------------------
[docs]
class FileRAGManager:
"""File-based RAG with ChromaDB storage and OpenRouter embeddings."""
[docs]
def __init__(
self,
store_name: str = "default",
store_path: Optional[str] = None,
api_key: Optional[str] = None,
embedding_model: str = "google/gemini-embedding-001",
max_file_size: int = MAX_FILE_SIZE,
gemini_only: bool = True,
document_task_type: Optional[str] = None,
query_task_type: Optional[str] = None,
):
"""Initialize the instance.
Args:
store_name (str): The store name value.
store_path (Optional[str]): The store path value.
api_key (Optional[str]): The api key value.
embedding_model (str): The embedding model value.
max_file_size (int): The max file size value.
gemini_only (bool): Use only the Gemini API for embeddings.
document_task_type: Optional Gemini task type for indexed text
(e.g. ``RETRIEVAL_DOCUMENT``).
query_task_type: Optional Gemini task type for search queries
(e.g. ``RETRIEVAL_QUERY``).
"""
self.store_name = store_name
self.store_path = store_path or DEFAULT_STORE_PATH
self.embedding_model = embedding_model
self.max_file_size = max_file_size
self._sanitized_name = _sanitize_collection_name(store_name)
self._collection_name = f"files_{self._sanitized_name}"
self.db_path = os.path.join(self.store_path, self._sanitized_name)
os.makedirs(self.db_path, exist_ok=True)
self.files_path = os.path.join(self.db_path, STORE_FILES_SUBDIR)
os.makedirs(self.files_path, exist_ok=True)
self.embedding_fn = SyncOpenRouterEmbeddings(
api_key=api_key,
model=embedding_model,
gemini_only=gemini_only,
document_task_type=document_task_type,
query_task_type=query_task_type,
)
# 🔥 use shared registry -- prevents singleton collision
from chroma_registry import get_client
self.client = get_client(
self.db_path,
settings=Settings(anonymized_telemetry=False, allow_reset=True),
)
self.collection = self.client.get_or_create_collection(
name=self._collection_name,
embedding_function=self.embedding_fn,
metadata={
"description": f"File RAG store: {store_name}",
"created_at": datetime.now(timezone.utc).isoformat(),
"embedding_model": embedding_model,
},
)
logger.info(
"Initialized FileRAGManager '%s' (collection: %s) at %s",
store_name, self._collection_name, self.db_path,
)
# -- helpers -------------------------------------------------------------
@staticmethod
def _compute_file_hash(content: str) -> str:
"""Internal helper: compute file hash.
Args:
content (str): Content data.
Returns:
str: Result string.
"""
return hashlib.sha256(content.encode("utf-8")).hexdigest()
@staticmethod
def _get_file_id(file_path: str) -> str:
"""Internal helper: get file id.
Args:
file_path (str): The file path value.
Returns:
str: Result string.
"""
normalized = os.path.normpath(os.path.abspath(file_path))
return hashlib.md5(normalized.encode("utf-8")).hexdigest()
@staticmethod
def _is_supported_file(file_path: str) -> bool:
"""Internal helper: is supported file.
Args:
file_path (str): The file path value.
Returns:
bool: True on success, False otherwise.
"""
ext = Path(file_path).suffix.lower()
filename = Path(file_path).name.lower()
known_files = {
"dockerfile", "makefile", "readme", "license", "changelog",
}
return ext in SUPPORTED_EXTENSIONS or filename in known_files
def _read_file_content(self, file_path: str) -> Optional[str]:
"""Internal helper: read file content.
Args:
file_path (str): The file path value.
Returns:
Optional[str]: The result.
"""
try:
file_size = os.path.getsize(file_path)
is_pdf = file_path.lower().endswith(".pdf")
if file_size > self.max_file_size:
if is_pdf:
logger.info(
"PDF exceeds size limit, attempting compression...",
)
compressed_path = file_path + ".compressed.pdf"
try:
_, _, compressed_size = compress_pdf(
file_path, compressed_path,
)
if compressed_size <= self.max_file_size:
content = extract_pdf_text(compressed_path)
try:
os.remove(compressed_path)
except OSError:
pass
return content
try:
os.remove(compressed_path)
except OSError:
pass
except Exception:
pass
return None
logger.warning(
"File too large (>%d bytes): %s",
self.max_file_size, file_path,
)
return None
if is_pdf:
return extract_pdf_text(file_path)
try:
with open(file_path, "r", encoding="utf-8") as f:
return f.read()
except UnicodeDecodeError:
with open(file_path, "r", encoding="latin-1") as f:
return f.read()
except Exception as e:
logger.error("Failed to read file %s: %s", file_path, e)
return None
@staticmethod
def _create_embedding_text(file_path: str, content: str) -> str:
"""Internal helper: create embedding text.
Args:
file_path (str): The file path value.
content (str): Content data.
Returns:
str: Result string.
"""
path_parts = Path(file_path).parts
filename = Path(file_path).name
ext = Path(file_path).suffix
return (
f"File: {filename}\nPath: {'/'.join(path_parts[-3:])}\n"
f"Type: {ext or 'unknown'}\n\n{content}"
)
# -- index ---------------------------------------------------------------
[docs]
def index_file(
self,
file_path: str,
tags: Optional[List[str]] = None,
use_chunking: bool = True,
chunk_size: int = DEFAULT_CHUNK_SIZE,
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
force: bool = False,
) -> Dict[str, Any]:
"""Index a single file into the collection.
When *force* is True the content-hash dedup check is skipped so
the file is always re-embedded (but the store is **not** cleared).
"""
file_path = os.path.abspath(file_path)
if not os.path.exists(file_path):
return {"success": False, "error": f"File not found: {file_path}"}
if not os.path.isfile(file_path):
return {"success": False, "error": f"Not a file: {file_path}"}
if not self._is_supported_file(file_path):
return {
"success": False,
"error": f"Unsupported file type: {file_path}",
}
content = self._read_file_content(file_path)
if content is None:
return {
"success": False,
"error": f"Failed to read file: {file_path}",
}
file_id = self._get_file_id(file_path)
content_hash = self._compute_file_hash(content)
try:
existing = self.collection.get(
where={"file_path": file_path}, include=["metadatas"],
)
if existing and existing.get("metadatas"):
if not force:
meta0 = existing["metadatas"][0]
if meta0.get("content_hash") == content_hash:
return {
"success": True, "action": "skipped",
"file_path": file_path,
"reason": "content unchanged",
}
old_ids = existing.get("ids", [])
if old_ids:
self.collection.delete(ids=old_ids)
except Exception:
pass
base_metadata: Dict[str, Any] = {
"file_path": file_path,
"filename": Path(file_path).name,
"extension": Path(file_path).suffix,
"content_hash": content_hash,
"file_size": len(content),
"indexed_at": datetime.now(timezone.utc).isoformat(),
"tags": json.dumps(tags or []),
"store_name": self.store_name,
"source_type": "local",
}
try:
if use_chunking and len(content) > chunk_size:
chunks = [
c for c in chunk_text(content, chunk_size, chunk_overlap)
if c and c.strip()
]
if not chunks:
chunks = [
content[:chunk_size]
if len(content) > chunk_size else content
]
ids, documents, metadatas = [], [], []
for i, chk in enumerate(chunks):
chunk_meta = base_metadata.copy()
chunk_meta.update({
"chunk_index": i,
"chunk_count": len(chunks),
"is_chunked": True,
})
ids.append(f"{file_id}_chunk_{i}")
documents.append(
self._create_embedding_text(file_path, chk),
)
metadatas.append(chunk_meta)
for start in range(0, len(ids), CHROMA_MAX_BATCH):
end = start + CHROMA_MAX_BATCH
self.collection.upsert(
ids=ids[start:end],
documents=documents[start:end],
metadatas=metadatas[start:end],
)
logger.info(
"Indexed file with %d chunks: %s", len(chunks), file_path,
)
return {
"success": True, "action": "indexed",
"file_id": file_id, "file_path": file_path,
"file_size": len(content), "chunk_count": len(chunks),
}
base_metadata.update({
"is_chunked": False, "chunk_index": 0, "chunk_count": 1,
})
self.collection.upsert(
ids=[file_id],
documents=[self._create_embedding_text(file_path, content)],
metadatas=[base_metadata],
)
logger.info("Indexed file: %s", file_path)
return {
"success": True, "action": "indexed",
"file_id": file_id, "file_path": file_path,
"file_size": len(content),
}
except Exception as e:
logger.error(
"Failed to index file %s: %s", file_path, e, exc_info=True,
)
return {"success": False, "error": str(e)}
[docs]
async def index_url(
self,
url: str,
tags: Optional[List[str]] = None,
use_chunking: bool = True,
chunk_size: int = DEFAULT_CHUNK_SIZE,
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
) -> Dict[str, Any]:
"""Index url.
Args:
url (str): URL string.
tags (Optional[List[str]]): The tags value.
use_chunking (bool): The use chunking value.
chunk_size (int): The chunk size value.
chunk_overlap (int): The chunk overlap value.
Returns:
Dict[str, Any]: The result.
"""
content_bytes, content_type, filename = await fetch_url_content(url)
if content_bytes is None:
return {"success": False, "error": f"Failed to fetch URL: {url}"}
url_hash = hashlib.md5(url.encode()).hexdigest()[:12]
ext = Path(filename).suffix if filename else ""
if not ext:
ext_map = {
"application/pdf": ".pdf",
"text/plain": ".txt",
"application/json": ".json",
"text/html": ".html",
}
ext = ext_map.get(content_type, ".txt")
is_pdf = ext.lower() == ".pdf"
file_size = len(content_bytes)
if file_size > self.max_file_size and not is_pdf:
return {
"success": False,
"error": (
f"File too large: {file_size} bytes "
f"(limit: {self.max_file_size})"
),
}
base_filename = Path(filename).name if filename else "downloaded"
base_filename = re.sub(r'[/\\:*?"<>|]', "_", base_filename)
stored_filename = f"{url_hash}_{base_filename}"
if not stored_filename.endswith(ext):
stored_filename += ext
stored_path = os.path.join(self.files_path, stored_filename)
try:
with open(stored_path, "wb") as f:
f.write(content_bytes)
except Exception as e:
return {"success": False, "error": f"Failed to save file: {e}"}
if is_pdf:
if file_size > self.max_file_size:
compressed_path = stored_path + ".compressed.pdf"
try:
_, _, compressed_size = compress_pdf(
stored_path, compressed_path,
)
if compressed_size <= self.max_file_size:
os.replace(compressed_path, stored_path)
else:
for p in (compressed_path, stored_path):
try:
os.remove(p)
except OSError:
pass
return {
"success": False,
"error": (
f"PDF still too large after compression: "
f"{compressed_size}"
),
}
except Exception as e:
try:
os.remove(stored_path)
except OSError:
pass
return {
"success": False,
"error": f"PDF compression failed: {e}",
}
content = extract_pdf_text(stored_path)
if content is None:
return {
"success": False,
"error": "Failed to extract text from PDF",
}
else:
try:
content = content_bytes.decode("utf-8")
except UnicodeDecodeError:
try:
content = content_bytes.decode("latin-1")
except Exception as e:
return {
"success": False,
"error": f"Failed to decode content: {e}",
}
content_hash = self._compute_file_hash(content)
file_id = hashlib.md5(url.encode()).hexdigest()
try:
existing = self.collection.get(
where={"source_url": url}, include=["metadatas"],
)
if existing and existing.get("metadatas"):
meta0 = existing["metadatas"][0]
if meta0.get("content_hash") == content_hash:
return {
"success": True, "action": "skipped",
"url": url, "reason": "content unchanged",
}
old_ids = existing.get("ids", [])
if old_ids:
self.collection.delete(ids=old_ids)
except Exception:
pass
base_metadata: Dict[str, Any] = {
"file_path": stored_path,
"filename": stored_filename,
"extension": ext,
"content_hash": content_hash,
"file_size": len(content),
"indexed_at": datetime.now(timezone.utc).isoformat(),
"tags": json.dumps(tags or []),
"store_name": self.store_name,
"source_type": "url",
"source_url": url,
}
try:
if use_chunking and len(content) > chunk_size:
chunks = [
c for c in chunk_text(content, chunk_size, chunk_overlap)
if c and c.strip()
]
if not chunks:
chunks = [
content[:chunk_size]
if len(content) > chunk_size else content
]
ids, documents, metadatas = [], [], []
for i, chk in enumerate(chunks):
chunk_meta = base_metadata.copy()
chunk_meta.update({
"chunk_index": i,
"chunk_count": len(chunks),
"is_chunked": True,
})
ids.append(f"{file_id}_chunk_{i}")
doc_text = f"File: {stored_filename}\nURL: {url}\n"
doc_text += f"Type: {ext}\n\n{chk}"
documents.append(doc_text)
metadatas.append(chunk_meta)
self.collection.upsert(
ids=ids, documents=documents, metadatas=metadatas,
)
logger.info(
"Indexed URL with %d chunks: %s", len(chunks), url,
)
return {
"success": True, "action": "indexed",
"url": url, "filename": stored_filename,
"file_size": len(content), "chunk_count": len(chunks),
"stored_path": stored_path,
}
base_metadata.update({
"is_chunked": False, "chunk_index": 0, "chunk_count": 1,
})
doc_text = f"File: {stored_filename}\nURL: {url}\n"
doc_text += f"Type: {ext}\n\n{content}"
self.collection.upsert(
ids=[file_id],
documents=[doc_text],
metadatas=[base_metadata],
)
logger.info("Indexed URL: %s", url)
return {
"success": True, "action": "indexed",
"url": url, "filename": stored_filename,
"file_size": len(content), "stored_path": stored_path,
}
except Exception as e:
logger.error("Failed to index URL %s: %s", url, e, exc_info=True)
return {"success": False, "error": str(e)}
[docs]
def index_directory(
self,
directory_path: str,
recursive: bool = True,
tags: Optional[List[str]] = None,
exclude_patterns: Optional[List[str]] = None,
max_workers: int = 6,
force: bool = False,
allowed_extensions: Optional[Collection[str]] = None,
) -> Dict[str, Any]:
"""Index all supported files in *directory_path*.
When *max_workers* > 1, files are indexed concurrently using a
thread pool. Each file's embedding batches are already
parallelised inside the embedding function, so even
``max_workers=1`` benefits from concurrent API calls.
*force* bypasses the per-file content-hash dedup check without
clearing the store, so already-indexed files get re-embedded.
When *allowed_extensions* is set, only files whose suffix (after
normalizing to a leading dot, lowercase) appears in the collection
are queued; ``None`` means no extension filter (all supported
types under *SUPPORTED_EXTENSIONS*).
"""
directory_path = os.path.abspath(directory_path)
if not os.path.isdir(directory_path):
return {
"success": False,
"error": f"Not a directory: {directory_path}",
}
exclude_patterns = exclude_patterns or [
"__pycache__", ".git", "node_modules", ".venv", "venv",
".env", "*.pyc", "*.pyo", ".DS_Store",
]
results: Dict[str, Any] = {
"success": True,
"indexed": 0,
"skipped": 0,
"failed": 0,
"files": [],
}
def should_exclude(path: str) -> bool:
path_lower = path.lower()
for pat in exclude_patterns: # type: ignore[union-attr]
if pat.startswith("*"):
if path_lower.endswith(pat[1:]):
return True
elif pat in path:
return True
return False
file_paths: List[str] = []
walker = (
os.walk(directory_path)
if recursive
else [(directory_path, [], os.listdir(directory_path))]
)
for root, dirs, files in walker:
if recursive:
dirs[:] = [d for d in dirs if not should_exclude(d)]
for fname in files:
file_path = os.path.join(root, fname)
if should_exclude(file_path):
continue
if not recursive and not os.path.isfile(file_path):
continue
file_paths.append(file_path)
if allowed_extensions is not None:
ext_set: set[str] = set()
for raw in allowed_extensions:
e = raw.strip().lower()
if not e:
continue
if not e.startswith("."):
e = "." + e
ext_set.add(e)
if ext_set:
file_paths = [
fp for fp in file_paths
if Path(fp).suffix.lower() in ext_set
]
total_files = len(file_paths)
done_count = 0
def _record(result: Dict[str, Any], lock=None) -> None:
nonlocal done_count
if lock:
lock.acquire()
try:
results["files"].append(result)
if result.get("success"):
key = (
"indexed"
if result.get("action") == "indexed"
else "skipped"
)
results[key] += 1
else:
results["failed"] += 1
done_count += 1
if done_count % 25 == 0 or done_count == total_files:
logger.info(
"Progress: %d / %d files (%d indexed, %d skipped, %d failed)",
done_count, total_files,
results["indexed"], results["skipped"], results["failed"],
)
finally:
if lock:
lock.release()
if max_workers > 1 and len(file_paths) > 1:
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
lock = threading.Lock()
def _index_one(fp: str) -> Dict[str, Any]:
return self.index_file(fp, tags=tags, force=force)
workers = min(max_workers, len(file_paths))
with ThreadPoolExecutor(max_workers=workers) as pool:
futures = {
pool.submit(_index_one, fp): fp for fp in file_paths
}
for future in as_completed(futures):
_record(future.result(), lock=lock)
else:
for file_path in file_paths:
result = self.index_file(file_path, tags=tags, force=force)
_record(result)
logger.info(
"Directory indexing complete: %d indexed, %d skipped, %d failed",
results["indexed"], results["skipped"], results["failed"],
)
return results
# -- search --------------------------------------------------------------
@staticmethod
def _merge_chunks(
chunks: List[tuple],
max_content_size: int,
) -> str:
"""Merge matched chunks into a single string.
*chunks* is a list of ``(chunk_index, chunk_text)`` tuples sorted
by ``chunk_index``. Adjacent chunks (index differs by 1) are
joined directly; non-adjacent chunks are separated by a ``[...]``
marker so the reader knows content was skipped.
"""
if not chunks:
return ""
merged_parts: List[str] = []
current_size = 0
prev_idx: int | None = None
for idx, text in chunks:
text = text.strip()
if not text:
continue
if current_size + len(text) + 10 > max_content_size:
remaining = max_content_size - current_size - 10
if remaining > 200:
merged_parts.append(text[:remaining].rstrip())
break
if prev_idx is not None:
if idx == prev_idx + 1:
merged_parts.append(text)
else:
merged_parts.append("\n\n[...]\n\n" + text)
else:
merged_parts.append(text)
current_size += len(text) + 10
prev_idx = idx
result = "\n".join(merged_parts)
return result
[docs]
def search(
self,
query: str,
n_results: int = 5,
tags: Optional[List[str]] = None,
return_content: bool = True,
query_embedding: list[float] | None = None,
max_content_size: int = 8000,
) -> List[Dict[str, Any]]:
"""Semantic search returning relevant chunks per file.
Instead of returning entire file contents, this collects the
matching chunk texts that ChromaDB found and merges them
(respecting ``max_content_size``). Small files whose full text
fits within one chunk are returned in full automatically.
Args:
query: Natural-language search query.
n_results: Maximum number of files to return.
tags: Optional tag filter.
return_content: Include chunk text in results.
query_embedding: Pre-computed query embedding (skips
ChromaDB's internal embedding call).
max_content_size: Maximum characters of merged chunk text
to return per file (default 8000).
"""
try:
where_filter = None
if tags:
where_filter = {
"$or": [{"tags": {"$contains": tag}} for tag in tags],
}
query_kwargs: Dict[str, Any] = {
"n_results": n_results * 5,
"where": where_filter,
"include": ["metadatas", "distances", "documents"],
}
if query_embedding is not None:
query_kwargs["query_embeddings"] = [query_embedding]
else:
query_kwargs["query_texts"] = [query]
results = self.collection.query(**query_kwargs)
if not results or not results.get("metadatas"):
return []
documents = results.get("documents", [[]])[0]
file_chunks: Dict[str, Dict[str, Any]] = {}
for i, metadata in enumerate(results["metadatas"][0]):
file_path = metadata.get("file_path", "")
distance = (
results["distances"][0][i]
if results.get("distances") else 1.0
)
chunk_index = int(metadata.get("chunk_index", 0))
chunk_text = documents[i] if i < len(documents) else ""
if file_path not in file_chunks:
file_chunks[file_path] = {
"metadata": metadata,
"best_distance": distance,
"chunks": [],
}
entry = file_chunks[file_path]
if distance < entry["best_distance"]:
entry["best_distance"] = distance
entry["metadata"] = metadata
entry["chunks"].append((distance, chunk_index, chunk_text))
output = []
sorted_files = sorted(
file_chunks.items(),
key=lambda x: x[1]["best_distance"],
)
for file_path, data in sorted_files[:n_results]:
metadata = data["metadata"]
best_distance = data["best_distance"]
entry: Dict[str, Any] = {
"file_path": file_path,
"filename": metadata.get("filename", ""),
"extension": metadata.get("extension", ""),
"file_size": metadata.get("file_size", 0),
"indexed_at": metadata.get("indexed_at", ""),
"tags": json.loads(metadata.get("tags", "[]")),
"similarity_score": (
1.0 - best_distance
if best_distance is not None else None
),
"source_type": metadata.get("source_type", "local"),
"source_url": metadata.get("source_url", None),
}
if return_content:
ranked = sorted(
data["chunks"], key=lambda c: c[0],
)
by_index = sorted(
[(ci, ct) for _, ci, ct in ranked],
key=lambda x: x[0],
)
seen_indices: set = set()
deduped = []
for ci, ct in by_index:
if ci not in seen_indices:
seen_indices.add(ci)
deduped.append((ci, ct))
content = self._merge_chunks(deduped, max_content_size)
entry["content"] = (
content if content else "[No chunk text available]"
)
output.append(entry)
logger.info(
"Search query '%s...' returned %d unique files",
query[:50], len(output),
)
return output
except Exception as e:
logger.error("Search failed: %s", e, exc_info=True)
return []
# -- remove / list / stats -----------------------------------------------
[docs]
def remove_file(self, file_path: str) -> Dict[str, Any]:
"""Delete the specified file.
Args:
file_path (str): The file path value.
Returns:
Dict[str, Any]: The result.
"""
file_path = os.path.abspath(file_path)
try:
existing = self.collection.get(
where={"file_path": file_path}, include=["metadatas"],
)
if existing and existing.get("ids"):
self.collection.delete(ids=existing["ids"])
return {
"success": True,
"file_path": file_path,
"entries_removed": len(existing["ids"]),
}
return {
"success": False,
"error": f"File not found in index: {file_path}",
}
except Exception as e:
return {"success": False, "error": str(e)}
[docs]
def remove_url(self, url: str) -> Dict[str, Any]:
"""Delete the specified url.
Args:
url (str): URL string.
Returns:
Dict[str, Any]: The result.
"""
try:
existing = self.collection.get(
where={"source_url": url}, include=["metadatas"],
)
if existing and existing.get("ids"):
self.collection.delete(ids=existing["ids"])
return {
"success": True,
"url": url,
"entries_removed": len(existing["ids"]),
}
return {
"success": False,
"error": f"URL not found in index: {url}",
}
except Exception as e:
return {"success": False, "error": str(e)}
[docs]
def list_indexed_files(self, limit: int = 100) -> List[Dict[str, Any]]:
"""List indexed files.
Args:
limit (int): Maximum number of items.
Returns:
List[Dict[str, Any]]: The result.
"""
try:
results = self.collection.get(limit=limit, include=["metadatas"])
return [
{
"file_path": m.get("file_path", ""),
"filename": m.get("filename", ""),
"extension": m.get("extension", ""),
"file_size": m.get("file_size", 0),
"indexed_at": m.get("indexed_at", ""),
"tags": json.loads(m.get("tags", "[]")),
}
for m in results.get("metadatas", [])
]
except Exception as e:
logger.error("Failed to list files: %s", e)
return []
[docs]
def list_store_files(self) -> List[Dict[str, Any]]:
"""List store files.
Returns:
List[Dict[str, Any]]: The result.
"""
files = []
if not os.path.exists(self.files_path):
return files
for fname in os.listdir(self.files_path):
fp = os.path.join(self.files_path, fname)
if os.path.isfile(fp):
try:
st = os.stat(fp)
files.append({
"filename": fname,
"size": st.st_size,
"modified": datetime.fromtimestamp(
st.st_mtime, tz=timezone.utc,
).isoformat(),
"path": fp,
})
except Exception:
pass
return sorted(files, key=lambda x: x.get("filename", ""))
[docs]
def read_store_file(self, filename: str) -> Dict[str, Any]:
"""Read store file.
Args:
filename (str): The filename value.
Returns:
Dict[str, Any]: The result.
"""
if "/" in filename or "\\" in filename or ".." in filename:
return {"success": False, "error": "Invalid filename."}
fp = os.path.join(self.files_path, filename)
if not os.path.isfile(fp):
return {"success": False, "error": f"File not found: {filename}"}
try:
if filename.lower().endswith(".pdf"):
content = extract_pdf_text(fp)
if content is None:
return {
"success": False,
"error": "Failed to extract text from PDF",
}
else:
try:
with open(fp, "r", encoding="utf-8") as f:
content = f.read()
except UnicodeDecodeError:
with open(fp, "r", encoding="latin-1") as f:
content = f.read()
return {
"success": True,
"filename": filename,
"content": content,
"size": len(content),
}
except Exception as e:
return {"success": False, "error": str(e)}
[docs]
def close(self) -> None:
"""Release the ChromaDB client and its underlying SQLite resources."""
try:
if hasattr(self.client, "close"):
self.client.close()
except Exception:
logger.debug(
"Error closing ChromaDB client for '%s'",
self.store_name, exc_info=True,
)
[docs]
def get_stats(self) -> Dict[str, Any]:
"""Retrieve the stats.
Returns:
Dict[str, Any]: The result.
"""
try:
return {
"store_name": self.store_name,
"store_path": self.db_path,
"file_count": self.collection.count(),
"embedding_model": self.embedding_model,
}
except Exception as e:
return {"error": str(e)}
[docs]
def clear(self) -> Dict[str, Any]:
"""Clear.
Returns:
Dict[str, Any]: The result.
"""
try:
self.client.delete_collection(name=self._collection_name)
self.collection = self.client.create_collection(
name=self._collection_name,
embedding_function=self.embedding_fn,
metadata={
"description": f"File RAG store: {self.store_name}",
"created_at": datetime.now(timezone.utc).isoformat(),
"embedding_model": self.embedding_model,
},
)
logger.info("Cleared RAG store: %s", self.store_name)
return {
"success": True,
"message": f"Store '{self.store_name}' cleared",
}
except Exception as e:
return {"success": False, "error": str(e)}
# ---------------------------------------------------------------------------
# Global store registry (LRU-bounded)
# ---------------------------------------------------------------------------
_STORE_REGISTRY_MAX_SIZE = int(os.environ.get("RAG_STORE_CACHE_SIZE", "5"))
_store_registry: OrderedDict[tuple, FileRAGManager] = OrderedDict()
# Sphinx-generated docs store: Gemini retrieval task types for index vs query.
STARGAZER_DOCS_STORE_NAME = "stargazer_docs"
STARGAZER_DOCS_DOCUMENT_TASK = "RETRIEVAL_DOCUMENT"
STARGAZER_DOCS_QUERY_TASK = "RETRIEVAL_QUERY"
[docs]
def get_rag_store(
store_name: str = "default",
api_key: Optional[str] = None,
max_file_size: Optional[int] = None,
gemini_only: bool = True,
document_task_type: Optional[str] = None,
query_task_type: Optional[str] = None,
) -> FileRAGManager:
"""Get or create a RAG store by name (LRU-cached).
At most ``_STORE_REGISTRY_MAX_SIZE`` stores are kept open
simultaneously. When a new store would exceed the limit the least
recently used entry is closed and evicted.
Cache entries are keyed by ``store_name`` plus optional embedding task
types so different embedding configurations do not share one client.
"""
cache_key = (
store_name,
document_task_type,
query_task_type,
)
if cache_key in _store_registry:
_store_registry.move_to_end(cache_key)
return _store_registry[cache_key]
kwargs: Dict[str, Any] = {"store_name": store_name, "api_key": api_key}
if max_file_size is not None:
kwargs["max_file_size"] = max_file_size
if gemini_only:
kwargs["gemini_only"] = True
if document_task_type is not None:
kwargs["document_task_type"] = document_task_type
if query_task_type is not None:
kwargs["query_task_type"] = query_task_type
store = FileRAGManager(**kwargs)
_store_registry[cache_key] = store
while len(_store_registry) > _STORE_REGISTRY_MAX_SIZE:
_evicted_key, evicted_store = _store_registry.popitem(last=False)
logger.info("Evicting RAG store %s from cache (LRU)", _evicted_key)
evicted_store.close()
return store
[docs]
def get_stargazer_docs_store() -> FileRAGManager:
"""Return the shared RAG store for Sphinx / tool documentation.
Uses ``RETRIEVAL_DOCUMENT`` for indexed chunks and ``RETRIEVAL_QUERY``
for search queries (Gemini embedding task types).
"""
return get_rag_store(
STARGAZER_DOCS_STORE_NAME,
document_task_type=STARGAZER_DOCS_DOCUMENT_TASK,
query_task_type=STARGAZER_DOCS_QUERY_TASK,
)
[docs]
def list_rag_stores() -> List[str]:
"""List all available RAG store directory names."""
if not os.path.exists(DEFAULT_STORE_PATH):
return []
return [
d for d in os.listdir(DEFAULT_STORE_PATH)
if os.path.isdir(os.path.join(DEFAULT_STORE_PATH, d))
]
[docs]
def list_rag_stores_with_stats() -> List[Dict[str, Any]]:
"""List stores with file counts using only filesystem ops (no ChromaDB).
Counts physical files in each store's ``files/`` subdirectory as a
lightweight proxy for the indexed entry count. This never opens a
ChromaDB client and therefore uses zero additional RAM.
"""
if not os.path.exists(DEFAULT_STORE_PATH):
return []
stores: List[Dict[str, Any]] = []
for name in sorted(os.listdir(DEFAULT_STORE_PATH)):
store_dir = os.path.join(DEFAULT_STORE_PATH, name)
if not os.path.isdir(store_dir):
continue
files_dir = os.path.join(store_dir, STORE_FILES_SUBDIR)
try:
file_count = len(os.listdir(files_dir)) if os.path.isdir(files_dir) else 0
except OSError:
file_count = 0
stores.append({"name": name, "file_count": file_count})
return stores
[docs]
def delete_rag_store(store_name: str) -> Dict[str, Any]:
"""Delete a RAG store completely."""
import shutil
sanitized_name = _sanitize_collection_name(store_name)
store_path = os.path.join(DEFAULT_STORE_PATH, sanitized_name)
if not os.path.exists(store_path):
return {
"success": False,
"error": f"Store '{store_name}' does not exist",
}
try:
keys_to_drop = [
k for k in list(_store_registry.keys())
if k[0] == store_name
]
for k in keys_to_drop:
evicted = _store_registry.pop(k, None)
if evicted is not None:
evicted.close()
shutil.rmtree(store_path)
logger.info("Deleted RAG store: %s (path: %s)", store_name, store_path)
return {
"success": True,
"message": f"Store '{store_name}' deleted successfully",
"path_deleted": store_path,
}
except Exception as e:
return {"success": False, "error": str(e)}