"""Per-channel message queue with batching.
Ensures temporal consistency: messages are processed in order per channel,
and rapid-succession messages can be collected into batches for a single
combined response.
"""
from __future__ import annotations
import asyncio
import base64
import json
import logging
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, Callable, Awaitable, Union
from platforms.base import IncomingMessage, Attachment
logger = logging.getLogger(__name__)
[docs]
@dataclass
class QueuedMessage:
"""One message captured while waiting in a channel's batching queue.
A normalized snapshot of an inbound message -- ``platform``, ``channel_id``,
``user_id``/``user_name`` and ``text`` -- stamped with ``queued_at`` and
carrying the original :class:`~platforms.base.IncomingMessage` in ``raw`` plus
any adapter-specific ``extra`` metadata. Created by :class:`MessageQueue` when
a message arrives and coalesced into a :class:`MessageBatch` before the
processor flushes a channel's accumulated traffic.
"""
platform: str
channel_id: str
user_id: str
user_name: str
text: str
queued_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
extra: dict[str, Any] = field(default_factory=dict)
raw: Any = None # Original IncomingMessage
[docs]
@dataclass
class MessageBatch:
"""A set of :class:`QueuedMessage` objects collected in a rolling window.
Groups rapid-succession messages on one channel so they are processed
together rather than one-at-a-time: ``messages`` holds the arrivals while
``first_at``/``last_at`` track the window that :meth:`add` slides forward on
each append. Held in memory by :class:`MessageQueue` and flushed once the
batch hits its size cap or the window's quiet period elapses.
"""
messages: list[QueuedMessage] = field(default_factory=list)
first_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
last_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
[docs]
def add(self, msg: QueuedMessage) -> None:
"""Append a message to the batch and extend its rolling window.
Records the new arrival, advancing ``last_at`` to now so the batch's
time window slides forward with each message; the first message also
sets ``first_at``. This is how rapid-succession messages get coalesced
into one batch before the in-memory processor flushes them.
Called by :meth:`MessageQueue.enqueue` (the in-memory batching path) for
each message that arrives while a batch is open for the channel.
Args:
msg (QueuedMessage): The freshly queued message to add to this
batch.
"""
self.messages.append(msg)
self.last_at = datetime.now(timezone.utc)
if len(self.messages) == 1:
self.first_at = self.last_at
@property
def size(self) -> int:
"""Number of messages currently in the batch.
Convenience accessor over the backing ``messages`` list, used to test
the batch against ``max_batch_size`` and to label log lines.
Read by :meth:`MessageQueue.enqueue` and
:meth:`MessageQueue._finalize_batch_unlocked` (size-limit checks and
logging) and by :meth:`MessageQueue._process_loop` when describing a
processed item.
Returns:
int: The count of messages held in this batch.
"""
return len(self.messages)
@property
def channel_id(self) -> str:
"""Channel identifier shared by every message in the batch.
A batch only ever holds messages for a single channel, so this returns
the channel of the first message. Used by :meth:`MessageQueue.enqueue_front`
to route a batch back to the correct per-channel queue.
Returns:
str: The fully-qualified channel id of the batch's messages.
Raises:
ValueError: If the batch is empty and therefore has no channel.
"""
if not self.messages:
raise ValueError("Empty batch has no channel_id")
return self.messages[0].channel_id
[docs]
def unique_authors(self) -> list[str]:
"""List the distinct author ids in the batch, preserving first-seen order.
Walks the batched messages once, de-duplicating by ``user_id`` while
keeping the order in which each author first appears, so a multi-author
batch can be summarised without losing conversational ordering. Pure
computation with no side effects.
Defined as a helper on the batch; no callers were found in the repo, so
it is available for downstream consumers (e.g. prompt-context assembly)
but currently unused by core paths.
Returns:
list[str]: The unique author ids, in the order they first sent a
message in this batch.
"""
seen: set[str] = set()
out: list[str] = []
for m in self.messages:
if m.user_id not in seen:
seen.add(m.user_id)
out.append(m.user_id)
return out
QueueItem = Union[QueuedMessage, MessageBatch]
ProcessorCallback = Callable[[QueueItem], Awaitable[None]]
def _serialize_incoming_message(msg: IncomingMessage) -> dict[str, Any]:
"""Convert an :class:`~platforms.base.IncomingMessage` to a JSON-safe dict.
Flattens every field of the platform message into primitives so it can be
embedded in a Redis-bound queue payload; binary attachment bytes are
base64-encoded and the timestamp is rendered as an ISO-8601 string.
Called only by :func:`_serialize_extra`, which uses it to preserve a raw
``IncomingMessage`` stashed inside a :class:`QueuedMessage`'s ``extra``
map across the Redis round-trip. The inverse is
:func:`_deserialize_incoming_message`.
Args:
msg (IncomingMessage): The platform message to serialize.
Returns:
dict[str, Any]: A JSON-serializable mapping of the message's fields,
with attachment data under ``data_b64`` and ``timestamp`` as ISO text.
"""
return {
"platform": msg.platform,
"channel_id": msg.channel_id,
"user_id": msg.user_id,
"user_name": msg.user_name,
"text": msg.text,
"is_addressed": msg.is_addressed,
"attachments": [
{
"data_b64": base64.b64encode(att.data).decode("utf-8"),
"mimetype": att.mimetype,
"filename": att.filename,
"source_url": att.source_url,
}
for att in msg.attachments
],
"channel_name": msg.channel_name,
"timestamp": msg.timestamp.isoformat(),
"message_id": msg.message_id,
"reply_to_id": msg.reply_to_id,
"extra": msg.extra,
"reactions": msg.reactions,
"unified_user_id": msg.unified_user_id,
"user_aliases": msg.user_aliases,
}
def _deserialize_incoming_message(data: dict[str, Any]) -> IncomingMessage:
"""Reconstruct an :class:`~platforms.base.IncomingMessage` from a dict.
Inverse of :func:`_serialize_incoming_message`: decodes base64 attachment
bytes back into :class:`~platforms.base.Attachment` objects, parses the
ISO timestamp, and supplies defaults for any optional fields missing from
older payloads.
Called only by :func:`_deserialize_extra` when an ``incoming_message`` key
is found in a deserialized ``extra`` map.
Args:
data (dict[str, Any]): The mapping previously produced by
:func:`_serialize_incoming_message`.
Returns:
IncomingMessage: The rehydrated platform message, including any
decoded attachments.
"""
attachments = []
for att_data in data.get("attachments", []):
attachments.append(
Attachment(
data=base64.b64decode(att_data["data_b64"]),
mimetype=att_data["mimetype"],
filename=att_data["filename"],
source_url=att_data.get("source_url", ""),
)
)
ts_str = data.get("timestamp")
ts = datetime.fromisoformat(ts_str) if ts_str else datetime.now(timezone.utc)
return IncomingMessage(
platform=data["platform"],
channel_id=data["channel_id"],
user_id=data["user_id"],
user_name=data["user_name"],
text=data["text"],
is_addressed=data["is_addressed"],
attachments=attachments,
channel_name=data.get("channel_name", ""),
timestamp=ts,
message_id=data.get("message_id", ""),
reply_to_id=data.get("reply_to_id", ""),
extra=data.get("extra", {}),
reactions=data.get("reactions", ""),
unified_user_id=data.get("unified_user_id"),
user_aliases=data.get("user_aliases", []),
)
def _serialize_extra(extra: dict[str, Any]) -> dict[str, Any]:
"""Make a :class:`QueuedMessage`'s ``extra`` map JSON-serializable.
Copies the mapping, replacing any nested :class:`~platforms.base.IncomingMessage`
value (typically stored under the ``incoming_message`` key) with its
serialized dict form via :func:`_serialize_incoming_message`; all other
values are passed through unchanged on the assumption they are already
JSON-safe.
Called by :func:`_serialize_item` for both standalone messages and each
message inside a batch. The inverse is :func:`_deserialize_extra`.
Args:
extra (dict[str, Any]): The ``extra`` mapping attached to a queued
message.
Returns:
dict[str, Any]: A shallow copy with ``IncomingMessage`` values
serialized to dicts.
"""
out = {}
for k, v in extra.items():
if isinstance(v, IncomingMessage):
out[k] = _serialize_incoming_message(v)
else:
out[k] = v
return out
def _deserialize_extra(extra: dict[str, Any]) -> dict[str, Any]:
"""Rehydrate an ``extra`` map produced by :func:`_serialize_extra`.
Copies the mapping, turning the ``incoming_message`` entry (when present
as a dict) back into an :class:`~platforms.base.IncomingMessage` via
:func:`_deserialize_incoming_message`; all other keys pass through
unchanged.
Called by :func:`_deserialize_item` while reconstructing a message or each
message of a batch.
Args:
extra (dict[str, Any]): The serialized ``extra`` mapping read back from
the queue payload.
Returns:
dict[str, Any]: A shallow copy with ``incoming_message`` rebuilt as an
:class:`~platforms.base.IncomingMessage` when applicable.
"""
out = {}
for k, v in extra.items():
if k == "incoming_message" and isinstance(v, dict):
out[k] = _deserialize_incoming_message(v)
else:
out[k] = v
return out
def _serialize_item(item: QueueItem) -> str:
"""Encode a queue item (single message or batch) as a JSON string.
Tags the payload with a ``type`` discriminator (``"message"`` or
``"batch"``) so :func:`_deserialize_item` can reconstruct the right object,
flattens datetimes to ISO strings, and serializes each message's ``extra``
map via :func:`_serialize_extra`. This is the on-the-wire format pushed
into Redis lists.
Called by :meth:`RedisQueue.put` and :meth:`RedisQueue.enqueue_front` to
produce the value stored in the per-channel Redis list. (A same-named but
unrelated helper exists in ``embedding_queue.py``.)
Args:
item (QueueItem): A :class:`QueuedMessage` or :class:`MessageBatch` to
encode.
Returns:
str: The JSON-encoded representation of *item*.
Raises:
ValueError: If *item* is neither a :class:`QueuedMessage` nor a
:class:`MessageBatch`.
"""
if isinstance(item, QueuedMessage):
data = {
"type": "message",
"platform": item.platform,
"channel_id": item.channel_id,
"user_id": item.user_id,
"user_name": item.user_name,
"text": item.text,
"queued_at": item.queued_at.isoformat(),
"extra": _serialize_extra(item.extra),
}
elif isinstance(item, MessageBatch):
data = {
"type": "batch",
"messages": [
{
"platform": msg.platform,
"channel_id": msg.channel_id,
"user_id": msg.user_id,
"user_name": msg.user_name,
"text": msg.text,
"queued_at": msg.queued_at.isoformat(),
"extra": _serialize_extra(msg.extra),
}
for msg in item.messages
],
"first_at": item.first_at.isoformat(),
"last_at": item.last_at.isoformat(),
}
else:
raise ValueError(f"Unknown item type: {type(item)}")
return json.dumps(data)
def _deserialize_item(serialized: str) -> QueueItem:
"""Decode a JSON string back into a :class:`QueuedMessage` or batch.
Inverse of :func:`_serialize_item`: branches on the ``type`` discriminator,
parses ISO timestamps, rehydrates each message's ``extra`` via
:func:`_deserialize_extra`, and re-exposes any embedded
:class:`~platforms.base.IncomingMessage` as the ``raw`` field of the
resulting :class:`QueuedMessage`.
Called by :meth:`RedisQueue.get` to turn a value popped from the Redis list
back into a queue item. (A same-named but unrelated helper exists in
``embedding_queue.py``.)
Args:
serialized (str): A JSON string previously produced by
:func:`_serialize_item`.
Returns:
QueueItem: The decoded :class:`QueuedMessage` or :class:`MessageBatch`.
Raises:
ValueError: If the payload's ``type`` field is neither ``"message"``
nor ``"batch"``.
"""
data = json.loads(serialized)
if data["type"] == "message":
extra = _deserialize_extra(data.get("extra", {}))
queued_at_str = data.get("queued_at")
queued_at = datetime.fromisoformat(queued_at_str) if queued_at_str else datetime.now(timezone.utc)
raw = extra.get("incoming_message")
return QueuedMessage(
platform=data["platform"],
channel_id=data["channel_id"],
user_id=data["user_id"],
user_name=data["user_name"],
text=data["text"],
queued_at=queued_at,
extra=extra,
raw=raw,
)
elif data["type"] == "batch":
batch = MessageBatch()
first_at_str = data.get("first_at")
last_at_str = data.get("last_at")
if first_at_str:
batch.first_at = datetime.fromisoformat(first_at_str)
if last_at_str:
batch.last_at = datetime.fromisoformat(last_at_str)
for msg_data in data.get("messages", []):
extra = _deserialize_extra(msg_data.get("extra", {}))
queued_at_str = msg_data.get("queued_at")
queued_at = datetime.fromisoformat(queued_at_str) if queued_at_str else datetime.now(timezone.utc)
raw = extra.get("incoming_message")
qmsg = QueuedMessage(
platform=msg_data["platform"],
channel_id=msg_data["channel_id"],
user_id=msg_data["user_id"],
user_name=msg_data["user_name"],
text=msg_data["text"],
queued_at=queued_at,
extra=extra,
raw=raw,
)
batch.messages.append(qmsg)
return batch
else:
raise ValueError(f"Unknown serialization type: {data.get('type')}")
[docs]
class RedisQueue:
"""Redis-list-backed FIFO queue for one channel's queue items.
Serializes :class:`QueuedMessage`/:class:`MessageBatch` items into a single
Redis list keyed ``message_queue:redis:{channel_key}`` so the per-channel
processor loop can pop them in order even across process restarts. Used by
:class:`MessageQueue` (via :meth:`MessageQueue._queue`) whenever both a
Redis client and an event bus are configured; otherwise
:class:`InMemoryRedisQueue` is substituted. Both classes expose the same
interface so the processor loop is agnostic to which is in play.
"""
[docs]
def __init__(self, redis: Any, channel_key: str) -> None:
"""Bind the queue to a Redis client and a channel.
Computes the Redis list key (``message_queue:redis:{channel_key}``)
used by every other method. Instantiated by :meth:`MessageQueue._queue`.
Args:
redis (Any): An async Redis client (e.g. ``redis.asyncio.Redis``).
channel_key (str): The fully-qualified channel identifier this
queue serves, used to namespace the backing list key.
"""
self.redis = redis
self.channel_key = channel_key
self.key = f"message_queue:redis:{channel_key}"
[docs]
async def put(self, item: QueueItem) -> None:
"""Append a queue item to the tail of the channel's Redis list.
Serializes *item* with :func:`_serialize_item` and ``RPUSH``-es it onto
the backing list. Called by :meth:`MessageQueue._finalize_batch_unlocked`
when an in-memory-batched item is committed to the queue.
Args:
item (QueueItem): The message or batch to enqueue.
"""
serialized = _serialize_item(item)
await self.redis.rpush(self.key, serialized)
[docs]
async def enqueue_front(self, item: QueueItem) -> None:
"""Push a queue item onto the head of the channel's Redis list.
Serializes *item* with :func:`_serialize_item` and ``LPUSH``-es it so it
will be consumed before any items already queued. Called by
:meth:`MessageQueue.enqueue_front` to give an item priority.
Args:
item (QueueItem): The message or batch to enqueue at the front.
"""
serialized = _serialize_item(item)
await self.redis.lpush(self.key, serialized)
[docs]
async def get(self) -> QueueItem:
"""Block until an item is available, then pop and decode it.
Performs a blocking ``BLPOP`` (no timeout) against the backing list and
decodes the popped value with :func:`_deserialize_item`. Called by
:meth:`MessageQueue._process_loop` (wrapped in a 60 s
:func:`asyncio.wait_for`).
Returns:
QueueItem: The next :class:`QueuedMessage` or :class:`MessageBatch`.
Raises:
RuntimeError: If ``BLPOP`` returns ``None`` (should not occur with
an infinite timeout, guarding against a misbehaving client).
"""
res = await self.redis.blpop(self.key, timeout=0)
if res:
val = res[1]
if isinstance(val, bytes):
val = val.decode("utf-8")
return _deserialize_item(val)
raise RuntimeError("Queue empty or blpop returned None")
[docs]
async def qsize(self) -> int:
"""Return the number of items currently in the channel's Redis list.
Issues an ``LLEN`` against the backing list, returning ``0`` when no
Redis client is bound.
Returns:
int: The current list length, or ``0`` if Redis is unavailable.
"""
if self.redis is None:
return 0
return await self.redis.llen(self.key)
[docs]
async def empty(self) -> bool:
"""Report whether the channel's Redis list has no items.
Returns:
bool: ``True`` if the backing list is empty (or no Redis client is
bound), ``False`` otherwise.
"""
if self.redis is None:
return True
length = await self.redis.llen(self.key)
return length == 0
[docs]
async def clear(self) -> int:
"""Delete the channel's Redis list, discarding all queued items.
Reads the length first so the count of removed items can be reported,
then ``DEL``-etes the backing key. Called by :meth:`MessageQueue.clear`.
Returns:
int: The number of items that were in the list before deletion
(``0`` if no Redis client is bound).
"""
if self.redis is None:
return 0
count = await self.redis.llen(self.key)
await self.redis.delete(self.key)
return count
[docs]
def task_done(self) -> None:
"""No-op completion hook mirroring :class:`asyncio.Queue.task_done`.
Exists so :meth:`MessageQueue._process_loop` can call ``task_done()``
uniformly regardless of whether the channel is backed by a Redis or an
in-memory queue; the Redis variant has nothing to track.
"""
pass
[docs]
class InMemoryRedisQueue:
"""In-process fallback queue with the same interface as :class:`RedisQueue`.
Wraps a plain :class:`asyncio.Queue` so :class:`MessageQueue` can operate
without Redis (standalone / single-process mode). Selected by
:meth:`MessageQueue._queue` when no Redis client and event bus are
configured; items live only for the lifetime of the process. Note that
:meth:`enqueue_front` and :meth:`clear` reach into the underlying queue's
private attributes to achieve head-insertion and draining.
"""
[docs]
def __init__(self) -> None:
"""Create the backing in-memory :class:`asyncio.Queue`.
Allocates the single unbounded :class:`asyncio.Queue` (``self.q``) that
backs this fallback queue; :meth:`put`, :meth:`get`, :meth:`enqueue_front`
and :meth:`clear` all operate on it. Called by :meth:`MessageQueue._queue`
when no Redis client/event bus is configured (standalone single-process
mode). No I/O.
"""
self.q = asyncio.Queue()
[docs]
async def put(self, item: QueueItem) -> None:
"""Append a queue item to the tail of the in-memory queue.
Mirrors :meth:`RedisQueue.put`; called by
:meth:`MessageQueue._finalize_batch_unlocked`.
Args:
item (QueueItem): The message or batch to enqueue.
"""
await self.q.put(item)
[docs]
async def enqueue_front(self, item: QueueItem) -> None:
"""Insert a queue item at the head of the in-memory queue.
Since :class:`asyncio.Queue` has no native front-insert, this reaches
into its private ``_queue`` deque to ``appendleft`` the item, clears the
``_finished`` event, and wakes a waiting getter if one is parked so the
item is consumed before existing entries. Mirrors
:meth:`RedisQueue.enqueue_front`; called by
:meth:`MessageQueue.enqueue_front`.
Args:
item (QueueItem): The message or batch to enqueue at the front.
"""
self.q._queue.appendleft(item)
self.q._finished.clear()
if self.q._getters:
getter = self.q._getters.popleft()
if not getter.done():
getter.set_result(None)
[docs]
async def get(self) -> QueueItem:
"""Block until an item is available, then pop and return it.
Mirrors :meth:`RedisQueue.get`; called by
:meth:`MessageQueue._process_loop`.
Returns:
QueueItem: The next :class:`QueuedMessage` or :class:`MessageBatch`.
"""
return await self.q.get()
[docs]
async def qsize(self) -> int:
"""Return the number of items currently buffered.
Returns:
int: The current size of the underlying :class:`asyncio.Queue`.
"""
return self.q.qsize()
[docs]
async def empty(self) -> bool:
"""Report whether the in-memory queue has no items.
Returns:
bool: ``True`` if the underlying queue is empty, ``False`` otherwise.
"""
return self.q.empty()
[docs]
async def clear(self) -> int:
"""Drain every buffered item, discarding them.
Repeatedly ``get_nowait``/``task_done`` until empty, counting the
removed items. Mirrors :meth:`RedisQueue.clear`; called by
:meth:`MessageQueue.clear`.
Returns:
int: The number of items removed from the queue.
"""
count = 0
while not self.q.empty():
try:
self.q.get_nowait()
self.q.task_done()
count += 1
except asyncio.QueueEmpty:
break
return count
[docs]
def task_done(self) -> None:
"""Mark a previously retrieved item as processed.
Delegates to :meth:`asyncio.Queue.task_done` so callers of
:meth:`get` keep the underlying queue's unfinished-task count
balanced. Called by :meth:`MessageQueue._process_loop`.
"""
self.q.task_done()
[docs]
class MessageQueue:
"""Per-channel queue that batches rapid-succession messages.
Operates in two modes:
- **In-memory** (no *redis*/*event_bus*): batches are buffered locally
and handed to the processor loop started via :meth:`start_processing`.
- **Distributed** (both *redis* and *event_bus* supplied): batching
state lives in Redis and finalised batches are published to the
cross-service inbound stream via ``event_bus.publish_inbound``; the
local processor loop is not used for delivery in this mode.
Parameters
----------
default_batch_window:
Seconds to wait for additional messages before finalising a batch.
max_batch_size:
Maximum messages per batch before immediate finalisation.
redis:
Optional async Redis client. When provided together with
*event_bus*, enables the distributed batching path (Redis-backed
batch state); otherwise consulted only for per-channel batch
config (window override / disable flag).
event_bus:
Optional event bus used to publish finalised batches to the
cross-service inbound stream. Required (with *redis*) for the
distributed path.
"""
[docs]
def __init__(
self,
default_batch_window: float = 5.0,
max_batch_size: int = 10,
redis: Any = None,
event_bus: Any = None,
) -> None:
"""Set up per-channel batching state and, in distributed mode, the scheduler.
Allocates the per-channel bookkeeping dicts (queues, locks, active
batches, timers, processor tasks, lengths, cancellation flags) and stores
the batching tunables and optional Redis/event-bus handles. When both
*redis* and *event_bus* are supplied the constructor immediately starts
the :meth:`_batch_scheduler_loop` background :class:`asyncio.Task` that
drives the distributed sliding-window flush; otherwise the queue runs in
local in-memory mode.
Instantiated once per process by the owning runner/gateway and exercised
directly by the batching tests; requires a running event loop in
distributed mode because of the scheduler task it creates.
Args:
default_batch_window (float): Seconds to wait for additional messages
before finalising a batch (the default sliding-window length).
max_batch_size (int): Maximum messages allowed in one batch before it
is flushed immediately regardless of the window.
redis (Any): Optional async Redis client; together with *event_bus*
it enables the distributed Redis-backed batching path.
event_bus (Any): Optional event bus whose ``publish_inbound`` delivers
finalised batches to the cross-service inbound stream.
"""
self._queues: dict[str, Union[RedisQueue, InMemoryRedisQueue]] = {}
self._locks: dict[str, asyncio.Lock] = {}
self._processors: dict[str, asyncio.Task] = {}
self._processing: dict[str, bool] = {}
self._active_batches: dict[str, MessageBatch] = {}
self._batch_timers: dict[str, asyncio.Task] = {}
self._batch_locks: dict[str, asyncio.Lock] = {}
self._lengths: dict[str, int] = {}
# Per-channel task tracking for cancellation support.
self._current_tasks: dict[str, asyncio.Task] = {}
self._stop_requested: dict[str, bool] = {}
self.default_batch_window = default_batch_window
self.max_batch_size = max_batch_size
self.redis = redis
self.event_bus = event_bus
self._restore_attempts: dict[str, int] = {}
# Register _batch_scheduler_loop as an asyncio task on startup in distributed mode
if self.redis is not None and self.event_bus is not None:
self._scheduler_task = asyncio.create_task(self._batch_scheduler_loop())
[docs]
def set_event_bus(self, event_bus: Any) -> None:
"""Attach an event bus after construction, lazily starting the scheduler.
Lets a queue created before the event bus existed be upgraded to the
distributed path. Once both a Redis client and the event bus are present
and no scheduler is running yet, it spawns the :meth:`_batch_scheduler_loop`
:class:`asyncio.Task` so Redis-backed batches begin flushing to the
inbound stream.
Provided for the runner/gateway wiring step; no in-repo callers were
found, so it is invoked from the owning process's startup sequence.
Args:
event_bus (Any): The event bus whose ``publish_inbound`` will deliver
finalised batches to the cross-service inbound stream.
"""
self.event_bus = event_bus
if self.redis is not None and self.event_bus is not None and not hasattr(self, "_scheduler_task"):
self._scheduler_task = asyncio.create_task(self._batch_scheduler_loop())
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _queue(self, channel: str) -> Union[RedisQueue, InMemoryRedisQueue]:
"""Get (or lazily create) the backing FIFO queue for a channel.
The first time a channel is seen this picks the backend — a
:class:`RedisQueue` when both a Redis client and event bus are
configured, otherwise an in-process :class:`InMemoryRedisQueue` — and
also initialises that channel's processor lock and processing flag.
Subsequent calls return the cached queue.
Called throughout :class:`MessageQueue` (e.g. :meth:`enqueue`,
:meth:`enqueue_front`, :meth:`start_processing`, :meth:`clear`,
:meth:`_finalize_batch_unlocked`, :meth:`_process_loop`) wherever a
channel's queue is needed.
Args:
channel (str): The fully-qualified channel key to resolve.
Returns:
Union[RedisQueue, InMemoryRedisQueue]: The queue serving *channel*.
"""
if channel not in self._queues:
if self.redis is not None and self.event_bus is not None:
self._queues[channel] = RedisQueue(self.redis, channel)
else:
self._queues[channel] = InMemoryRedisQueue()
self._locks[channel] = asyncio.Lock()
self._processing[channel] = False
return self._queues[channel]
def _batch_lock(self, channel: str) -> asyncio.Lock:
"""Get (or lazily create) the per-channel lock guarding batch mutation.
Serialises access to a channel's in-memory active batch so concurrent
:meth:`enqueue` calls and the timer-driven flush cannot race. Lazily
allocates the :class:`asyncio.Lock` on first use and caches it in
``_batch_locks``.
Called by :meth:`enqueue`, :meth:`_finalize_batch`, and the in-memory
batching path; only relevant to the local (non-Redis) batching mode.
Args:
channel (str): The channel key whose batch lock is needed.
Returns:
asyncio.Lock: The lock protecting *channel*'s active batch.
"""
if channel not in self._batch_locks:
self._batch_locks[channel] = asyncio.Lock()
return self._batch_locks[channel]
async def _batch_window_for(self, channel: str) -> float:
"""Resolve the effective batching window for a channel.
Consults Redis for per-channel overrides: a set
``message_batching_disabled:{channel}`` key forces a ``0.0`` window
(batching off, messages flushed immediately), and a
``message_batch_window:{channel}`` key supplies a custom window in
seconds. Any Redis error is logged at debug level and the configured
``default_batch_window`` is used as a safe fallback.
Called by :meth:`enqueue` at the start of every message to decide
whether and how long to batch.
Args:
channel (str): The channel key whose window is being resolved.
Returns:
float: The window in seconds (``0.0`` when batching is disabled for
the channel, else the override or ``default_batch_window``).
"""
if self.redis is not None:
try:
if await self.redis.exists(f"message_batching_disabled:{channel}"):
return 0.0
raw = await self.redis.get(f"message_batch_window:{channel}")
if raw:
return float(raw)
except Exception:
logger.debug("Redis batch-window lookup failed for %s", channel)
return self.default_batch_window
async def _finalize_batch(self, channel: str) -> None:
"""Lock-protected wrapper that finalises a channel's in-memory batch.
Acquires the per-channel batch lock (via :meth:`_batch_lock`) and then
delegates to :meth:`_finalize_batch_unlocked`, ensuring the flush cannot
race a concurrent :meth:`enqueue`. Only used in the local in-memory
batching mode.
Called by the per-batch timer's ``_fire`` closure created in
:meth:`_start_timer` when the window elapses.
Args:
channel (str): The channel whose active batch should be finalised.
"""
lock = self._batch_lock(channel)
async with lock:
await self._finalize_batch_unlocked(channel)
async def _finalize_batch_unlocked(self, channel: str) -> None:
"""Commit a channel's open in-memory batch to its FIFO queue.
Pops the channel's active :class:`MessageBatch`, cancels and awaits any
pending flush timer, then pushes the batch onto the channel's backing
queue via :meth:`_queue` and ``q.put`` and bumps the cached length so
the processor loop will pick it up. Assumes the caller already holds the
per-channel batch lock (it is the unlocked core of :meth:`_finalize_batch`).
Logs the finalised batch size. No-ops if no batch is open for the
channel.
Called by :meth:`_finalize_batch` (timer-driven) and directly by
:meth:`enqueue` when a batch reaches ``max_batch_size``.
Args:
channel (str): The channel whose active batch should be flushed.
"""
if channel not in self._active_batches:
return
batch = self._active_batches.pop(channel)
timer = self._batch_timers.pop(channel, None)
if timer is not None and not timer.done():
timer.cancel()
try:
await timer
except asyncio.CancelledError:
pass
q = self._queue(channel)
await q.put(batch)
self._lengths[channel] = self._lengths.get(channel, 0) + 1
logger.info(
"Finalised batch for %s with %d message(s)",
channel,
batch.size,
)
async def _start_timer(self, channel: str, window: float) -> None:
"""(Re)arm the sliding-window flush timer for a channel's in-memory batch.
Cancels and awaits any existing timer for the channel, then schedules a
new :class:`asyncio.Task` that sleeps *window* seconds and calls
:meth:`_finalize_batch`. Re-arming on every message is what makes the
window slide forward while messages keep arriving. Local in-memory mode
only.
Called by :meth:`enqueue` each time a message is added to (or starts) a
channel's active batch.
Args:
channel (str): The channel whose flush timer is being armed.
window (float): Seconds to wait before finalising the batch.
"""
old = self._batch_timers.get(channel)
if old is not None and not old.done():
old.cancel()
try:
await old
except asyncio.CancelledError:
pass
async def _fire() -> None:
"""Sleep for the window, then finalise the channel's batch.
Body of the timer task scheduled by the enclosing
:meth:`_start_timer`; swallows :class:`asyncio.CancelledError` so a
re-arm (timer replacement) is silent. No callers other than the
``create_task`` below.
"""
try:
await asyncio.sleep(window)
await self._finalize_batch(channel)
except asyncio.CancelledError:
pass
self._batch_timers[channel] = asyncio.create_task(_fire())
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
[docs]
async def enqueue(self, msg: QueuedMessage) -> None:
"""Accept a message into the channel's batching window.
In distributed mode (``redis`` and ``event_bus`` set), the message is
appended to the Redis-backed batch and ultimately published to the
cross-service inbound stream via ``event_bus.publish_inbound`` (either
immediately when batching is disabled or once the window/size limit is
reached). In the in-memory mode it is buffered in a local
:class:`MessageBatch` for the processor loop.
"""
channel = msg.channel_id
window = await self._batch_window_for(channel)
if window <= 0:
if self.redis is not None and self.event_bus is not None:
# Bypass batching and publish directly to inbound stream
raw_msg = msg.raw
attachments_list = []
if raw_msg:
for att in getattr(raw_msg, "attachments", []):
attachments_list.append({
"data": getattr(att, "data", b""),
"mimetype": getattr(att, "mimetype", ""),
"filename": getattr(att, "filename", ""),
"source_url": getattr(att, "source_url", ""),
})
envelope = {
"type": "message",
"channel_key": msg.channel_id,
"platform": msg.platform,
"channel_id": msg.channel_id.split(":")[-1] if ":" in msg.channel_id else msg.channel_id,
"user_id": msg.user_id,
"username": msg.user_name,
"display_name": msg.extra.get("display_name") or msg.user_name,
"content": msg.text or "",
"message_id": msg.extra.get("incoming_message").message_id if (msg.extra and "incoming_message" in msg.extra) else "",
"timestamp": msg.queued_at.timestamp(),
"attachments": attachments_list,
"reply_to": msg.extra.get("incoming_message").reply_to_id if (msg.extra and "incoming_message" in msg.extra) else None,
"embeds": msg.extra.get("incoming_message").extra.get("embeds") if (msg.extra and "incoming_message" in msg.extra and hasattr(msg.extra["incoming_message"], "extra")) else None,
"room_name": msg.extra.get("incoming_message").channel_name if (msg.extra and "incoming_message" in msg.extra) else None,
"is_dm": msg.extra.get("incoming_message").extra.get("is_dm") if (msg.extra and "incoming_message" in msg.extra and hasattr(msg.extra["incoming_message"], "extra")) else False,
"guild_id": msg.extra.get("incoming_message").extra.get("guild_id") if (msg.extra and "incoming_message" in msg.extra and hasattr(msg.extra["incoming_message"], "extra")) else None,
"member_roles": msg.extra.get("incoming_message").extra.get("member_roles") if (msg.extra and "incoming_message" in msg.extra and hasattr(msg.extra["incoming_message"], "extra")) else [],
"trace_id": msg.extra.get("incoming_message").extra.get("trace_id") if (msg.extra and "incoming_message" in msg.extra and hasattr(msg.extra["incoming_message"], "extra")) else "",
"enqueued_at": msg.queued_at.timestamp(),
"is_addressed": getattr(raw_msg, "is_addressed", True) if raw_msg else True,
}
await self.event_bus.publish_inbound(envelope)
else:
q = self._queue(channel)
await q.put(msg)
self._lengths[channel] = self._lengths.get(channel, 0) + 1
return
if self.redis is not None and self.event_bus is not None:
# Distributed temporal batching path
raw_msg = msg.raw
attachments_list = []
if raw_msg:
cached_key = msg.extra.get("cached_message_key") if msg.extra else None
for att in getattr(raw_msg, "attachments", []):
entry = {
"mimetype": getattr(att, "mimetype", ""),
"filename": getattr(att, "filename", ""),
"source_url": getattr(att, "source_url", ""),
}
if cached_key:
entry["cached_message_key"] = cached_key
else:
att_data = getattr(att, "data", b"")
if isinstance(att_data, bytes) and att_data:
import base64
entry["data"] = base64.b64encode(att_data).decode("utf-8")
entry["is_b64"] = True
attachments_list.append(entry)
msg_dict = {
"platform": msg.platform,
"channel_id": msg.channel_id.split(":")[-1] if ":" in msg.channel_id else msg.channel_id,
"user_id": msg.user_id,
"username": msg.user_name,
"display_name": msg.extra.get("display_name") or msg.user_name,
"content": msg.text or "",
"message_id": msg.extra.get("incoming_message").message_id if (msg.extra and "incoming_message" in msg.extra) else "",
"timestamp": msg.queued_at.timestamp(),
"attachments": attachments_list,
"reply_to": msg.extra.get("incoming_message").reply_to_id if (msg.extra and "incoming_message" in msg.extra) else None,
"embeds": msg.extra.get("incoming_message").extra.get("embeds") if (msg.extra and "incoming_message" in msg.extra and hasattr(msg.extra["incoming_message"], "extra")) else None,
"room_name": msg.extra.get("incoming_message").channel_name if (msg.extra and "incoming_message" in msg.extra) else None,
"is_dm": msg.extra.get("incoming_message").extra.get("is_dm") if (msg.extra and "incoming_message" in msg.extra and hasattr(msg.extra["incoming_message"], "extra")) else False,
"guild_id": msg.extra.get("incoming_message").extra.get("guild_id") if (msg.extra and "incoming_message" in msg.extra and hasattr(msg.extra["incoming_message"], "extra")) else None,
"member_roles": msg.extra.get("incoming_message").extra.get("member_roles") if (msg.extra and "incoming_message" in msg.extra and hasattr(msg.extra["incoming_message"], "extra")) else [],
"trace_id": msg.extra.get("incoming_message").extra.get("trace_id") if (msg.extra and "incoming_message" in msg.extra and hasattr(msg.extra["incoming_message"], "extra")) else "",
"enqueued_at": msg.queued_at.timestamp(),
"is_addressed": getattr(raw_msg, "is_addressed", True) if raw_msg else True,
}
msg_json = json.dumps(msg_dict)
import time
execution_ts = time.time() + window
await self._batch_append_redis(channel, msg_json, execution_ts)
# Check max batch size to trigger immediate finalization
h = await self.redis.hgetall(f"sg:batch:active:{channel}")
if h:
h_str = {}
for k, v in h.items():
k_str = k.decode("utf-8") if isinstance(k, bytes) else k
v_str = v.decode("utf-8") if isinstance(v, bytes) else v
h_str[k_str] = v_str
try:
messages = json.loads(h_str.get("messages", "[]"))
except Exception:
messages = []
if len(messages) >= self.max_batch_size:
await self._force_drain_channel_batch(
channel,
now=execution_ts + 1.0,
)
return
lock = self._batch_lock(channel)
async with lock:
if channel in self._active_batches:
batch = self._active_batches[channel]
batch.add(msg)
if batch.size >= self.max_batch_size:
await self._finalize_batch_unlocked(channel)
else:
await self._start_timer(channel, window)
else:
batch = MessageBatch()
batch.add(msg)
self._active_batches[channel] = batch
await self._start_timer(channel, window)
async def _batch_append_redis(
self,
channel_key: str,
message_json: str,
execution_ts: float,
) -> None:
"""Atomically append a message to a channel's Redis batch and arm its timer.
Runs a Lua script (loaded from ``atomic_batch_append.lua`` on disk, or a
built-in fallback) that appends *message_json* to the JSON array stored
under ``sg:batch:active:{channel_key}`` (setting ``first_at`` on the
first message), refreshes that key's TTL, and upserts the channel's
execution time into the ``sg:batch:timers`` sorted set so the scheduler
will flush it. The whole append-plus-schedule is atomic, so concurrent
appends from multiple workers cannot corrupt the array or lose the timer.
No-ops without a Redis client.
Called by :meth:`enqueue` on the distributed batching path and by
:meth:`_restore_batch` when re-queuing messages after a failed publish.
Args:
channel_key (str): The channel whose Redis batch is being appended to.
message_json (str): The already-serialized message object to append.
execution_ts (float): Unix time at which the batch should next be
eligible to flush (the sliding-window deadline).
"""
if self.redis is None:
return
batch_key = f"sg:batch:active:{channel_key}"
timer_key = "sg:batch:timers"
batch_ttl_ms = int((self.default_batch_window + 120.0) * 1000)
if not hasattr(self, "_lua_script"):
try:
with open("atomic_batch_append.lua", "r") as f:
self._lua_script = f.read()
except Exception:
self._lua_script = """local batch_key = KEYS[1]
local timer_key = KEYS[2]
local channel_key = ARGV[1]
local message_json = ARGV[2]
local execution_ts = tonumber(ARGV[3])
local batch_ttl_ms = tonumber(ARGV[4])
local current = redis.call('HGET', batch_key, 'messages')
if current then
local trimmed = string.sub(current, 1, -2)
redis.call('HSET', batch_key, 'messages', trimmed .. ',' .. message_json .. ']')
else
redis.call('HSET', batch_key, 'messages', '[' .. message_json .. ']')
redis.call('HSET', batch_key, 'first_at', tostring(execution_ts - 1))
end
redis.call('ZADD', timer_key, execution_ts, channel_key)
if batch_ttl_ms and batch_ttl_ms > 0 then
redis.call('PEXPIRE', batch_key, batch_ttl_ms)
end
return 1"""
await self.redis.eval(
self._lua_script,
2,
batch_key,
timer_key,
channel_key,
message_json,
str(execution_ts),
str(batch_ttl_ms),
)
logger.debug("Message appended to Redis batch for channel: %s", channel_key)
async def _batch_scheduler_loop(self) -> None:
"""Run the distributed batch flusher: poll for expired timers once a second.
The long-lived background loop of the distributed batching path. Every
second it calls :meth:`_check_expired_batches`, which drains and
publishes any channel batch whose sliding window has elapsed. Per-tick
exceptions are logged and swallowed so one bad iteration never kills the
loop; an :class:`asyncio.CancelledError` cleanly breaks out at shutdown.
Started as an :class:`asyncio.Task` by :meth:`__init__` (or later by
:meth:`set_event_bus`) once both Redis and an event bus are available.
"""
logger.info("Starting Redis message batch scheduler loop")
while True:
try:
await self._check_expired_batches()
except asyncio.CancelledError:
break
except Exception:
logger.exception("Error in batch scheduler loop")
await asyncio.sleep(1.0)
async def _batch_drain_redis(
self,
channel_key: str,
batch_key: str,
now: float,
) -> list[Any] | None:
"""Atomically claim + drain + delete an expired batch.
Returns ``[messages_json, first_at]`` when this caller wins the claim,
or ``None`` when the batch was already drained by another worker or was
re-armed by a newer message (sliding window not yet elapsed).
"""
if self.redis is None:
return None
timer_key = "sg:batch:timers"
if not hasattr(self, "_lua_drain_script"):
try:
with open("atomic_batch_drain.lua", "r") as f:
self._lua_drain_script = f.read()
except Exception:
self._lua_drain_script = """local timer_key = KEYS[2]
local channel_key = ARGV[1]
local now = tonumber(ARGV[2])
local score = redis.call('ZSCORE', timer_key, channel_key)
if not score then
return false
end
if tonumber(score) > now then
return false
end
local batch_key = KEYS[1]
local messages = redis.call('HGET', batch_key, 'messages')
local first_at = redis.call('HGET', batch_key, 'first_at')
redis.call('ZREM', timer_key, channel_key)
redis.call('DEL', batch_key)
return {messages, first_at}"""
result = await self.redis.eval(
self._lua_drain_script,
2,
batch_key,
timer_key,
channel_key,
str(now),
)
if not result:
return None
return result
async def _restore_batch(
self,
channel_key: str,
messages: list[dict[str, Any]],
) -> None:
"""Re-append drained messages so a failed publish never drops them.
Re-arms the batch timer to ``now + default_batch_window`` so the next
scheduler tick retries the flush.
"""
if self.redis is None or not messages:
return
attempts = self._restore_attempts.get(channel_key, 0) + 1
if attempts > 5:
logger.error(
"Restore batch for %s exceeded max retries (%d) — messages may be lost",
channel_key,
attempts,
)
return
self._restore_attempts[channel_key] = attempts
import time
execution_ts = time.time() + self.default_batch_window
for m in messages:
try:
await self._batch_append_redis(
channel_key, json.dumps(m), execution_ts
)
except Exception:
logger.exception(
"Failed to restore a batched message for %s after publish failure",
channel_key,
)
else:
self._restore_attempts.pop(channel_key, None)
async def _force_drain_channel_batch(
self,
channel_key: str,
*,
now: float,
) -> None:
"""Atomically drain and publish one channel's batch (max-size flush).
Unlike :meth:`_check_expired_batches`, this only touches *channel_key*
so a full batch on one channel cannot prematurely flush others.
"""
batch_key = f"sg:batch:active:{channel_key}"
await self._drain_and_publish_channel_batch(channel_key, batch_key, now)
async def _drain_and_publish_channel_batch(
self,
channel_key: str,
batch_key: str,
now: float,
) -> None:
"""Atomically claim one channel's batch and publish it (or restore it).
The single drain path shared by the size-triggered flush
(:meth:`_force_drain_channel_batch`) and the timer-expiry sweep
(:meth:`_check_expired_batches`, message_queue.py:1368). It first calls
:meth:`_batch_drain_redis` to atomically claim ``batch_key``; that returns
``None`` when another worker already drained it (double-flush guard) or a
newer message re-armed the timer (lost-append guard), in which case this
no-ops. Otherwise it decodes and JSON-parses the claimed payload and, when
an event bus is configured, publishes a ``batch`` envelope via
:meth:`RedisEventBus.publish_inbound`. If no event bus is available, or the
publish raises, the drained messages are handed to :meth:`_restore_batch`
so they are retried rather than lost.
Args:
channel_key: ``{platform}:{channel_id}`` key identifying the batch.
batch_key: Redis key (``sg:batch:active:{channel_key}``) holding the
claimable batch payload.
now: Current epoch seconds, stamped onto the published envelope as
``enqueued_at`` and passed to the atomic-claim script.
"""
# Atomically claim the batch. Returns None if another worker
# already drained it (double-flush guard) or a newer message
# re-armed the timer (lost-append guard).
drained = await self._batch_drain_redis(channel_key, batch_key, now)
if not drained:
return
messages_json = drained[0] if len(drained) > 0 else None
first_at = drained[1] if len(drained) > 1 else None
if isinstance(messages_json, bytes):
messages_json = messages_json.decode("utf-8")
if isinstance(first_at, bytes):
first_at = first_at.decode("utf-8")
if not messages_json:
logger.warning(
"Drained batch for %s had no messages payload", channel_key
)
return
try:
messages = json.loads(messages_json)
except Exception:
logger.error(
"Failed to parse drained batch messages for %s: %s",
channel_key,
messages_json,
)
messages = []
if not messages:
return
if self.event_bus is None:
# Nothing to publish to — restore so the messages are not
# lost when an event bus becomes available.
await self._restore_batch(channel_key, messages)
return
first_msg = messages[0]
platform = first_msg.get("platform", channel_key.split(":")[0])
channel_id = first_msg.get("channel_id", channel_key.split(":")[-1])
envelope = {
"type": "batch",
"channel_key": channel_key,
"platform": platform,
"channel_id": channel_id,
"messages": messages,
"first_at": first_at,
"trace_id": first_msg.get("trace_id")
or first_msg.get("extra", {}).get("trace_id")
or "",
"enqueued_at": now,
}
try:
await self.event_bus.publish_inbound(envelope)
logger.info(
"Finalised batch for %s with %d message(s)",
channel_key,
len(messages),
)
except Exception:
# Publish failed after we already removed the batch from
# Redis — restore the drained messages so they are retried
# rather than silently dropped.
logger.exception(
"Failed to publish drained batch for %s; restoring for retry",
channel_key,
)
await self._restore_batch(channel_key, messages)
async def _check_expired_batches(self, now: float | None = None) -> None:
"""Find batch timers that have elapsed and flush each one.
Queries the ``sg:batch:timers`` Redis sorted set for channels whose
scheduled execution time is ``<= now`` and, for each, calls
:meth:`_drain_and_publish_channel_batch` to atomically claim, drain, and
publish that channel's batch to the cross-service inbound stream. All
errors are logged and swallowed so one bad channel cannot stall the
scheduler.
Called every second by :meth:`_batch_scheduler_loop` (with *now*
defaulting to the current wall-clock time), and directly by the
batching tests with an explicit timestamp. No-ops when no Redis client
is configured.
Args:
now (float | None): Unix timestamp used as the expiry cutoff; when
``None``, ``time.time()`` is used.
"""
if self.redis is None:
return
if now is None:
import time
now = time.time()
try:
expired = await self.redis.zrangebyscore("sg:batch:timers", "-inf", now)
if not expired:
return
for channel_key in expired:
if isinstance(channel_key, bytes):
channel_key = channel_key.decode("utf-8")
batch_key = f"sg:batch:active:{channel_key}"
await self._drain_and_publish_channel_batch(
channel_key,
batch_key,
now,
)
except Exception:
logger.exception("Error checking expired batches")
[docs]
async def enqueue_front(self, item: QueueItem) -> None:
"""Push a message or batch to the head of its channel queue (priority).
Resolves the channel from *item* (via its ``channel_id``), fetches the
backing queue with :meth:`_queue`, and delegates to that queue's
``enqueue_front`` so the item is consumed before anything already
waiting; the cached per-channel length is bumped to match. Used to inject
a high-priority item (e.g. a re-queued or system message) ahead of the
normal FIFO order.
Within the repo this is exercised by the queue tests; otherwise called by
the owning runner when an item must jump the line.
Args:
item (QueueItem): The :class:`QueuedMessage` or :class:`MessageBatch`
to enqueue at the front of its channel.
"""
channel = item.channel_id
q = self._queue(channel)
await q.enqueue_front(item)
self._lengths[channel] = self._lengths.get(channel, 0) + 1
[docs]
def is_channel_processing(self, channel: str) -> bool:
"""Report whether a channel is currently running its processor callback.
Reads the cached ``_processing`` flag, which :meth:`_process_loop` sets
while a queue item's callback is in flight for the channel. A cheap,
non-blocking status probe.
No in-repo callers were found, so this is consumed by external status or
admin surfaces.
Args:
channel (str): The channel key to check.
Returns:
bool: ``True`` if the channel is actively processing an item,
``False`` otherwise (including unknown channels).
"""
return self._processing.get(channel, False)
[docs]
def queue_size(self, channel: str) -> int:
"""Return the cached number of items waiting in a channel's queue.
Reads the in-process ``_lengths`` counter that the enqueue and processor
paths keep in sync, rather than querying the backend, so it is cheap and
synchronous. Note this reflects committed queue items, not messages still
accumulating in an open batch.
No in-repo callers were found, so this is consumed by external status or
admin surfaces.
Args:
channel (str): The channel key to size.
Returns:
int: The cached queue length for the channel (``0`` if unknown).
"""
return self._lengths.get(channel, 0)
[docs]
async def start_processing(
self,
channel: str,
callback: ProcessorCallback,
) -> None:
"""Ensure a per-channel processor loop is running, starting one if not.
Idempotently launches the :meth:`_process_loop` :class:`asyncio.Task` for
*channel* (creating the backing queue via :meth:`_queue` if needed). Holds
the channel's processor lock while checking for and replacing a finished
task so two concurrent calls cannot start duplicate loops. The loop then
pops items and invokes *callback* for each. Used by the in-memory delivery
mode; in distributed mode batches are delivered via the event bus instead.
Within the repo this is exercised by the queue tests; otherwise called by
the owning runner when a channel first needs servicing.
Args:
channel (str): The channel to start processing.
callback (ProcessorCallback): Async callable invoked with each
dequeued :class:`QueuedMessage` or :class:`MessageBatch`.
"""
q = self._queue(channel)
lock = self._locks[channel]
async with lock:
task = self._processors.get(channel)
if task is not None and not task.done():
return
self._processors[channel] = asyncio.create_task(
self._process_loop(channel, callback),
)
[docs]
async def cancel_batch_timer(self, channel: str) -> None:
"""Cancel and await a channel's pending in-memory batch flush timer.
Pops the channel's timer task (if any), cancels it, and awaits it so the
cancellation fully settles before returning, swallowing the expected
:class:`asyncio.CancelledError`. Prevents a pending flush from firing
after the channel is being torn down or stopped. No-ops when no timer is
armed.
Called by :meth:`stop_processing` as part of stopping a channel.
Args:
channel (str): The channel whose flush timer should be cancelled.
"""
timer = self._batch_timers.pop(channel, None)
if timer is not None and not timer.done():
timer.cancel()
try:
await timer
except asyncio.CancelledError:
pass
logger.info("Cancelled batch timer for channel %s", channel)
[docs]
async def stop_processing(self, channel: str) -> None:
"""Tear down all processing for a channel: in-flight task, timer, and loop.
Cancels the channel's currently-running callback via
:meth:`cancel_current`, cancels its pending batch timer via
:meth:`cancel_batch_timer`, then pops and cancels the
:meth:`_process_loop` task itself, awaiting each cancellation so it fully
settles. Used to stop servicing a channel entirely (e.g. on a user
stop/cancel request).
Called by ``web/redis_platform_api.py`` (the platform cancel/stop
endpoint) and exercised by the queue tests.
Args:
channel (str): The channel to stop processing.
"""
logger.info("Stopping queue processing for channel %s", channel)
self.cancel_current(channel)
await self.cancel_batch_timer(channel)
task = self._processors.pop(channel, None)
if task is not None and not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
logger.info("Stopped queue processor loop task for channel %s", channel)
[docs]
async def clear(self, channel: str) -> int:
"""Discard every queued item for a channel and reset its cached length.
Delegates to the backing queue's ``clear`` (which deletes the Redis list
or drains the in-memory queue), then zeroes the channel's cached length.
Returns ``0`` for a channel that has no queue. Used to drop pending work
for a channel (e.g. during a cancel/reset).
No in-repo callers were found, so this is invoked by the owning runner or
an admin/reset path.
Args:
channel (str): The channel whose queue should be emptied.
Returns:
int: The number of items removed from the channel's queue.
"""
q = self._queues.get(channel)
if q is None:
return 0
count = await q.clear()
self._lengths[channel] = 0
return count
[docs]
def stats(self) -> dict[str, dict[str, Any]]:
"""Snapshot per-channel queue health for every known channel.
Builds a mapping from channel key to a small status dict — cached
``queue_size``, the ``is_processing`` flag, and whether a live processor
task exists — reading purely from in-process bookkeeping so it is cheap
and synchronous. Intended for diagnostics and monitoring surfaces.
No in-repo callers were found, so this is consumed by external status or
admin surfaces.
Returns:
dict[str, dict[str, Any]]: Per-channel status, each entry holding
``queue_size``, ``is_processing``, and ``has_processor`` keys.
"""
out: dict[str, dict[str, Any]] = {}
for ch, q in self._queues.items():
out[ch] = {
"queue_size": self._lengths.get(ch, 0),
"is_processing": self._processing.get(ch, False),
"has_processor": (
ch in self._processors and not self._processors[ch].done()
),
}
return out
# ------------------------------------------------------------------
# Processor loop
# ------------------------------------------------------------------
[docs]
def cancel_current(self, channel: str) -> bool:
"""Cancel the in-flight processing task for *channel*.
Returns ``True`` if a task was found and cancelled, ``False`` if
no processing was active for the channel.
"""
task = self._current_tasks.get(channel)
if task is not None and not task.done():
self._stop_requested[channel] = True
task.cancel()
logger.info("Cancelled current processing task for %s", channel)
return True
return False
async def _process_loop(
self,
channel: str,
callback: ProcessorCallback,
) -> None:
"""Drain a channel's queue, running *callback* on each item in order.
The body of the per-channel processor :class:`asyncio.Task`. It pops one
item at a time (with a 60 s wait so it can self-retire when idle, unless
a batch is still accumulating), marks the channel as processing, and runs
*callback* in a tracked child task so a user stop request
(:meth:`cancel_current`) can cancel just the in-flight item without
killing the loop. Callback errors are logged and skipped; a true loop
cancellation breaks out. On exit it releases the channel's queue, lock,
processing flag, and batch lock so the channel can be cleanly restarted
later.
Launched by :meth:`start_processing`; it owns the lifecycle of one
channel's processing.
Args:
channel (str): The channel this loop services.
callback (ProcessorCallback): Async callable invoked with each
dequeued :class:`QueuedMessage` or :class:`MessageBatch`.
"""
q = self._queue(channel)
while True:
try:
try:
item = await asyncio.wait_for(q.get(), timeout=60.0)
self._lengths[channel] = max(0, self._lengths.get(channel, 0) - 1)
except asyncio.TimeoutError:
if channel in self._active_batches:
continue
break
self._processing[channel] = True
task = asyncio.create_task(callback(item))
self._current_tasks[channel] = task
try:
await task
except asyncio.CancelledError:
if self._stop_requested.pop(channel, False):
logger.info(
"Processing stopped by user for %s",
channel,
)
else:
if not task.done():
task.cancel()
raise
except Exception:
label = (
f"batch({item.size})"
if isinstance(item, MessageBatch)
else "message"
)
logger.exception(
"Error processing %s in %s",
label,
channel,
)
finally:
self._current_tasks.pop(channel, None)
q.task_done()
self._processing[channel] = False
except asyncio.CancelledError:
break
except Exception:
logger.exception("Fatal error in queue processor for %s", channel)
await asyncio.sleep(1.0)
self._processing[channel] = False
lock = self._locks.get(channel)
if lock:
async with lock:
if self._processors.get(channel) is asyncio.current_task():
del self._processors[channel]
self._queues.pop(channel, None)
self._locks.pop(channel, None)
self._processing.pop(channel, None)
self._batch_locks.pop(channel, None)