Source code for task_manager

"""Fire-and-forget task manager for tool execution.

Wraps tool handler coroutines with a configurable timeout.  If a tool
completes within the timeout its result is returned inline.  Otherwise
the coroutine continues as a background :class:`asyncio.Task` and a
JSON envelope containing a task ID is returned so the LLM can poll for
results later via the ``check_task`` tool.

Output redirect
~~~~~~~~~~~~~~~
Any backgrounded task can have its result automatically delivered to a
channel on any platform when it finishes.  Call
:meth:`TaskManager.set_output_redirect` (or use the ``redirect_task``
tool) to configure this.
"""

from __future__ import annotations

import asyncio
import json
import logging
import time
import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, TYPE_CHECKING

if TYPE_CHECKING:
    from platforms.base import PlatformAdapter

logger = logging.getLogger(__name__)

from message_utils import split_message

REDIS_KEY_PREFIX = "task_result:"
REDIS_TTL_SECONDS = 86400  # 24 hours

DEFAULT_REDIRECT_MAX_CHARS = 9000
MAX_REDIRECT_MESSAGES = 5


[docs] class TaskStatus(str, Enum): """TaskStatus (inherits from str, Enum). """ RUNNING = "running" COMPLETED = "completed" FAILED = "failed"
[docs] @dataclass class TaskRecord: """In-memory record for a tracked background task.""" task_id: str tool_name: str status: TaskStatus created_at: float = field(default_factory=time.time) result: str | None = None error: str | None = None user_id: str = "" channel_id: str = "" platform: str = "" asyncio_task: asyncio.Task | None = field( default=None, repr=False, ) # -- Output redirect ----------------------------------------------- redirect_channel_id: str = "" """Channel to deliver the result to when the task finishes.""" redirect_platform: str = "" """Platform name for the redirect target.""" redirect_adapter: Any = field(default=None, repr=False) """Resolved :class:`PlatformAdapter` for delivery.""" redirect_max_chars: int = 0 """Max characters of output body to deliver (0 = use default)."""
[docs] class TaskManager: """Manage fire-and-forget tool execution with timeout. Parameters ---------- timeout: Seconds to wait for a tool to complete before backgrounding it. Defaults to ``10.0``. redis: Optional async Redis client for persisting completed results. """
[docs] def __init__( self, timeout: float = 10.0, redis: Any = None, ) -> None: """Initialize the instance. Args: timeout (float): Maximum wait time in seconds. redis (Any): The redis value. """ self.timeout = timeout self.redis = redis self._tasks: dict[str, TaskRecord] = {}
# ------------------------------------------------------------------ # Core execution # ------------------------------------------------------------------
[docs] async def execute( self, coro: Any, tool_name: str = "", user_id: str = "", channel_id: str = "", platform: str = "", ) -> str: """Run *coro* with a timeout; background it if it takes too long. Returns the tool result string directly when the coroutine finishes within :attr:`timeout`, or a JSON envelope with a ``task_id`` when it does not. """ task = asyncio.create_task(coro) done, _ = await asyncio.wait({task}, timeout=self.timeout) if done: return str(task.result()) task_id = uuid.uuid4().hex[:12] record = TaskRecord( task_id=task_id, tool_name=tool_name, status=TaskStatus.RUNNING, user_id=user_id, channel_id=channel_id, platform=platform, asyncio_task=task, ) self._tasks[task_id] = record task.add_done_callback( lambda t: self._on_task_done(task_id, t), ) logger.info( "Tool '%s' backgrounded as task %s", tool_name, task_id, ) return json.dumps({ "task_id": task_id, "tool_name": tool_name, "status": TaskStatus.RUNNING.value, "message": ( f"Tool '{tool_name}' is running in the background. " f"Use check_task with task_id '{task_id}' to get " f"the result." ), })
# ------------------------------------------------------------------ # Result retrieval # ------------------------------------------------------------------
[docs] async def get_result( self, task_id: str, user_id: str | None = None, ) -> str: """Return the result for *task_id*, or a status update. If *user_id* is set, only tasks owned by that user are returned. """ record = self._tasks.get(task_id) if record is not None: if user_id is not None and record.user_id != user_id: return json.dumps({ "error": f"Task '{task_id}' not found.", }) if record.status == TaskStatus.COMPLETED: return record.result or "" if record.status == TaskStatus.FAILED: return json.dumps({ "task_id": task_id, "status": TaskStatus.FAILED.value, "error": record.error, }) return json.dumps({ "task_id": task_id, "tool_name": record.tool_name, "status": TaskStatus.RUNNING.value, "elapsed_seconds": round( time.time() - record.created_at, 1, ), }) # Check Redis for persisted results if self.redis is not None: try: cached = await self.redis.get( f"{REDIS_KEY_PREFIX}{task_id}", ) if cached: return cached except Exception: logger.debug( "Redis lookup failed for task %s", task_id, exc_info=True, ) return json.dumps({ "error": f"Task '{task_id}' not found.", })
[docs] async def await_result( self, task_id: str, timeout: float = 300.0, ) -> str: """Block until *task_id* completes and return its result. Unlike :meth:`get_result` which returns immediately with a status update, this method **awaits** the underlying :class:`asyncio.Task` so the caller's coroutine is suspended until the work finishes. Parameters ---------- timeout: Maximum seconds to wait. Defaults to ``300`` (5 minutes). If exceeded, a timeout error JSON envelope is returned. """ record = self._tasks.get(task_id) # --- Already finished (in-memory) -------------------------------- if record is not None: if record.status == TaskStatus.COMPLETED: return record.result or "" if record.status == TaskStatus.FAILED: return json.dumps({ "task_id": task_id, "status": TaskStatus.FAILED.value, "error": record.error, }) # --- Still running: await the asyncio.Task -------------------- atask = record.asyncio_task if atask is not None: try: await asyncio.wait_for( asyncio.shield(atask), timeout=timeout, ) except asyncio.TimeoutError: return json.dumps({ "task_id": task_id, "status": "timeout", "error": ( f"Task '{task_id}' did not complete within " f"{timeout}s. It is still running in the " f"background — use check_task to poll later." ), }) except asyncio.CancelledError: pass # fall through to the status re-check below except Exception: pass # task raised; result captured by _on_task_done # Re-read status — _on_task_done has updated the record. if record.status == TaskStatus.COMPLETED: return record.result or "" if record.status == TaskStatus.FAILED: return json.dumps({ "task_id": task_id, "status": TaskStatus.FAILED.value, "error": record.error, }) # asyncio_task was already None (finished between our checks) return await self.get_result(task_id) # --- Not in memory: check Redis ---------------------------------- if self.redis is not None: try: cached = await self.redis.get( f"{REDIS_KEY_PREFIX}{task_id}", ) if cached: return cached except Exception: logger.debug( "Redis lookup failed for task %s", task_id, exc_info=True, ) return json.dumps({ "error": f"Task '{task_id}' not found.", })
[docs] async def list_tasks(self, user_id: str | None = None) -> str: """Return a JSON summary of tracked tasks. If *user_id* is provided, only tasks belonging to that user are returned. Pass ``None`` to list all tasks. """ now = time.time() tasks = [] for rec in self._tasks.values(): if user_id is not None and rec.user_id != user_id: continue entry: dict[str, Any] = { "task_id": rec.task_id, "tool_name": rec.tool_name, "status": rec.status.value, "user_id": rec.user_id, "elapsed_seconds": round( now - rec.created_at, 1, ), "created_at": rec.created_at, } if rec.status == TaskStatus.FAILED: entry["error"] = rec.error tasks.append(entry) return json.dumps({ "tasks": tasks, "count": len(tasks), })
# ------------------------------------------------------------------ # Output redirect # ------------------------------------------------------------------
[docs] def set_output_redirect( self, task_id: str, channel_id: str, platform: str, adapter: "PlatformAdapter", max_chars: int = 0, ) -> str | None: """Configure a task to deliver its result to *channel_id* on finish. Returns an error string if the task is not found or already finished, otherwise ``None``. """ record = self._tasks.get(task_id) if record is None: return f"Task '{task_id}' not found." record.redirect_channel_id = channel_id record.redirect_platform = platform record.redirect_adapter = adapter record.redirect_max_chars = max_chars # If the task already finished before the redirect was set, # deliver immediately. if record.status in (TaskStatus.COMPLETED, TaskStatus.FAILED): asyncio.create_task(self._deliver_output(record)) return None
# ------------------------------------------------------------------ # Internal # ------------------------------------------------------------------ def _on_task_done( self, task_id: str, task: asyncio.Task, ) -> None: """Callback fired when a backgrounded task finishes.""" record = self._tasks.get(task_id) if record is None: return record.asyncio_task = None if task.cancelled(): record.status = TaskStatus.FAILED record.error = "Task was cancelled." logger.warning("Task %s was cancelled", task_id) elif (exc := task.exception()) is not None: record.status = TaskStatus.FAILED record.error = f"{type(exc).__name__}: {exc}" logger.error( "Task %s failed: %s", task_id, exc, exc_info=exc, ) else: record.status = TaskStatus.COMPLETED record.result = str(task.result()) logger.info("Task %s completed", task_id) # Persist to Redis asynchronously if self.redis is not None: asyncio.create_task(self._persist_result(record)) # Deliver output to redirect channel if configured if record.redirect_adapter is not None and record.redirect_channel_id: asyncio.create_task(self._deliver_output(record)) async def _persist_result(self, record: TaskRecord) -> None: """Store a completed/failed result in Redis.""" try: if record.status == TaskStatus.COMPLETED: value = json.dumps({ "task_id": record.task_id, "status": record.status.value, "user_id": record.user_id, "result": record.result or "", }) else: value = json.dumps({ "task_id": record.task_id, "status": record.status.value, "user_id": record.user_id, "error": record.error, }) await self.redis.set( f"{REDIS_KEY_PREFIX}{record.task_id}", value, ex=REDIS_TTL_SECONDS, ) except Exception: logger.debug( "Failed to persist task %s to Redis", record.task_id, exc_info=True, ) async def _deliver_output(self, record: TaskRecord) -> None: """Send the task result to the configured redirect channel. Splits across up to :data:`MAX_REDIRECT_MESSAGES` messages and only truncates when that budget is exhausted. """ adapter = record.redirect_adapter channel_id = record.redirect_channel_id if adapter is None or not channel_id: return elapsed = round(time.time() - record.created_at, 1) tool = record.tool_name or "unknown" max_chars = record.redirect_max_chars or DEFAULT_REDIRECT_MAX_CHARS if record.status == TaskStatus.COMPLETED: body = record.result or "(empty result)" truncated = len(body) > max_chars if truncated: body = body[:max_chars] header = f"**`{tool}`** completed in {elapsed}s\n" suffix = "\n…[truncated]" if truncated else "" text = f"{header}```\n{body}{suffix}\n```" elif record.status == TaskStatus.FAILED: err = record.error or "unknown error" text = ( f"**`{tool}`** failed after {elapsed}s\n" f"```\n{err}\n```" ) else: return chunks = split_message(text, max_length=1950) try: for chunk in chunks[:MAX_REDIRECT_MESSAGES]: await adapter.send(channel_id, chunk) logger.info( "Delivered task %s output (%d msg) to %s:%s", record.task_id, min(len(chunks), MAX_REDIRECT_MESSAGES), record.redirect_platform, channel_id, ) except Exception: logger.exception( "Failed to deliver task %s output to %s:%s", record.task_id, record.redirect_platform, channel_id, )