Source code for embedding_queue

"""Redis-backed batched embedding queue for non-critical embedding generation.

Accumulates (redis_key, text) pairs in a Redis sorted set
(``embed_queue:pending``) and flushes them to the Gemini batch embeddings
API on a timer or when the batch reaches a size threshold.  Resolved
embeddings are written back to the corresponding Redis hashes so they
become vector-searchable.

Persistence guarantees:

* Items are written to the ZSET *before* any API call, so a crash or
  restart never loses pending work.
* On startup, any leftover items from a previous run are drained
  automatically (fire-and-forget — no in-process Future for those).
* On API failure, items are re-added to the ZSET for the next drain
  cycle.

A single background :class:`asyncio.Task` (``_drain_task``) drains the
queue — it pops a batch from the ZSET with ``ZPOPMIN``, calls the
embedding API, and writes results back.  After each successful batch it
loops back to check for more items; when the ZSET is empty it exits and
the next :meth:`enqueue` call schedules a new drain after the batching
interval.

Usage::

    queue = EmbeddingBatchQueue(
        openrouter, redis, model="google/gemini-embedding-001",
    )
    await queue.start()   # drains any leftovers from previous run

    future = await queue.enqueue("msg:abc-123", "Hello world")
    embedding = await future

    await queue.stop()
"""

from __future__ import annotations

import asyncio
import json
import logging
import time
from dataclasses import dataclass

import numpy as np
import redis.asyncio as aioredis

from openrouter_client import OpenRouterClient

logger = logging.getLogger(__name__)

DEFAULT_FLUSH_INTERVAL = 3600.0
DEFAULT_MAX_BATCH_SIZE = 50
API_BATCH_LIMIT = 50

PENDING_ZSET_KEY = "embed_queue:pending"


@dataclass
class _QueueItem:
    redis_key: str
    text: str
    score: float
    future: asyncio.Future[list[float]] | None = None


def _embed_to_bytes(embedding: list[float]) -> bytes:
    return np.array(embedding, dtype=np.float32).tobytes()


def _serialize_item(redis_key: str, text: str) -> str:
    return json.dumps({"redis_key": redis_key, "text": text}, separators=(",", ":"))


def _deserialize_item(raw: str | bytes, score: float) -> _QueueItem:
    s = raw.decode() if isinstance(raw, bytes) else raw
    data = json.loads(s)
    return _QueueItem(
        redis_key=data["redis_key"],
        text=data["text"],
        score=score,
    )


