Source code for message_queue

"""Per-channel message queue with batching.

Ensures temporal consistency: messages are processed in order per channel,
and rapid-succession messages can be collected into batches for a single
combined response.
"""

from __future__ import annotations

import asyncio
import logging
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, Callable, Awaitable, Union

logger = logging.getLogger(__name__)


[docs] @dataclass class QueuedMessage: """A message waiting in the channel queue.""" platform: str channel_id: str user_id: str user_name: str text: str queued_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) extra: dict[str, Any] = field(default_factory=dict) raw: Any = None # Original IncomingMessage
[docs] @dataclass class MessageBatch: """A group of messages collected within a rolling time window.""" messages: list[QueuedMessage] = field(default_factory=list) first_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) last_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
[docs] def add(self, msg: QueuedMessage) -> None: """Add. Args: msg (QueuedMessage): Incoming message object. """ self.messages.append(msg) self.last_at = datetime.now(timezone.utc) if len(self.messages) == 1: self.first_at = self.last_at
@property def size(self) -> int: """Size. Returns: int: The result. """ return len(self.messages) @property def channel_id(self) -> str: """Channel id. Returns: str: Result string. """ if not self.messages: raise ValueError("Empty batch has no channel_id") return self.messages[0].channel_id
[docs] def unique_authors(self) -> list[str]: """Unique authors. Returns: list[str]: The result. """ seen: set[str] = set() out: list[str] = [] for m in self.messages: if m.user_id not in seen: seen.add(m.user_id) out.append(m.user_id) return out
QueueItem = Union[QueuedMessage, MessageBatch] ProcessorCallback = Callable[[QueueItem], Awaitable[None]]
[docs] class MessageQueue: """Per-channel queue that processes messages in order. Parameters ---------- default_batch_window: Seconds to wait for additional messages before finalising a batch. max_batch_size: Maximum messages per batch before immediate finalisation. redis: Optional async Redis client for per-channel batch config. """
[docs] def __init__( self, default_batch_window: float = 5.0, max_batch_size: int = 10, redis: Any = None, ) -> None: """Initialize the instance. Args: default_batch_window (float): The default batch window value. max_batch_size (int): The max batch size value. redis (Any): The redis value. """ self._queues: dict[str, asyncio.Queue[QueueItem]] = {} self._locks: dict[str, asyncio.Lock] = {} self._processors: dict[str, asyncio.Task] = {} self._processing: dict[str, bool] = {} self._active_batches: dict[str, MessageBatch] = {} self._batch_timers: dict[str, asyncio.Task] = {} self._batch_locks: dict[str, asyncio.Lock] = {} # Per-channel task tracking for cancellation support. self._current_tasks: dict[str, asyncio.Task] = {} self._stop_requested: dict[str, bool] = {} self.default_batch_window = default_batch_window self.max_batch_size = max_batch_size self.redis = redis
# ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _queue(self, channel: str) -> asyncio.Queue[QueueItem]: """Internal helper: queue. Args: channel (str): The channel value. Returns: asyncio.Queue[QueueItem]: The result. """ if channel not in self._queues: self._queues[channel] = asyncio.Queue() self._locks[channel] = asyncio.Lock() self._processing[channel] = False return self._queues[channel] def _batch_lock(self, channel: str) -> asyncio.Lock: """Internal helper: batch lock. Args: channel (str): The channel value. Returns: asyncio.Lock: The result. """ if channel not in self._batch_locks: self._batch_locks[channel] = asyncio.Lock() return self._batch_locks[channel] async def _batch_window_for(self, channel: str) -> float: """Internal helper: batch window for. Args: channel (str): The channel value. Returns: float: The result. """ if self.redis is not None: try: if await self.redis.exists(f"message_batching_disabled:{channel}"): return 0.0 raw = await self.redis.get(f"message_batch_window:{channel}") if raw: return float(raw) except Exception: logger.debug("Redis batch-window lookup failed for %s", channel) return self.default_batch_window async def _finalize_batch(self, channel: str) -> None: """Internal helper: finalize batch. Args: channel (str): The channel value. """ lock = self._batch_lock(channel) async with lock: await self._finalize_batch_unlocked(channel) async def _finalize_batch_unlocked(self, channel: str) -> None: """Internal helper: finalize batch unlocked. Args: channel (str): The channel value. """ if channel not in self._active_batches: return batch = self._active_batches.pop(channel) timer = self._batch_timers.pop(channel, None) if timer is not None and not timer.done(): timer.cancel() try: await timer except asyncio.CancelledError: pass q = self._queue(channel) await q.put(batch) logger.info( "Finalised batch for %s with %d message(s)", channel, batch.size, ) async def _start_timer(self, channel: str, window: float) -> None: """Internal helper: start timer. Args: channel (str): The channel value. window (float): The window value. """ old = self._batch_timers.get(channel) if old is not None and not old.done(): old.cancel() try: await old except asyncio.CancelledError: pass async def _fire() -> None: """Internal helper: fire. """ try: await asyncio.sleep(window) await self._finalize_batch(channel) except asyncio.CancelledError: pass self._batch_timers[channel] = asyncio.create_task(_fire()) # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------
[docs] async def enqueue(self, msg: QueuedMessage) -> None: """Add a message to the channel queue (may batch).""" channel = msg.channel_id window = await self._batch_window_for(channel) if window <= 0: q = self._queue(channel) await q.put(msg) return lock = self._batch_lock(channel) async with lock: if channel in self._active_batches: batch = self._active_batches[channel] batch.add(msg) if batch.size >= self.max_batch_size: await self._finalize_batch_unlocked(channel) else: await self._start_timer(channel, window) else: batch = MessageBatch() batch.add(msg) self._active_batches[channel] = batch await self._start_timer(channel, window)
[docs] def is_channel_processing(self, channel: str) -> bool: """Check whether is channel processing. Args: channel (str): The channel value. Returns: bool: True on success, False otherwise. """ return self._processing.get(channel, False)
[docs] def queue_size(self, channel: str) -> int: """Queue size. Args: channel (str): The channel value. Returns: int: The result. """ q = self._queues.get(channel) return q.qsize() if q else 0
[docs] async def start_processing( self, channel: str, callback: ProcessorCallback, ) -> None: """Ensure a processor loop is running for *channel*.""" q = self._queue(channel) lock = self._locks[channel] async with lock: task = self._processors.get(channel) if task is not None and not task.done(): return self._processors[channel] = asyncio.create_task( self._process_loop(channel, callback), )
[docs] async def stop_processing(self, channel: str) -> None: """Stop processing. Args: channel (str): The channel value. """ task = self._processors.pop(channel, None) if task is not None and not task.done(): task.cancel() try: await task except asyncio.CancelledError: pass
[docs] async def clear(self, channel: str) -> int: """Clear. Args: channel (str): The channel value. Returns: int: The result. """ q = self._queues.get(channel) if q is None: return 0 count = 0 while not q.empty(): try: q.get_nowait() q.task_done() count += 1 except asyncio.QueueEmpty: break return count
[docs] def stats(self) -> dict[str, dict[str, Any]]: """Stats. Returns: dict[str, dict[str, Any]]: The result. """ out: dict[str, dict[str, Any]] = {} for ch, q in self._queues.items(): out[ch] = { "queue_size": q.qsize(), "is_processing": self._processing.get(ch, False), "has_processor": ( ch in self._processors and not self._processors[ch].done() ), } return out
# ------------------------------------------------------------------ # Processor loop # ------------------------------------------------------------------
[docs] def cancel_current(self, channel: str) -> bool: """Cancel the in-flight processing task for *channel*. Returns ``True`` if a task was found and cancelled, ``False`` if no processing was active for the channel. """ task = self._current_tasks.get(channel) if task is not None and not task.done(): self._stop_requested[channel] = True task.cancel() logger.info("Cancelled current processing task for %s", channel) return True return False
async def _process_loop( self, channel: str, callback: ProcessorCallback, ) -> None: """Internal helper: process loop. Args: channel (str): The channel value. callback (ProcessorCallback): The callback value. """ q = self._queue(channel) while True: try: try: item = await asyncio.wait_for(q.get(), timeout=60.0) except asyncio.TimeoutError: if channel in self._active_batches: continue break self._processing[channel] = True task = asyncio.create_task(callback(item)) self._current_tasks[channel] = task try: await task except asyncio.CancelledError: if self._stop_requested.pop(channel, False): logger.info( "Processing stopped by user for %s", channel, ) else: if not task.done(): task.cancel() raise except Exception: label = ( f"batch({item.size})" if isinstance(item, MessageBatch) else "message" ) logger.exception( "Error processing %s in %s", label, channel, ) finally: self._current_tasks.pop(channel, None) q.task_done() self._processing[channel] = False except asyncio.CancelledError: break except Exception: logger.exception("Fatal error in queue processor for %s", channel) await asyncio.sleep(1.0) self._processing[channel] = False lock = self._locks.get(channel) if lock: async with lock: if self._processors.get(channel) is asyncio.current_task(): del self._processors[channel]