"""Redis-backed batched embedding queue for non-critical embedding generation.
Accumulates (redis_key, text) pairs in a Redis sorted set
(``embed_queue:pending``) and flushes them to the Gemini batch embeddings
API on a timer or when the batch reaches a size threshold. Resolved
embeddings are written back to the corresponding Redis hashes so they
become vector-searchable.
Persistence guarantees:
* Items are written to the ZSET *before* any API call, so a crash or
restart never loses pending work.
* On startup, any leftover items from a previous run are drained
automatically (fire-and-forget — no in-process Future for those).
* On API failure, items are re-added to the ZSET for the next drain
cycle.
A single background :class:`asyncio.Task` (``_drain_task``) drains the
queue — it pops a batch from the ZSET with ``ZPOPMIN``, calls the
embedding API, and writes results back. After each successful batch it
loops back to check for more items; when the ZSET is empty it exits and
the next :meth:`enqueue` call schedules a new drain after the batching
interval.
Usage::
queue = EmbeddingBatchQueue(
openrouter, redis, model="google/gemini-embedding-001",
)
await queue.start() # drains any leftovers from previous run
future = await queue.enqueue("msg:abc-123", "Hello world")
embedding = await future
await queue.stop()
"""
from __future__ import annotations
import asyncio
import jsonutil as json
import logging
import threading
import time
from dataclasses import dataclass
from typing import Any
import numpy as np
import redis.asyncio as aioredis
from openrouter_client import OpenRouterClient
logger = logging.getLogger(__name__)
DEFAULT_FLUSH_INTERVAL = 3600.0
DEFAULT_MAX_BATCH_SIZE = 50
API_BATCH_LIMIT = 50
PENDING_ZSET_KEY = "embed_queue:pending"
INFLIGHT_ZSET_KEY = "embed_queue:inflight"
DEDUP_HASH_KEY = "embed_queue:dedup"
RETRIES_HASH_KEY = "embed_queue:retries"
INFLIGHT_RECLAIM_SECONDS = 300.0
# After this many failed flush attempts for the same ``redis_key``, the
# embedding queue drops the item: a zero vector is written to
# ``redis_key`` so vector search still works, the dedup/retries tracking
# is cleaned up, and the item is NOT re-added to the pending ZSET.
MAX_EMBED_QUEUE_RETRIES = 3
@dataclass
class _QueueItem:
redis_key: str
text: str
score: float
def _embed_to_bytes(embedding: list[float]) -> bytes:
"""Pack an embedding vector into the raw float32 byte blob stored in Redis.
Produces the binary representation that vector search expects: a contiguous
little-endian ``float32`` buffer written under the ``embedding`` field of
each item's Redis hash.
This is the module-local twin of :func:`message_cache._embed_to_bytes` and
is used by :meth:`EmbeddingBatchQueue._process_batch` (to persist real
embeddings), :meth:`EmbeddingBatchQueue._drain_loop` and
:meth:`EmbeddingBatchQueue._requeue_items` (to sink zero vectors for blank
or permanently-failing items). It has no other callers outside this module.
Args:
embedding (list[float]): The embedding vector to serialize.
Returns:
bytes: The vector encoded as a packed ``float32`` byte string.
"""
return np.array(embedding, dtype=np.float32).tobytes()
def _serialize_item(redis_key: str, text: str) -> str:
"""Encode a ``(redis_key, text)`` pair into the JSON ZSET member string.
Produces the canonical, compact JSON member used to identify a queue item
inside the pending/in-flight ZSETs and the dedup hash; the same string must
be reproduced byte-for-byte to remove an item, so separators are pinned.
Called throughout this module wherever a queue member is written or matched:
:meth:`EmbeddingBatchQueue.enqueue` / :meth:`enqueue_many` (to add members),
:meth:`_release_inflight` and :meth:`_requeue_items` (to locate the member
to remove), and :meth:`_drain_loop` (to drop blank items from the in-flight
ZSET). It is also exercised directly by ``tests/core/test_embed_pubsub.py``.
Args:
redis_key (str): The Redis hash key the resolved embedding will be
written to (e.g. ``msg:abc-123``).
text (str): The text to be embedded.
Returns:
str: A compact JSON object string ``{"redis_key":...,"text":...}``.
"""
return json.dumps({"redis_key": redis_key, "text": text}, separators=(",", ":"))
def _deserialize_item(raw: str | bytes, score: float) -> _QueueItem:
"""Decode a ZSET member string back into a :class:`_QueueItem`.
Inverts :func:`_serialize_item`, attaching the supplied sorted-set *score*
so a re-queue can preserve FIFO ordering. Accepts either ``str`` or ``bytes``
members since Redis may return raw bytes.
Called only by :meth:`EmbeddingBatchQueue._claim_pending_batch` when turning
members popped from the pending ZSET into work items; malformed members
there are caught and skipped by the caller.
Args:
raw (str | bytes): The serialized member produced by
:func:`_serialize_item`.
score (float): The ZSET score to record on the item (its enqueue
timestamp), used for re-queue ordering.
Returns:
_QueueItem: The reconstructed item with ``redis_key``, ``text`` and
``score`` populated.
Raises:
ValueError: If *raw* is not valid JSON.
KeyError: If the decoded JSON lacks ``redis_key`` or ``text``.
"""
s = raw.decode() if isinstance(raw, bytes) else raw
data = json.loads(s)
return _QueueItem(
redis_key=data["redis_key"],
text=data["text"],
score=score,
)
def _silence_future_exception(fut: asyncio.Future) -> None:
"""Consume a future's exception so asyncio does not log it as unretrieved.
The embedding futures handed back to callers may never be awaited (a
producer can fire-and-forget an enqueue), and when such a future ends in an
exception asyncio emits a noisy "exception was never retrieved" warning at
garbage-collection time. Installed as a done-callback so the exception is
read and discarded once the future completes; cancelled futures are left
alone. Touches only the in-memory future object, no I/O.
Registered via ``add_done_callback`` on each future created in
:meth:`EmbeddingBatchQueue.enqueue` and
:meth:`EmbeddingBatchQueue.enqueue_many`; asyncio invokes it automatically.
Args:
fut (asyncio.Future): The completed future whose exception (if any)
should be read and swallowed.
Returns:
None
"""
if fut.done() and not fut.cancelled():
try:
fut.exception()
except Exception:
pass
[docs]
class EmbeddingBatchQueue:
"""Redis-backed queue that batches embedding requests and flushes
them periodically via the Gemini batch API.
Parameters
----------
openrouter:
Shared API client with ``embed_batch()`` support.
redis:
Async Redis client for the persistent ZSET and writing embeddings.
model:
Embedding model identifier.
flush_interval:
Seconds to wait after the first enqueue before starting a flush.
max_batch_size:
Flush immediately when the queue reaches this size.
"""
[docs]
def __init__(
self,
openrouter: OpenRouterClient,
redis: aioredis.Redis,
model: str = "google/gemini-embedding-001",
flush_interval: float = DEFAULT_FLUSH_INTERVAL,
max_batch_size: int = DEFAULT_MAX_BATCH_SIZE,
) -> None:
"""Configure the queue and prepare its background-task and locking state.
Stores the injected collaborators and tuning parameters and initialises
the timer/drain task slots (left ``None`` until :meth:`start` or
:meth:`enqueue` schedules them), the cross-thread coordination lock, and
the Lua ``_claim_script`` used by :meth:`_claim_pending_batch` to
atomically move a batch from the pending ZSET into the in-flight ZSET.
No Redis or network I/O happens here; the background drain begins only
once :meth:`start` is called.
Constructed once per worker process by ``InferenceWorker`` in
``inference_main.py`` and ``AgentsWorker`` in ``agents_main.py``, and is
then passed down to the message processor and knowledge-anchoring worker.
Args:
openrouter (OpenRouterClient): Shared API client; its
``_embed_gemini_batch()`` method performs the actual embedding
calls.
redis (aioredis.Redis): Async Redis client backing the persistent
pending/in-flight ZSETs, dedup/retry hashes, the result hashes,
and the ``sg:embed:done:*`` pub/sub channels.
model (str): Embedding model identifier passed to the batch API.
flush_interval (float): Seconds to wait after the first enqueue
before a timer-triggered flush starts.
max_batch_size (int): Queue size that triggers an immediate flush,
and the maximum number of items claimed per drain iteration.
"""
self._openrouter = openrouter
self._redis = redis
self._model = model
self._flush_interval = flush_interval
self._max_batch_size = max_batch_size
self._timer_task: asyncio.Task | None = None
self._drain_task: asyncio.Task | None = None
self._drain_coord_lock = threading.Lock()
self._claim_script = """
local pending = KEYS[1]
local inflight = KEYS[2]
local now = tonumber(ARGV[1])
local count = tonumber(ARGV[2])
local items = redis.call('ZRANGE', pending, 0, count - 1)
for _, m in ipairs(items) do
redis.call('ZREM', pending, m)
redis.call('ZADD', inflight, now, m)
end
return items
"""
async def _reclaim_inflight(self) -> None:
"""Recover orphaned in-flight items left behind by a crashed worker.
Closes the crash-safety loop opened by :meth:`_claim_pending_batch`:
items are moved to ``embed_queue:inflight`` (scored with their claim
time) while being embedded, so if a worker dies before writing results
they would otherwise be stuck. This scans the in-flight ZSET for members
older than ``INFLIGHT_RECLAIM_SECONDS`` via ``ZRANGEBYSCORE`` and, in a
single pipeline, removes each from in-flight (``ZREM``) and adds it back
onto ``embed_queue:pending`` (``ZADD``, re-scored to now) so the next
drain re-processes it. All Redis errors are caught and logged rather
than propagated.
Called once by :meth:`start` during recovery, before the leftover drain
is scheduled.
Returns:
None
"""
cutoff = time.time() - INFLIGHT_RECLAIM_SECONDS
try:
stale = await self._redis.zrangebyscore(
INFLIGHT_ZSET_KEY,
"-inf",
cutoff,
)
if not stale:
return
pipe = self._redis.pipeline()
for member in stale:
pipe.zrem(INFLIGHT_ZSET_KEY, member)
pipe.zadd(PENDING_ZSET_KEY, {member: time.time()})
await pipe.execute()
logger.warning(
"Reclaimed %d stale in-flight embedding queue items",
len(stale),
)
except Exception:
logger.exception("Failed to reclaim in-flight embedding queue items")
async def _claim_pending_batch(self, count: int) -> list[_QueueItem]:
"""Atomically claim up to *count* items from the pending ZSET.
Runs the ``_claim_script`` Lua under ``EVAL`` so that, in a single
round-trip, the oldest *count* members are removed from
``embed_queue:pending`` and re-added to ``embed_queue:inflight`` scored
with the current time. This makes claiming crash-safe: items in flight
are tracked and can later be recovered by :meth:`_reclaim_inflight` if
the worker dies before writing results. Each returned member is decoded
via :func:`_deserialize_item`; any malformed member is dropped straight
from the in-flight ZSET and logged rather than failing the batch.
Called by :meth:`_drain_loop` once per drain iteration to fetch the next
batch to embed.
Args:
count (int): Maximum number of items to claim (the configured
``max_batch_size``).
Returns:
list[_QueueItem]: The claimed, successfully-decoded items (possibly
empty when the pending ZSET holds no more work).
"""
raw_items = await self._redis.eval(
self._claim_script,
2,
PENDING_ZSET_KEY,
INFLIGHT_ZSET_KEY,
str(time.time()),
str(count),
)
batch: list[_QueueItem] = []
if not raw_items:
return batch
now = time.time()
for member in raw_items:
if isinstance(member, bytes):
member = member.decode("utf-8")
try:
item = _deserialize_item(member, now)
except Exception:
await self._redis.zrem(INFLIGHT_ZSET_KEY, member)
logger.warning("Skipping malformed embed queue item: %r", member)
continue
batch.append(item)
return batch
async def _release_inflight(self, batch: list[_QueueItem]) -> None:
"""Remove successfully-processed items from the in-flight ZSET.
Closes out a batch once its embeddings have been written back to Redis,
``ZREM``-ing each item (re-serialized via :func:`_serialize_item`) from
``embed_queue:inflight`` in a single pipeline so it is no longer eligible
for stale reclaim. Items that failed instead go through
:meth:`_requeue_items`, which removes them from in-flight itself.
Called by :meth:`_process_batch` after the result-write pipeline
succeeds.
Args:
batch (list[_QueueItem]): The items whose embeddings were persisted;
a no-op when empty.
"""
if not batch:
return
pipe = self._redis.pipeline()
for item in batch:
pipe.zrem(INFLIGHT_ZSET_KEY, _serialize_item(item.redis_key, item.text))
await pipe.execute()
async def _zrem_by_redis_key(self, redis_key: str) -> None:
"""Evict any stale pending entry for a redis key before re-enqueueing it.
Implements the queue's dedup-by-key invariant: the dedup hash
``embed_queue:dedup`` maps each ``redis_key`` to the exact serialized
member currently sitting in the pending ZSET, so this fetches that
member (``HGET``) and removes it from ``embed_queue:pending`` (``ZREM``).
Callers invoke it
immediately before ``ZADD``-ing a fresh ``(redis_key, text)`` member so a
newer text for the same key supersedes an older queued one (e.g. after a
restart, or when another process queued the same key) instead of leaving
a duplicate behind.
Called by :meth:`enqueue` and, per item, by :meth:`enqueue_many`.
Args:
redis_key (str): The destination hash key whose prior pending member
should be removed.
Returns:
None
"""
old = await self._redis.hget(DEDUP_HASH_KEY, redis_key)
if old:
await self._redis.zrem(PENDING_ZSET_KEY, old)
# ------------------------------------------------------------------
# Lifecycle
# ------------------------------------------------------------------
[docs]
async def start(self) -> None:
"""Bring the queue online and recover any work left by a previous run.
The lifecycle entry point: it first runs :meth:`_reclaim_inflight` to
rescue items orphaned by a crashed worker, then counts the pending and
in-flight ZSETs (``ZCARD``) and, if either holds leftovers, immediately
kicks
off a background drain via :meth:`_start_drain_now` (fire-and-forget,
with no per-item future since those producers are gone). Beyond that it
only logs; the steady-state drain is scheduled lazily by
:meth:`enqueue`.
Awaited once per worker process at startup by ``InferenceWorker`` in
``inference_main.py`` (line 350) and ``AgentsWorker`` in
``agents_main.py`` (line 250).
Returns:
None
"""
await self._reclaim_inflight()
leftover = await self._redis.zcard(PENDING_ZSET_KEY)
inflight = await self._redis.zcard(INFLIGHT_ZSET_KEY)
if leftover or inflight:
logger.info(
"Recovering embedding queue (pending=%d, inflight=%d)",
leftover,
inflight,
)
self._start_drain_now()
logger.info(
"EmbeddingBatchQueue started (flush_interval=%.1fs, max_batch=%d)",
self._flush_interval,
self._max_batch_size,
)
[docs]
async def stop(self) -> None:
"""Tear down the background tasks, leaving queued work safely persisted.
The lifecycle shutdown counterpart of :meth:`start`. It cancels the
pending flush timer and the drain task (awaiting the latter so its
cancellation settles), then counts (``ZCARD``) and logs how many items
remain in ``embed_queue:pending``. Nothing is dropped: because every
item lives in
Redis, in-progress and pending work is recovered by the next
:meth:`start`. No embeddings are computed or written here.
Awaited during worker shutdown by ``InferenceWorker`` in
``inference_main.py`` (line 506) and ``AgentsWorker`` in
``agents_main.py`` (line 454).
Returns:
None
"""
if self._timer_task is not None:
self._timer_task.cancel()
self._timer_task = None
if self._drain_task is not None:
self._drain_task.cancel()
await asyncio.gather(self._drain_task, return_exceptions=True)
self._drain_task = None
pending = await self._redis.zcard(PENDING_ZSET_KEY)
logger.info(
"EmbeddingBatchQueue stopped (%d items persisted for next start)",
pending,
)
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
async def _await_pubsub_result(self, redis_key: str, future: asyncio.Future[list[float]]) -> None:
"""Resolve *future* once the embedding for *redis_key* lands in Redis.
Bridges the cross-process gap between an enqueueing producer and the
drain worker (which may run in a different service): rather than holding
an in-memory promise, it waits for the embedding to appear in Redis.
First it does a fast pre-check of the ``embedding`` field on the
*redis_key* hash (resolving immediately if already computed); otherwise
it subscribes to ``sg:embed:done:{redis_key}`` and, on each publish,
re-reads the hash and resolves the future. A blank or corrupted stored
vector resolves the future with a :class:`ValueError`, and a publish
with no stored value with a :class:`RuntimeError`.
It reads back the blob via ``message_cache._bytes_to_embed`` and always
unsubscribes/closes the pub/sub connection in ``finally``. The matching
publishes come from :meth:`_process_batch`, :meth:`_drain_loop` and
:meth:`_requeue_items` (the zero-vector sink path). Spawned as a
fire-and-forget ``asyncio.create_task`` from :meth:`enqueue` and
:meth:`enqueue_many`; also driven directly by
``tests/test_ka_defensive_embeddings.py``.
Args:
redis_key (str): The Redis hash key whose ``embedding`` field will
hold the resolved vector.
future (asyncio.Future[list[float]]): The future handed back to the
caller, resolved (or failed) in place by this coroutine.
Returns:
None: The result is delivered through *future*, not the return value.
"""
# Pre-check: is it already computed?
try:
val = await self._redis.hget(redis_key, "embedding")
if val is not None:
from message_cache import _bytes_to_embed
embedding = _bytes_to_embed(val)
if embedding:
future.set_result(embedding)
return
except Exception as e:
logger.warning("Failed to check pre-existing embedding for %s: %s", redis_key, e)
channel = f"sg:embed:done:{redis_key}"
pubsub = self._redis.pubsub()
try:
await pubsub.subscribe(channel)
while not future.done():
try:
msg = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
if msg is not None:
val = await self._redis.hget(redis_key, "embedding")
if val is not None:
from message_cache import _bytes_to_embed
embedding = _bytes_to_embed(val)
if embedding:
future.set_result(embedding)
else:
future.set_exception(ValueError("Corrupted/invalid embedding in await pubsub"))
else:
future.set_exception(RuntimeError("Embedding missing from Redis hash despite publish event"))
break
await asyncio.sleep(0.05)
except asyncio.CancelledError:
break
except Exception as e:
logger.warning("Error awaiting pubsub embedding result for %s: %s", redis_key, e)
future.set_exception(e)
break
finally:
try:
await pubsub.unsubscribe(channel)
await pubsub.aclose()
except Exception as e:
logger.warning("Error closing pubsub for %s: %s", redis_key, e)
[docs]
async def enqueue(
self,
redis_key: str,
text: str,
) -> asyncio.Future[list[float]]:
"""Queue one text for deferred embedding and hand back a result future.
The main public producer entry point. Empty/whitespace text short-
circuits to a pre-resolved zero-vector future (using
``openrouter_client.EMBED_DIMENSIONS``) and never touches Redis.
Otherwise it creates a future, spawns :meth:`_await_pubsub_result` as a
background task to resolve it once the embedding lands in Redis (the
cross-process bridge, since the actual drain may run in another
service), de-dups any prior pending member for the key via
:meth:`_zrem_by_redis_key`, then in one pipeline writes the dedup hash
(``HSET``) and adds the serialized member onto ``embed_queue:pending``
(``ZADD``).
Finally it calls :meth:`_maybe_trigger_drain` to schedule or fire the
flush. The future is resolved out-of-band when the embedding is written
and the ``sg:embed:done:{redis_key}`` channel is published.
Called from the message-processing path: ``message_processor/processor.py``
(lines 2088, 2107, 2305), ``message_processor/channel_heartbeat.py``
(line 576) and ``message_processor/generate_and_send.py`` (line 2573);
also driven by the pub/sub tests.
Args:
redis_key (str): Destination hash whose ``embedding`` field will
receive the resolved vector.
text (str): The text to embed.
Returns:
asyncio.Future[list[float]]: A future resolving to the embedding
vector (a zero vector for blank text, or after the batch flush).
"""
loop = asyncio.get_running_loop()
if not text or not text.strip():
from gemini_embed_pool import EMBED_DIMENSIONS
future: asyncio.Future[list[float]] = loop.create_future()
future.set_result([0.0] * EMBED_DIMENSIONS)
return future
future = loop.create_future()
future.add_done_callback(_silence_future_exception)
asyncio.create_task(self._await_pubsub_result(redis_key, future))
score = time.time()
member = _serialize_item(redis_key, text)
await self._zrem_by_redis_key(redis_key)
pipe = self._redis.pipeline()
pipe.hset(DEDUP_HASH_KEY, redis_key, member)
pipe.zadd(PENDING_ZSET_KEY, {member: score})
await pipe.execute()
await self._maybe_trigger_drain()
return future
[docs]
async def enqueue_many(
self,
items: list[tuple[str, str]],
) -> list[asyncio.Future[list[float]]]:
"""Queue a batch of ``(redis_key, text)`` pairs for deferred embedding.
The bulk counterpart of :meth:`enqueue`, applying the same per-item
rules: blank text yields a pre-resolved zero-vector future, while real
text gets a future plus a background :meth:`_await_pubsub_result` task,
is de-duped via :meth:`_zrem_by_redis_key`, and is written to the dedup
hash and ``embed_queue:pending`` (each item in its own small pipeline).
A single :meth:`_maybe_trigger_drain` runs at the end if anything was
actually queued. The returned futures line up positionally with the
input list.
Called by ``background_tasks.py`` (line 1344, queuing pending message
embeddings) and by ``MessageProcessor`` in
``message_processor/processor.py`` (line 4222).
Args:
items (list[tuple[str, str]]): ``(redis_key, text)`` pairs to queue,
in order.
Returns:
list[asyncio.Future[list[float]]]: One future per input item, in the
same order; an empty list when ``items`` is empty.
"""
if not items:
return []
loop = asyncio.get_running_loop()
futures: list[asyncio.Future[list[float]]] = []
pending_members: list[tuple[str, str, float]] = []
now = time.time()
for redis_key, text in items:
if not text or not text.strip():
from gemini_embed_pool import EMBED_DIMENSIONS
fut: asyncio.Future[list[float]] = loop.create_future()
fut.set_result([0.0] * EMBED_DIMENSIONS)
futures.append(fut)
continue
fut = loop.create_future()
fut.add_done_callback(_silence_future_exception)
asyncio.create_task(self._await_pubsub_result(redis_key, fut))
futures.append(fut)
member = _serialize_item(redis_key, text)
pending_members.append((redis_key, member, now))
for rk, member, score in pending_members:
await self._zrem_by_redis_key(rk)
pipe = self._redis.pipeline()
pipe.hset(DEDUP_HASH_KEY, rk, member)
pipe.zadd(PENDING_ZSET_KEY, {member: score})
await pipe.execute()
if pending_members:
await self._maybe_trigger_drain()
return futures
[docs]
async def pending_count(self) -> int:
"""Return how many items are currently waiting in the persistent queue.
A thin ``ZCARD`` over ``embed_queue:pending`` exposing the backlog size.
Counts only pending work, not items already claimed into the in-flight
ZSET. Used internally as the drain-completion signal.
Called by :meth:`flush_and_wait` to poll for an empty queue; no other
in-repo callers were found.
Returns:
int: The number of members in the pending ZSET.
"""
return await self._redis.zcard(PENDING_ZSET_KEY)
[docs]
def flush_now(self) -> None:
"""Kick off an immediate, non-blocking drain of the pending queue.
A synchronous fire-and-forget trigger meant to be called just before an
LLM inference so freshly-enqueued embeddings become vector-searchable in
time. It simply delegates to :meth:`_start_drain_now`, which cancels any
pending flush timer and ensures the background :meth:`_drain_loop` task
is running; it does not wait for completion (use :meth:`flush_and_wait`
for that).
Called synchronously by ``MessageProcessor`` in
``message_processor/generate_and_send.py`` (line 1323).
Returns:
None
"""
self._start_drain_now()
[docs]
async def flush_and_wait(self, timeout: float = 15.0) -> None:
"""Force a drain and await it until the queue empties or time runs out.
The blocking counterpart of :meth:`flush_now`, used when a caller needs
all pending embeddings written before proceeding. It calls
:meth:`_start_drain_now`, then loops polling :meth:`pending_count`:
while items remain it awaits a shielded slice of the live drain task
(re-kicking the drain if none is running) until the pending ZSET reaches
zero or the ``timeout`` budget elapses. The shield prevents the drain
from being cancelled when an individual wait times out.
Awaited by the knowledge-anchoring worker in
``knowledge_anchoring/worker.py`` (line 1185) before it relies on
message embeddings being present.
Args:
timeout (float): Maximum seconds to wait for the queue to drain
before returning regardless of remaining items.
Returns:
None
"""
self._start_drain_now()
t0 = time.monotonic()
while time.monotonic() - t0 < timeout:
count = await self.pending_count()
if count == 0:
break
if self._drain_task is not None and not self._drain_task.done():
try:
await asyncio.wait_for(asyncio.shield(self._drain_task), timeout=1.0)
except asyncio.TimeoutError:
pass
except Exception:
await asyncio.sleep(0.1)
else:
self._start_drain_now()
await asyncio.sleep(0.1)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
async def _maybe_trigger_drain(self) -> None:
"""Decide whether to flush now or arm the batching timer after an enqueue.
Implements the queue's size-vs-time flush policy. It counts
``embed_queue:pending`` (``ZCARD``) and, if the backlog has reached
``max_batch_size``, drains immediately via :meth:`_start_drain_now`;
otherwise, when no drain is already running, it arms a one-shot
:meth:`_trigger_after_delay` timer task so a partial batch still flushes
after ``flush_interval`` seconds. The only Redis access is that count.
Called at the end of :meth:`enqueue` and :meth:`enqueue_many`.
Returns:
None
"""
size = await self._redis.zcard(PENDING_ZSET_KEY)
if size >= self._max_batch_size:
self._start_drain_now()
elif self._drain_task is None or self._drain_task.done():
if self._timer_task is None or self._timer_task.done():
self._timer_task = asyncio.create_task(
self._trigger_after_delay(),
)
def _start_drain_now(self) -> None:
"""Cancel the pending-flush timer and ensure a drain task is running.
The single chokepoint that all immediate-flush paths funnel through. It
takes the ``_drain_coord_lock`` (a ``threading.Lock`` guarding against
races between callers) to cancel any armed :meth:`_trigger_after_delay`
timer and, only if no live drain exists, spawn a fresh
:meth:`_drain_loop` task. Idempotent: calling it while a drain is
already in flight is a no-op, so it never starts two concurrent drains.
Synchronous and fire-and-forget; it schedules but does not await.
Called by :meth:`start`, :meth:`flush_now`, :meth:`flush_and_wait`,
:meth:`_maybe_trigger_drain` and :meth:`_trigger_after_delay`.
Returns:
None
"""
with self._drain_coord_lock:
if self._timer_task is not None:
self._timer_task.cancel()
self._timer_task = None
if self._drain_task is None or self._drain_task.done():
self._drain_task = asyncio.create_task(self._drain_loop())
async def _trigger_after_delay(self) -> None:
"""Sleep one batching interval, then flush whatever has accumulated.
The timer half of the size-vs-time policy: it guarantees a partially
full batch still gets embedded. It sleeps for ``flush_interval`` seconds
via ``asyncio.sleep`` and then calls :meth:`_start_drain_now`. If
the timer is cancelled first (because the size threshold was hit and a
drain already started), it returns quietly without flushing.
Spawned as a background task by :meth:`_maybe_trigger_drain` and stored
as ``_timer_task``; cancelled by :meth:`_start_drain_now` and
:meth:`stop`.
Returns:
None
"""
try:
await asyncio.sleep(self._flush_interval)
except asyncio.CancelledError:
return
self._start_drain_now()
async def _drain_loop(self) -> None:
"""Repeatedly claim and embed batches until the pending queue is empty.
The worker coroutine at the heart of the queue. Each iteration claims
up to ``max_batch_size`` items via :meth:`_claim_pending_batch` (which
atomically moves them to the in-flight ZSET) and exits when nothing is
left. Blank-text items are sunk inline: it writes a zero vector to the
item's hash, clears the retry/dedup hashes, publishes
``sg:embed:done:{redis_key}`` and removes the item from in-flight
(``ZREM``), all in one pipeline. The remaining real items are forwarded to
:meth:`_process_batch`, which performs the embedding API call and result
write-back. A cooperative cancellation (on :meth:`stop`) is caught and
logged so the task ends cleanly.
Spawned as ``_drain_task`` only by :meth:`_start_drain_now`; never called
directly.
Returns:
None
"""
try:
while True:
batch = await self._claim_pending_batch(self._max_batch_size)
if not batch:
return
filtered: list[_QueueItem] = []
for item in batch:
if not item.text or not item.text.strip():
logger.debug(
"Dropping blank-text item from embed queue: %s",
item.redis_key,
)
from gemini_embed_pool import EMBED_DIMENSIONS
zero_blob = _embed_to_bytes([0.0] * EMBED_DIMENSIONS)
try:
pipe = self._redis.pipeline()
pipe.hset(item.redis_key, "embedding", zero_blob)
pipe.hdel(RETRIES_HASH_KEY, item.redis_key)
pipe.hdel(DEDUP_HASH_KEY, item.redis_key)
pipe.publish(f"sg:embed:done:{item.redis_key}", "1")
pipe.zrem(
INFLIGHT_ZSET_KEY,
_serialize_item(item.redis_key, item.text),
)
await pipe.execute()
except Exception as e:
logger.warning(
"Failed to sink and publish blank text for %s: %s",
item.redis_key,
e,
)
continue
filtered.append(item)
if filtered:
await self._process_batch(filtered)
except asyncio.CancelledError:
logger.warning("EmbeddingBatchQueue drain loop cancelled")
async def _process_batch(self, batch: list[_QueueItem]) -> None:
"""Embed one claimed batch and write the resulting vectors back to Redis.
The compute-and-persist core invoked once per non-empty batch by
:meth:`_drain_loop`. It splits the batch into ``API_BATCH_LIMIT``-sized
chunks and embeds them concurrently through the ``_embed_chunk`` closure
(which calls ``OpenRouterClient._embed_gemini_batch`` — the only network
I/O), then in a single pipeline writes each ``embedding`` blob (``HSET``)
via :func:`_embed_to_bytes`, clears the retry/dedup hashes, and publishes
``sg:embed:done:{redis_key}`` so waiting :meth:`_await_pubsub_result`
tasks resolve. On success it releases the batch from in-flight via
:meth:`_release_inflight`. If either the API call or the write-back
fails, it routes the whole batch to :meth:`_requeue_items` for retry. It
also emits ok/error ``embed_queue_drain`` events through
``observability.publish_debug_event``.
Called only by :meth:`_drain_loop`.
Args:
batch (list[_QueueItem]): The claimed, non-blank items to embed.
Returns:
None
"""
t0 = time.monotonic()
total = len(batch)
logger.info(
"Flushing %d deferred embeddings (sub-batches of %d)",
total,
API_BATCH_LIMIT,
)
chunks = [
batch[start : start + API_BATCH_LIMIT]
for start in range(0, total, API_BATCH_LIMIT)
]
async def _embed_chunk(chunk: list[_QueueItem]) -> list[list[float]]:
"""Embed a single API-sized sub-batch via the OpenRouter Gemini path.
Helper closure used by :meth:`_process_batch` to fan a claimed batch
out into ``API_BATCH_LIMIT``-sized chunks that are embedded
concurrently with :func:`asyncio.gather`. It pulls the ``text`` field
off each :class:`_QueueItem`, preserving order, and calls
``self._openrouter._embed_gemini_batch`` (the only network I/O here);
the returned vectors line up positionally with the input items so the
caller can ``zip`` them back together. Any API exception propagates to
:meth:`_process_batch`, which re-queues the whole batch.
Defined and invoked only inside :meth:`_process_batch`; it has no
other callers in the module or repo.
Args:
chunk (list[_QueueItem]): A slice of the claimed batch no larger
than ``API_BATCH_LIMIT`` items.
Returns:
list[list[float]]: One embedding vector per item, in input order.
Raises:
Exception: Whatever ``_embed_gemini_batch`` raises on API failure,
propagated unchanged to the caller.
"""
texts = [item.text for item in chunk]
return await self._openrouter._embed_gemini_batch(
texts,
self._model,
)
try:
results = await asyncio.gather(
*(_embed_chunk(c) for c in chunks),
)
except Exception as e:
logger.warning(
"Batch embedding API call failed for %d sub-batches "
"of %d items — re-queuing all",
len(chunks),
total,
exc_info=True,
)
from observability import publish_debug_event
asyncio.create_task(
publish_debug_event(
"embed_queue_drain",
"embedding_queue",
status="error",
duration_ms=(time.monotonic() - t0) * 1000,
preview=f"processed={total} error_count={total}",
payload={
"batch_size": total,
"error_count": total,
"error": str(e),
},
),
name="obs_embed_drain_err",
)
await self._requeue_items(batch)
return
all_embeddings: list[list[float]] = []
for chunk_embeddings in results:
all_embeddings.extend(chunk_embeddings)
pipe = self._redis.pipeline()
for item, embedding in zip(batch, all_embeddings):
pipe.hset(item.redis_key, "embedding", _embed_to_bytes(embedding))
pipe.hdel(RETRIES_HASH_KEY, item.redis_key)
pipe.hdel(DEDUP_HASH_KEY, item.redis_key)
pipe.publish(f"sg:embed:done:{item.redis_key}", "1")
try:
await pipe.execute()
except Exception as e:
logger.warning(
"Failed to write %d embeddings back to Redis",
total,
exc_info=True,
)
from observability import publish_debug_event
asyncio.create_task(
publish_debug_event(
"embed_queue_drain",
"embedding_queue",
status="error",
duration_ms=(time.monotonic() - t0) * 1000,
preview=f"processed={total} error_count={total}",
payload={
"batch_size": total,
"error_count": total,
"error": str(e),
},
),
name="obs_embed_drain_err2",
)
await self._requeue_items(batch)
return
await self._release_inflight(batch)
logger.debug("Flushed %d embeddings to Redis", total)
from observability import publish_debug_event
asyncio.create_task(
publish_debug_event(
"embed_queue_drain",
"embedding_queue",
status="ok",
duration_ms=(time.monotonic() - t0) * 1000,
preview=f"processed={total} error_count=0",
payload={"batch_size": total, "error_count": 0},
),
name="obs_embed_drain_ok",
)
async def _requeue_items(self, items: list[_QueueItem]) -> None:
"""Apply the bounded-retry policy to a batch whose flush just failed.
The failure-handling path for :meth:`_process_batch`. It increments the
per-key counter in ``embed_queue:retries`` (``HINCRBY``) for every item,
then splits them: any key that has now reached ``MAX_EMBED_QUEUE_RETRIES``
is given up on (its hash gets a zero vector via :func:`_embed_to_bytes`,
the retry/dedup hashes are cleared, and ``sg:embed:done:{redis_key}`` is
published so waiters unblock with a usable-but-empty vector), while the
rest are removed from in-flight (``ZREM``) and added back onto
``embed_queue:pending`` (``ZADD``, preserving their original score for
ordering)
for another drain attempt. Each Redis stage is wrapped so a failure here
is logged rather than lost. Without this cap a permanently failing item
would loop forever.
Called by :meth:`_process_batch` on both API-call and write-back failure.
Args:
items (list[_QueueItem]): The items from the failed batch to retry or
permanently drop; a no-op when empty.
Returns:
None
"""
if not items:
return
from gemini_embed_pool import EMBED_DIMENSIONS
try:
pipe = self._redis.pipeline()
for item in items:
pipe.hincrby(RETRIES_HASH_KEY, item.redis_key, 1)
counts_raw = await pipe.execute()
except Exception:
logger.error(
"Failed to increment embed_queue retry counts for %d items — "
"re-queueing without cap",
len(items),
exc_info=True,
)
counts_raw = [0] * len(items)
counts: list[int] = []
for raw in counts_raw:
try:
counts.append(int(raw))
except (TypeError, ValueError):
counts.append(0)
to_requeue: list[_QueueItem] = []
to_drop: list[tuple[_QueueItem, int]] = []
for item, count in zip(items, counts):
if count >= MAX_EMBED_QUEUE_RETRIES:
to_drop.append((item, count))
else:
to_requeue.append(item)
if to_drop:
zero_blob = _embed_to_bytes([0.0] * EMBED_DIMENSIONS)
try:
pipe = self._redis.pipeline()
for item, _ in to_drop:
pipe.hset(item.redis_key, "embedding", zero_blob)
pipe.hdel(RETRIES_HASH_KEY, item.redis_key)
pipe.hdel(DEDUP_HASH_KEY, item.redis_key)
pipe.publish(f"sg:embed:done:{item.redis_key}", "1")
await pipe.execute()
except Exception:
logger.error(
"Failed to sink zero-vector for %d dropped embed_queue " "items",
len(to_drop),
exc_info=True,
)
for item, count in to_drop:
logger.warning(
"Dropping embed_queue item %s after %d failed attempts "
"(wrote zero vector)",
item.redis_key,
count,
)
if to_requeue:
mapping: dict[str, float] = {}
pipe = self._redis.pipeline()
for item in to_requeue:
member = _serialize_item(item.redis_key, item.text)
mapping[member] = item.score
pipe.zrem(INFLIGHT_ZSET_KEY, member)
try:
pipe.zadd(PENDING_ZSET_KEY, mapping)
await pipe.execute()
logger.info(
"Re-queued %d items for retry (dropped %d at cap)",
len(to_requeue),
len(to_drop),
)
except Exception:
logger.error(
"Failed to re-queue %d items to Redis",
len(to_requeue),
exc_info=True,
)