"""Per-channel message queue with batching.
Ensures temporal consistency: messages are processed in order per channel,
and rapid-succession messages can be collected into batches for a single
combined response.
"""
from __future__ import annotations
import asyncio
import logging
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, Callable, Awaitable, Union
logger = logging.getLogger(__name__)
[docs]
@dataclass
class QueuedMessage:
"""A message waiting in the channel queue."""
platform: str
channel_id: str
user_id: str
user_name: str
text: str
queued_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
extra: dict[str, Any] = field(default_factory=dict)
raw: Any = None # Original IncomingMessage
[docs]
@dataclass
class MessageBatch:
"""A group of messages collected within a rolling time window."""
messages: list[QueuedMessage] = field(default_factory=list)
first_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
last_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
[docs]
def add(self, msg: QueuedMessage) -> None:
"""Add.
Args:
msg (QueuedMessage): Incoming message object.
"""
self.messages.append(msg)
self.last_at = datetime.now(timezone.utc)
if len(self.messages) == 1:
self.first_at = self.last_at
@property
def size(self) -> int:
"""Size.
Returns:
int: The result.
"""
return len(self.messages)
@property
def channel_id(self) -> str:
"""Channel id.
Returns:
str: Result string.
"""
if not self.messages:
raise ValueError("Empty batch has no channel_id")
return self.messages[0].channel_id
[docs]
def unique_authors(self) -> list[str]:
"""Unique authors.
Returns:
list[str]: The result.
"""
seen: set[str] = set()
out: list[str] = []
for m in self.messages:
if m.user_id not in seen:
seen.add(m.user_id)
out.append(m.user_id)
return out
QueueItem = Union[QueuedMessage, MessageBatch]
ProcessorCallback = Callable[[QueueItem], Awaitable[None]]
[docs]
class MessageQueue:
"""Per-channel queue that processes messages in order.
Parameters
----------
default_batch_window:
Seconds to wait for additional messages before finalising a batch.
max_batch_size:
Maximum messages per batch before immediate finalisation.
redis:
Optional async Redis client for per-channel batch config.
"""
[docs]
def __init__(
self,
default_batch_window: float = 5.0,
max_batch_size: int = 10,
redis: Any = None,
) -> None:
"""Initialize the instance.
Args:
default_batch_window (float): The default batch window value.
max_batch_size (int): The max batch size value.
redis (Any): The redis value.
"""
self._queues: dict[str, asyncio.Queue[QueueItem]] = {}
self._locks: dict[str, asyncio.Lock] = {}
self._processors: dict[str, asyncio.Task] = {}
self._processing: dict[str, bool] = {}
self._active_batches: dict[str, MessageBatch] = {}
self._batch_timers: dict[str, asyncio.Task] = {}
self._batch_locks: dict[str, asyncio.Lock] = {}
# Per-channel task tracking for cancellation support.
self._current_tasks: dict[str, asyncio.Task] = {}
self._stop_requested: dict[str, bool] = {}
self.default_batch_window = default_batch_window
self.max_batch_size = max_batch_size
self.redis = redis
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _queue(self, channel: str) -> asyncio.Queue[QueueItem]:
"""Internal helper: queue.
Args:
channel (str): The channel value.
Returns:
asyncio.Queue[QueueItem]: The result.
"""
if channel not in self._queues:
self._queues[channel] = asyncio.Queue()
self._locks[channel] = asyncio.Lock()
self._processing[channel] = False
return self._queues[channel]
def _batch_lock(self, channel: str) -> asyncio.Lock:
"""Internal helper: batch lock.
Args:
channel (str): The channel value.
Returns:
asyncio.Lock: The result.
"""
if channel not in self._batch_locks:
self._batch_locks[channel] = asyncio.Lock()
return self._batch_locks[channel]
async def _batch_window_for(self, channel: str) -> float:
"""Internal helper: batch window for.
Args:
channel (str): The channel value.
Returns:
float: The result.
"""
if self.redis is not None:
try:
if await self.redis.exists(f"message_batching_disabled:{channel}"):
return 0.0
raw = await self.redis.get(f"message_batch_window:{channel}")
if raw:
return float(raw)
except Exception:
logger.debug("Redis batch-window lookup failed for %s", channel)
return self.default_batch_window
async def _finalize_batch(self, channel: str) -> None:
"""Internal helper: finalize batch.
Args:
channel (str): The channel value.
"""
lock = self._batch_lock(channel)
async with lock:
await self._finalize_batch_unlocked(channel)
async def _finalize_batch_unlocked(self, channel: str) -> None:
"""Internal helper: finalize batch unlocked.
Args:
channel (str): The channel value.
"""
if channel not in self._active_batches:
return
batch = self._active_batches.pop(channel)
timer = self._batch_timers.pop(channel, None)
if timer is not None and not timer.done():
timer.cancel()
try:
await timer
except asyncio.CancelledError:
pass
q = self._queue(channel)
await q.put(batch)
logger.info(
"Finalised batch for %s with %d message(s)", channel, batch.size,
)
async def _start_timer(self, channel: str, window: float) -> None:
"""Internal helper: start timer.
Args:
channel (str): The channel value.
window (float): The window value.
"""
old = self._batch_timers.get(channel)
if old is not None and not old.done():
old.cancel()
try:
await old
except asyncio.CancelledError:
pass
async def _fire() -> None:
"""Internal helper: fire.
"""
try:
await asyncio.sleep(window)
await self._finalize_batch(channel)
except asyncio.CancelledError:
pass
self._batch_timers[channel] = asyncio.create_task(_fire())
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
[docs]
async def enqueue(self, msg: QueuedMessage) -> None:
"""Add a message to the channel queue (may batch)."""
channel = msg.channel_id
window = await self._batch_window_for(channel)
if window <= 0:
q = self._queue(channel)
await q.put(msg)
return
lock = self._batch_lock(channel)
async with lock:
if channel in self._active_batches:
batch = self._active_batches[channel]
batch.add(msg)
if batch.size >= self.max_batch_size:
await self._finalize_batch_unlocked(channel)
else:
await self._start_timer(channel, window)
else:
batch = MessageBatch()
batch.add(msg)
self._active_batches[channel] = batch
await self._start_timer(channel, window)
[docs]
def is_channel_processing(self, channel: str) -> bool:
"""Check whether is channel processing.
Args:
channel (str): The channel value.
Returns:
bool: True on success, False otherwise.
"""
return self._processing.get(channel, False)
[docs]
def queue_size(self, channel: str) -> int:
"""Queue size.
Args:
channel (str): The channel value.
Returns:
int: The result.
"""
q = self._queues.get(channel)
return q.qsize() if q else 0
[docs]
async def start_processing(
self, channel: str, callback: ProcessorCallback,
) -> None:
"""Ensure a processor loop is running for *channel*."""
q = self._queue(channel)
lock = self._locks[channel]
async with lock:
task = self._processors.get(channel)
if task is not None and not task.done():
return
self._processors[channel] = asyncio.create_task(
self._process_loop(channel, callback),
)
[docs]
async def stop_processing(self, channel: str) -> None:
"""Stop processing.
Args:
channel (str): The channel value.
"""
task = self._processors.pop(channel, None)
if task is not None and not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
[docs]
async def clear(self, channel: str) -> int:
"""Clear.
Args:
channel (str): The channel value.
Returns:
int: The result.
"""
q = self._queues.get(channel)
if q is None:
return 0
count = 0
while not q.empty():
try:
q.get_nowait()
q.task_done()
count += 1
except asyncio.QueueEmpty:
break
return count
[docs]
def stats(self) -> dict[str, dict[str, Any]]:
"""Stats.
Returns:
dict[str, dict[str, Any]]: The result.
"""
out: dict[str, dict[str, Any]] = {}
for ch, q in self._queues.items():
out[ch] = {
"queue_size": q.qsize(),
"is_processing": self._processing.get(ch, False),
"has_processor": (
ch in self._processors
and not self._processors[ch].done()
),
}
return out
# ------------------------------------------------------------------
# Processor loop
# ------------------------------------------------------------------
[docs]
def cancel_current(self, channel: str) -> bool:
"""Cancel the in-flight processing task for *channel*.
Returns ``True`` if a task was found and cancelled, ``False`` if
no processing was active for the channel.
"""
task = self._current_tasks.get(channel)
if task is not None and not task.done():
self._stop_requested[channel] = True
task.cancel()
logger.info("Cancelled current processing task for %s", channel)
return True
return False
async def _process_loop(
self, channel: str, callback: ProcessorCallback,
) -> None:
"""Internal helper: process loop.
Args:
channel (str): The channel value.
callback (ProcessorCallback): The callback value.
"""
q = self._queue(channel)
while True:
try:
try:
item = await asyncio.wait_for(q.get(), timeout=60.0)
except asyncio.TimeoutError:
if channel in self._active_batches:
continue
break
self._processing[channel] = True
task = asyncio.create_task(callback(item))
self._current_tasks[channel] = task
try:
await task
except asyncio.CancelledError:
if self._stop_requested.pop(channel, False):
logger.info(
"Processing stopped by user for %s", channel,
)
else:
if not task.done():
task.cancel()
raise
except Exception:
label = (
f"batch({item.size})"
if isinstance(item, MessageBatch)
else "message"
)
logger.exception(
"Error processing %s in %s", label, channel,
)
finally:
self._current_tasks.pop(channel, None)
q.task_done()
self._processing[channel] = False
except asyncio.CancelledError:
break
except Exception:
logger.exception("Fatal error in queue processor for %s", channel)
await asyncio.sleep(1.0)
self._processing[channel] = False
lock = self._locks.get(channel)
if lock:
async with lock:
if self._processors.get(channel) is asyncio.current_task():
del self._processors[channel]