"""Redis-backed batched embedding queue for non-critical embedding generation.
Accumulates (redis_key, text) pairs in a Redis sorted set
(``embed_queue:pending``) and flushes them to the Gemini batch embeddings
API on a timer or when the batch reaches a size threshold. Resolved
embeddings are written back to the corresponding Redis hashes so they
become vector-searchable.
Persistence guarantees:
* Items are written to the ZSET *before* any API call, so a crash or
restart never loses pending work.
* On startup, any leftover items from a previous run are drained
automatically (fire-and-forget — no in-process Future for those).
* On API failure, items are re-added to the ZSET for the next drain
cycle.
A single background :class:`asyncio.Task` (``_drain_task``) drains the
queue — it pops a batch from the ZSET with ``ZPOPMIN``, calls the
embedding API, and writes results back. After each successful batch it
loops back to check for more items; when the ZSET is empty it exits and
the next :meth:`enqueue` call schedules a new drain after the batching
interval.
Usage::
queue = EmbeddingBatchQueue(
openrouter, redis, model="google/gemini-embedding-001",
)
await queue.start() # drains any leftovers from previous run
future = await queue.enqueue("msg:abc-123", "Hello world")
embedding = await future
await queue.stop()
"""
from __future__ import annotations
import asyncio
import json
import logging
import time
from dataclasses import dataclass
import numpy as np
import redis.asyncio as aioredis
from openrouter_client import OpenRouterClient
logger = logging.getLogger(__name__)
DEFAULT_FLUSH_INTERVAL = 3600.0
DEFAULT_MAX_BATCH_SIZE = 50
API_BATCH_LIMIT = 50
PENDING_ZSET_KEY = "embed_queue:pending"
@dataclass
class _QueueItem:
redis_key: str
text: str
score: float
future: asyncio.Future[list[float]] | None = None
def _embed_to_bytes(embedding: list[float]) -> bytes:
return np.array(embedding, dtype=np.float32).tobytes()
def _serialize_item(redis_key: str, text: str) -> str:
return json.dumps({"redis_key": redis_key, "text": text}, separators=(",", ":"))
def _deserialize_item(raw: str | bytes, score: float) -> _QueueItem:
s = raw.decode() if isinstance(raw, bytes) else raw
data = json.loads(s)
return _QueueItem(
redis_key=data["redis_key"],
text=data["text"],
score=score,
)
[docs]
class EmbeddingBatchQueue:
"""Redis-backed queue that batches embedding requests and flushes
them periodically via the Gemini batch API.
Parameters
----------
openrouter:
Shared API client with ``embed_batch()`` support.
redis:
Async Redis client for the persistent ZSET and writing embeddings.
model:
Embedding model identifier.
flush_interval:
Seconds to wait after the first enqueue before starting a flush.
max_batch_size:
Flush immediately when the queue reaches this size.
"""
def __init__(
self,
openrouter: OpenRouterClient,
redis: aioredis.Redis,
model: str = "google/gemini-embedding-001",
flush_interval: float = DEFAULT_FLUSH_INTERVAL,
max_batch_size: int = DEFAULT_MAX_BATCH_SIZE,
) -> None:
self._openrouter = openrouter
self._redis = redis
self._model = model
self._flush_interval = flush_interval
self._max_batch_size = max_batch_size
# In-memory futures for callers who await the result.
# Items recovered from a previous run have no future.
self._pending_futures: dict[str, asyncio.Future[list[float]]] = {}
self._timer_task: asyncio.Task | None = None
self._drain_task: asyncio.Task | None = None
# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------
[docs]
async def start(self) -> None:
"""Start the queue and drain any leftovers from a previous run."""
leftover = await self._redis.zcard(PENDING_ZSET_KEY)
if leftover:
logger.info(
"Recovering %d pending embeddings from previous run", leftover,
)
self._start_drain_now()
logger.info(
"EmbeddingBatchQueue started (flush_interval=%.1fs, max_batch=%d)",
self._flush_interval, self._max_batch_size,
)
[docs]
async def stop(self) -> None:
"""Cancel the background drain task.
Items remain in the Redis ZSET and will be recovered on next
start. In-process futures are cancelled.
"""
if self._timer_task is not None:
self._timer_task.cancel()
self._timer_task = None
if self._drain_task is not None:
self._drain_task.cancel()
await asyncio.gather(self._drain_task, return_exceptions=True)
self._drain_task = None
for fut in self._pending_futures.values():
if not fut.done():
fut.cancel()
self._pending_futures.clear()
pending = await self._redis.zcard(PENDING_ZSET_KEY)
logger.info(
"EmbeddingBatchQueue stopped (%d items persisted for next start)",
pending,
)
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
[docs]
async def enqueue(
self, redis_key: str, text: str,
) -> asyncio.Future[list[float]]:
"""Add text for deferred embedding. Returns a Future that
resolves with the embedding vector once the batch is flushed.
If the same *redis_key* is already pending, the existing future
is returned (dedup). Empty/whitespace text returns a
pre-resolved zero-vector future.
"""
loop = asyncio.get_running_loop()
if not text or not text.strip():
from openrouter_client import EMBED_DIMENSIONS
future: asyncio.Future[list[float]] = loop.create_future()
future.set_result([0.0] * EMBED_DIMENSIONS)
return future
existing = self._pending_futures.get(redis_key)
if existing is not None and not existing.done():
return existing
future = loop.create_future()
self._pending_futures[redis_key] = future
score = time.time()
member = _serialize_item(redis_key, text)
await self._redis.zadd(PENDING_ZSET_KEY, {member: score})
await self._maybe_trigger_drain()
return future
[docs]
async def enqueue_many(
self, items: list[tuple[str, str]],
) -> list[asyncio.Future[list[float]]]:
"""Atomically enqueue multiple ``(redis_key, text)`` pairs.
Returns a list of Futures (one per item, in the same order).
"""
if not items:
return []
loop = asyncio.get_running_loop()
futures: list[asyncio.Future[list[float]]] = []
mapping: dict[str, float] = {}
now = time.time()
for redis_key, text in items:
if not text or not text.strip():
from openrouter_client import EMBED_DIMENSIONS
fut: asyncio.Future[list[float]] = loop.create_future()
fut.set_result([0.0] * EMBED_DIMENSIONS)
futures.append(fut)
continue
existing = self._pending_futures.get(redis_key)
if existing is not None and not existing.done():
futures.append(existing)
continue
fut = loop.create_future()
self._pending_futures[redis_key] = fut
futures.append(fut)
member = _serialize_item(redis_key, text)
mapping[member] = now
if mapping:
await self._redis.zadd(PENDING_ZSET_KEY, mapping)
await self._maybe_trigger_drain()
return futures
[docs]
async def pending_count(self) -> int:
"""Number of items waiting in the persistent queue."""
return await self._redis.zcard(PENDING_ZSET_KEY)
[docs]
def flush_now(self) -> None:
"""Trigger an immediate drain of the persistent queue.
Call this just before an LLM inference so that all pending
embeddings are available for vector search. Non-blocking:
the drain runs in a background task.
"""
self._start_drain_now()
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
async def _maybe_trigger_drain(self) -> None:
"""Check ZSET size and schedule or start drain accordingly."""
size = await self._redis.zcard(PENDING_ZSET_KEY)
if size >= self._max_batch_size:
self._start_drain_now()
elif self._drain_task is None or self._drain_task.done():
if self._timer_task is None or self._timer_task.done():
self._timer_task = asyncio.create_task(
self._trigger_after_delay(),
)
def _start_drain_now(self) -> None:
"""Cancel any pending timer and ensure the drain loop is running."""
if self._timer_task is not None:
self._timer_task.cancel()
self._timer_task = None
if self._drain_task is None or self._drain_task.done():
self._drain_task = asyncio.create_task(self._drain_loop())
async def _trigger_after_delay(self) -> None:
"""Wait for the batching interval, then start the drain loop."""
try:
await asyncio.sleep(self._flush_interval)
except asyncio.CancelledError:
return
self._start_drain_now()
async def _drain_loop(self) -> None:
"""Pop batches from the Redis ZSET and process them until empty."""
try:
while True:
raw_items = await self._redis.zpopmin(
PENDING_ZSET_KEY, self._max_batch_size,
)
if not raw_items:
self._pending_futures = {
k: v for k, v in self._pending_futures.items()
if not v.done()
}
return
batch: list[_QueueItem] = []
for member, score in raw_items:
try:
item = _deserialize_item(member, score)
except Exception:
logger.warning(
"Skipping malformed embed queue item: %r",
member,
)
continue
if not item.text or not item.text.strip():
logger.debug(
"Dropping blank-text item from embed queue: %s",
item.redis_key,
)
fut = self._pending_futures.pop(item.redis_key, None)
if fut is not None and not fut.done():
from openrouter_client import EMBED_DIMENSIONS
fut.set_result([0.0] * EMBED_DIMENSIONS)
continue
item.future = self._pending_futures.pop(
item.redis_key, None,
)
batch.append(item)
if batch:
await self._process_batch(batch)
except asyncio.CancelledError:
logger.warning("EmbeddingBatchQueue drain loop cancelled")
async def _process_batch(self, batch: list[_QueueItem]) -> None:
"""Embed one batch and write results to Redis.
On API failure the items are re-added to the ZSET so they
survive for the next drain cycle or the next startup.
"""
total = len(batch)
logger.info(
"Flushing %d deferred embeddings (sub-batches of %d)",
total, API_BATCH_LIMIT,
)
chunks = [
batch[start : start + API_BATCH_LIMIT]
for start in range(0, total, API_BATCH_LIMIT)
]
async def _embed_chunk(chunk: list[_QueueItem]) -> list[list[float]]:
texts = [item.text for item in chunk]
return await self._openrouter._embed_gemini_batch(
texts, self._model,
)
try:
results = await asyncio.gather(
*(_embed_chunk(c) for c in chunks),
)
except Exception:
logger.warning(
"Batch embedding API call failed for %d sub-batches "
"of %d items — re-queuing all",
len(chunks), total, exc_info=True,
)
await self._requeue_items(batch)
for item in batch:
if item.future is not None and not item.future.done():
item.future.set_exception(
RuntimeError("Batch embedding flush failed"),
)
return
all_embeddings: list[list[float]] = []
for chunk_embeddings in results:
all_embeddings.extend(chunk_embeddings)
pipe = self._redis.pipeline()
for item, embedding in zip(batch, all_embeddings):
pipe.hset(item.redis_key, "embedding", _embed_to_bytes(embedding))
if item.future is not None and not item.future.done():
item.future.set_result(embedding)
try:
await pipe.execute()
except Exception:
logger.warning(
"Failed to write %d embeddings back to Redis",
total, exc_info=True,
)
logger.debug("Flushed %d embeddings to Redis", total)
async def _requeue_items(self, items: list[_QueueItem]) -> None:
"""Re-add failed items to the ZSET for retry."""
mapping: dict[str, float] = {}
for item in items:
member = _serialize_item(item.redis_key, item.text)
mapping[member] = item.score
try:
await self._redis.zadd(PENDING_ZSET_KEY, mapping)
logger.info("Re-queued %d items for retry", len(items))
except Exception:
logger.error(
"Failed to re-queue %d items to Redis", len(items),
exc_info=True,
)