[docs] class EmbeddingBatchQueue: """Redis-backed queue that batches embedding requests and flushes them periodically via the Gemini batch API. Parameters ---------- openrouter: Shared API client with ``embed_batch()`` support. redis: Async Redis client for the persistent ZSET and writing embeddings. model: Embedding model identifier. flush_interval: Seconds to wait after the first enqueue before starting a flush. max_batch_size: Flush immediately when the queue reaches this size. """ def __init__( self, openrouter: OpenRouterClient, redis: aioredis.Redis, model: str = "google/gemini-embedding-001", flush_interval: float = DEFAULT_FLUSH_INTERVAL, max_batch_size: int = DEFAULT_MAX_BATCH_SIZE, ) -> None: self._openrouter = openrouter self._redis = redis self._model = model self._flush_interval = flush_interval self._max_batch_size = max_batch_size # In-memory futures for callers who await the result. # Items recovered from a previous run have no future. self._pending_futures: dict[str, asyncio.Future[list[float]]] = {} self._timer_task: asyncio.Task | None = None self._drain_task: asyncio.Task | None = None # ------------------------------------------------------------------ # Lifecycle # ------------------------------------------------------------------
[docs] async def start(self) -> None: """Start the queue and drain any leftovers from a previous run.""" leftover = await self._redis.zcard(PENDING_ZSET_KEY) if leftover: logger.info( "Recovering %d pending embeddings from previous run", leftover, ) self._start_drain_now() logger.info( "EmbeddingBatchQueue started (flush_interval=%.1fs, max_batch=%d)", self._flush_interval, self._max_batch_size, )
[docs] async def stop(self) -> None: """Cancel the background drain task. Items remain in the Redis ZSET and will be recovered on next start. In-process futures are cancelled. """ if self._timer_task is not None: self._timer_task.cancel() self._timer_task = None if self._drain_task is not None: self._drain_task.cancel() await asyncio.gather(self._drain_task, return_exceptions=True) self._drain_task = None for fut in self._pending_futures.values(): if not fut.done(): fut.cancel() self._pending_futures.clear() pending = await self._redis.zcard(PENDING_ZSET_KEY) logger.info( "EmbeddingBatchQueue stopped (%d items persisted for next start)", pending, )
# ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------
[docs] async def enqueue( self, redis_key: str, text: str, ) -> asyncio.Future[list[float]]: """Add text for deferred embedding. Returns a Future that resolves with the embedding vector once the batch is flushed. If the same *redis_key* is already pending, the existing future is returned (dedup). Empty/whitespace text returns a pre-resolved zero-vector future. """ loop = asyncio.get_running_loop() if not text or not text.strip(): from openrouter_client import EMBED_DIMENSIONS future: asyncio.Future[list[float]] = loop.create_future() future.set_result([0.0] * EMBED_DIMENSIONS) return future existing = self._pending_futures.get(redis_key) if existing is not None and not existing.done(): return existing future = loop.create_future() self._pending_futures[redis_key] = future score = time.time() member = _serialize_item(redis_key, text) await self._redis.zadd(PENDING_ZSET_KEY, {member: score}) await self._maybe_trigger_drain() return future
[docs] async def enqueue_many( self, items: list[tuple[str, str]], ) -> list[asyncio.Future[list[float]]]: """Atomically enqueue multiple ``(redis_key, text)`` pairs. Returns a list of Futures (one per item, in the same order). """ if not items: return [] loop = asyncio.get_running_loop() futures: list[asyncio.Future[list[float]]] = [] mapping: dict[str, float] = {} now = time.time() for redis_key, text in items: if not text or not text.strip(): from openrouter_client import EMBED_DIMENSIONS fut: asyncio.Future[list[float]] = loop.create_future() fut.set_result([0.0] * EMBED_DIMENSIONS) futures.append(fut) continue existing = self._pending_futures.get(redis_key) if existing is not None and not existing.done(): futures.append(existing) continue fut = loop.create_future() self._pending_futures[redis_key] = fut futures.append(fut) member = _serialize_item(redis_key, text) mapping[member] = now if mapping: await self._redis.zadd(PENDING_ZSET_KEY, mapping) await self._maybe_trigger_drain() return futures
[docs] async def pending_count(self) -> int: """Number of items waiting in the persistent queue.""" return await self._redis.zcard(PENDING_ZSET_KEY)
[docs] def flush_now(self) -> None: """Trigger an immediate drain of the persistent queue. Call this just before an LLM inference so that all pending embeddings are available for vector search. Non-blocking: the drain runs in a background task. """ self._start_drain_now()
# ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ async def _maybe_trigger_drain(self) -> None: """Check ZSET size and schedule or start drain accordingly.""" size = await self._redis.zcard(PENDING_ZSET_KEY) if size >= self._max_batch_size: self._start_drain_now() elif self._drain_task is None or self._drain_task.done(): if self._timer_task is None or self._timer_task.done(): self._timer_task = asyncio.create_task( self._trigger_after_delay(), ) def _start_drain_now(self) -> None: """Cancel any pending timer and ensure the drain loop is running.""" if self._timer_task is not None: self._timer_task.cancel() self._timer_task = None if self._drain_task is None or self._drain_task.done(): self._drain_task = asyncio.create_task(self._drain_loop()) async def _trigger_after_delay(self) -> None: """Wait for the batching interval, then start the drain loop.""" try: await asyncio.sleep(self._flush_interval) except asyncio.CancelledError: return self._start_drain_now() async def _drain_loop(self) -> None: """Pop batches from the Redis ZSET and process them until empty.""" try: while True: raw_items = await self._redis.zpopmin( PENDING_ZSET_KEY, self._max_batch_size, ) if not raw_items: self._pending_futures = { k: v for k, v in self._pending_futures.items() if not v.done() } return batch: list[_QueueItem] = [] for member, score in raw_items: try: item = _deserialize_item(member, score) except Exception: logger.warning( "Skipping malformed embed queue item: %r", member, ) continue if not item.text or not item.text.strip(): logger.debug( "Dropping blank-text item from embed queue: %s", item.redis_key, ) fut = self._pending_futures.pop(item.redis_key, None) if fut is not None and not fut.done(): from openrouter_client import EMBED_DIMENSIONS fut.set_result([0.0] * EMBED_DIMENSIONS) continue item.future = self._pending_futures.pop( item.redis_key, None, ) batch.append(item) if batch: await self._process_batch(batch) except asyncio.CancelledError: logger.warning("EmbeddingBatchQueue drain loop cancelled") async def _process_batch(self, batch: list[_QueueItem]) -> None: """Embed one batch and write results to Redis. On API failure the items are re-added to the ZSET so they survive for the next drain cycle or the next startup. """ total = len(batch) logger.info( "Flushing %d deferred embeddings (sub-batches of %d)", total, API_BATCH_LIMIT, ) chunks = [ batch[start : start + API_BATCH_LIMIT] for start in range(0, total, API_BATCH_LIMIT) ] async def _embed_chunk(chunk: list[_QueueItem]) -> list[list[float]]: texts = [item.text for item in chunk] return await self._openrouter._embed_gemini_batch( texts, self._model, ) try: results = await asyncio.gather( *(_embed_chunk(c) for c in chunks), ) except Exception: logger.warning( "Batch embedding API call failed for %d sub-batches " "of %d items — re-queuing all", len(chunks), total, exc_info=True, ) await self._requeue_items(batch) for item in batch: if item.future is not None and not item.future.done(): item.future.set_exception( RuntimeError("Batch embedding flush failed"), ) return all_embeddings: list[list[float]] = [] for chunk_embeddings in results: all_embeddings.extend(chunk_embeddings) pipe = self._redis.pipeline() for item, embedding in zip(batch, all_embeddings): pipe.hset(item.redis_key, "embedding", _embed_to_bytes(embedding)) if item.future is not None and not item.future.done(): item.future.set_result(embedding) try: await pipe.execute() except Exception: logger.warning( "Failed to write %d embeddings back to Redis", total, exc_info=True, ) logger.debug("Flushed %d embeddings to Redis", total) async def _requeue_items(self, items: list[_QueueItem]) -> None: """Re-add failed items to the ZSET for retry.""" mapping: dict[str, float] = {} for item in items: member = _serialize_item(item.redis_key, item.text) mapping[member] = item.score try: await self._redis.zadd(PENDING_ZSET_KEY, mapping) logger.info("Re-queued %d items for retry", len(items)) except Exception: logger.error( "Failed to re-queue %d items to Redis", len(items), exc_info=True, )