"""Fire-and-forget task manager for tool execution.
Wraps tool handler coroutines with a configurable timeout. If a tool
completes within the timeout its result is returned inline. Otherwise
the coroutine continues as a background :class:`asyncio.Task` and a
JSON envelope containing a task ID is returned so the LLM can poll for
results later via the ``check_task`` tool.
Output redirect
~~~~~~~~~~~~~~~
Any backgrounded task can have its result automatically delivered to a
channel on any platform when it finishes. Call
:meth:`TaskManager.set_output_redirect` (or use the ``redirect_task``
tool) to configure this.
"""
from __future__ import annotations
import asyncio
import json
import logging
import time
import uuid
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, TYPE_CHECKING
if TYPE_CHECKING:
from platforms.base import PlatformAdapter
logger = logging.getLogger(__name__)
from message_utils import split_message
REDIS_KEY_PREFIX = "task_result:"
REDIS_TTL_SECONDS = 86400 # 24 hours
DEFAULT_REDIRECT_MAX_CHARS = 9000
MAX_REDIRECT_MESSAGES = 5
[docs]
class TaskStatus(str, Enum):
"""TaskStatus (inherits from str, Enum).
"""
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
[docs]
@dataclass
class TaskRecord:
"""In-memory record for a tracked background task."""
task_id: str
tool_name: str
status: TaskStatus
created_at: float = field(default_factory=time.time)
result: str | None = None
error: str | None = None
user_id: str = ""
channel_id: str = ""
platform: str = ""
asyncio_task: asyncio.Task | None = field(
default=None, repr=False,
)
# -- Output redirect -----------------------------------------------
redirect_channel_id: str = ""
"""Channel to deliver the result to when the task finishes."""
redirect_platform: str = ""
"""Platform name for the redirect target."""
redirect_adapter: Any = field(default=None, repr=False)
"""Resolved :class:`PlatformAdapter` for delivery."""
redirect_max_chars: int = 0
"""Max characters of output body to deliver (0 = use default)."""
[docs]
class TaskManager:
"""Manage fire-and-forget tool execution with timeout.
Parameters
----------
timeout:
Seconds to wait for a tool to complete before backgrounding
it. Defaults to ``10.0``.
redis:
Optional async Redis client for persisting completed results.
"""
[docs]
def __init__(
self,
timeout: float = 10.0,
redis: Any = None,
) -> None:
"""Initialize the instance.
Args:
timeout (float): Maximum wait time in seconds.
redis (Any): The redis value.
"""
self.timeout = timeout
self.redis = redis
self._tasks: dict[str, TaskRecord] = {}
# ------------------------------------------------------------------
# Core execution
# ------------------------------------------------------------------
[docs]
async def execute(
self,
coro: Any,
tool_name: str = "",
user_id: str = "",
channel_id: str = "",
platform: str = "",
) -> str:
"""Run *coro* with a timeout; background it if it takes too long.
Returns the tool result string directly when the coroutine
finishes within :attr:`timeout`, or a JSON envelope with a
``task_id`` when it does not.
"""
task = asyncio.create_task(coro)
done, _ = await asyncio.wait({task}, timeout=self.timeout)
if done:
return str(task.result())
task_id = uuid.uuid4().hex[:12]
record = TaskRecord(
task_id=task_id,
tool_name=tool_name,
status=TaskStatus.RUNNING,
user_id=user_id,
channel_id=channel_id,
platform=platform,
asyncio_task=task,
)
self._tasks[task_id] = record
task.add_done_callback(
lambda t: self._on_task_done(task_id, t),
)
logger.info(
"Tool '%s' backgrounded as task %s", tool_name, task_id,
)
return json.dumps({
"task_id": task_id,
"tool_name": tool_name,
"status": TaskStatus.RUNNING.value,
"message": (
f"Tool '{tool_name}' is running in the background. "
f"Use check_task with task_id '{task_id}' to get "
f"the result."
),
})
# ------------------------------------------------------------------
# Result retrieval
# ------------------------------------------------------------------
[docs]
async def get_result(
self, task_id: str, user_id: str | None = None,
) -> str:
"""Return the result for *task_id*, or a status update.
If *user_id* is set, only tasks owned by that user are returned.
"""
record = self._tasks.get(task_id)
if record is not None:
if user_id is not None and record.user_id != user_id:
return json.dumps({
"error": f"Task '{task_id}' not found.",
})
if record.status == TaskStatus.COMPLETED:
return record.result or ""
if record.status == TaskStatus.FAILED:
return json.dumps({
"task_id": task_id,
"status": TaskStatus.FAILED.value,
"error": record.error,
})
return json.dumps({
"task_id": task_id,
"tool_name": record.tool_name,
"status": TaskStatus.RUNNING.value,
"elapsed_seconds": round(
time.time() - record.created_at, 1,
),
})
# Check Redis for persisted results
if self.redis is not None:
try:
cached = await self.redis.get(
f"{REDIS_KEY_PREFIX}{task_id}",
)
if cached:
return cached
except Exception:
logger.debug(
"Redis lookup failed for task %s",
task_id, exc_info=True,
)
return json.dumps({
"error": f"Task '{task_id}' not found.",
})
[docs]
async def await_result(
self, task_id: str, timeout: float = 300.0,
) -> str:
"""Block until *task_id* completes and return its result.
Unlike :meth:`get_result` which returns immediately with a status
update, this method **awaits** the underlying :class:`asyncio.Task`
so the caller's coroutine is suspended until the work finishes.
Parameters
----------
timeout:
Maximum seconds to wait. Defaults to ``300`` (5 minutes).
If exceeded, a timeout error JSON envelope is returned.
"""
record = self._tasks.get(task_id)
# --- Already finished (in-memory) --------------------------------
if record is not None:
if record.status == TaskStatus.COMPLETED:
return record.result or ""
if record.status == TaskStatus.FAILED:
return json.dumps({
"task_id": task_id,
"status": TaskStatus.FAILED.value,
"error": record.error,
})
# --- Still running: await the asyncio.Task --------------------
atask = record.asyncio_task
if atask is not None:
try:
await asyncio.wait_for(
asyncio.shield(atask), timeout=timeout,
)
except asyncio.TimeoutError:
return json.dumps({
"task_id": task_id,
"status": "timeout",
"error": (
f"Task '{task_id}' did not complete within "
f"{timeout}s. It is still running in the "
f"background — use check_task to poll later."
),
})
except asyncio.CancelledError:
pass # fall through to the status re-check below
except Exception:
pass # task raised; result captured by _on_task_done
# Re-read status — _on_task_done has updated the record.
if record.status == TaskStatus.COMPLETED:
return record.result or ""
if record.status == TaskStatus.FAILED:
return json.dumps({
"task_id": task_id,
"status": TaskStatus.FAILED.value,
"error": record.error,
})
# asyncio_task was already None (finished between our checks)
return await self.get_result(task_id)
# --- Not in memory: check Redis ----------------------------------
if self.redis is not None:
try:
cached = await self.redis.get(
f"{REDIS_KEY_PREFIX}{task_id}",
)
if cached:
return cached
except Exception:
logger.debug(
"Redis lookup failed for task %s",
task_id, exc_info=True,
)
return json.dumps({
"error": f"Task '{task_id}' not found.",
})
[docs]
async def list_tasks(self, user_id: str | None = None) -> str:
"""Return a JSON summary of tracked tasks.
If *user_id* is provided, only tasks belonging to that user are
returned. Pass ``None`` to list all tasks.
"""
now = time.time()
tasks = []
for rec in self._tasks.values():
if user_id is not None and rec.user_id != user_id:
continue
entry: dict[str, Any] = {
"task_id": rec.task_id,
"tool_name": rec.tool_name,
"status": rec.status.value,
"user_id": rec.user_id,
"elapsed_seconds": round(
now - rec.created_at, 1,
),
"created_at": rec.created_at,
}
if rec.status == TaskStatus.FAILED:
entry["error"] = rec.error
tasks.append(entry)
return json.dumps({
"tasks": tasks,
"count": len(tasks),
})
# ------------------------------------------------------------------
# Output redirect
# ------------------------------------------------------------------
[docs]
def set_output_redirect(
self,
task_id: str,
channel_id: str,
platform: str,
adapter: "PlatformAdapter",
max_chars: int = 0,
) -> str | None:
"""Configure a task to deliver its result to *channel_id* on finish.
Returns an error string if the task is not found or already
finished, otherwise ``None``.
"""
record = self._tasks.get(task_id)
if record is None:
return f"Task '{task_id}' not found."
record.redirect_channel_id = channel_id
record.redirect_platform = platform
record.redirect_adapter = adapter
record.redirect_max_chars = max_chars
# If the task already finished before the redirect was set,
# deliver immediately.
if record.status in (TaskStatus.COMPLETED, TaskStatus.FAILED):
asyncio.create_task(self._deliver_output(record))
return None
# ------------------------------------------------------------------
# Internal
# ------------------------------------------------------------------
def _on_task_done(
self, task_id: str, task: asyncio.Task,
) -> None:
"""Callback fired when a backgrounded task finishes."""
record = self._tasks.get(task_id)
if record is None:
return
record.asyncio_task = None
if task.cancelled():
record.status = TaskStatus.FAILED
record.error = "Task was cancelled."
logger.warning("Task %s was cancelled", task_id)
elif (exc := task.exception()) is not None:
record.status = TaskStatus.FAILED
record.error = f"{type(exc).__name__}: {exc}"
logger.error(
"Task %s failed: %s", task_id, exc, exc_info=exc,
)
else:
record.status = TaskStatus.COMPLETED
record.result = str(task.result())
logger.info("Task %s completed", task_id)
# Persist to Redis asynchronously
if self.redis is not None:
asyncio.create_task(self._persist_result(record))
# Deliver output to redirect channel if configured
if record.redirect_adapter is not None and record.redirect_channel_id:
asyncio.create_task(self._deliver_output(record))
async def _persist_result(self, record: TaskRecord) -> None:
"""Store a completed/failed result in Redis."""
try:
if record.status == TaskStatus.COMPLETED:
value = json.dumps({
"task_id": record.task_id,
"status": record.status.value,
"user_id": record.user_id,
"result": record.result or "",
})
else:
value = json.dumps({
"task_id": record.task_id,
"status": record.status.value,
"user_id": record.user_id,
"error": record.error,
})
await self.redis.set(
f"{REDIS_KEY_PREFIX}{record.task_id}",
value,
ex=REDIS_TTL_SECONDS,
)
except Exception:
logger.debug(
"Failed to persist task %s to Redis",
record.task_id, exc_info=True,
)
async def _deliver_output(self, record: TaskRecord) -> None:
"""Send the task result to the configured redirect channel.
Splits across up to :data:`MAX_REDIRECT_MESSAGES` messages and
only truncates when that budget is exhausted.
"""
adapter = record.redirect_adapter
channel_id = record.redirect_channel_id
if adapter is None or not channel_id:
return
elapsed = round(time.time() - record.created_at, 1)
tool = record.tool_name or "unknown"
max_chars = record.redirect_max_chars or DEFAULT_REDIRECT_MAX_CHARS
if record.status == TaskStatus.COMPLETED:
body = record.result or "(empty result)"
truncated = len(body) > max_chars
if truncated:
body = body[:max_chars]
header = f"**`{tool}`** completed in {elapsed}s\n"
suffix = "\n…[truncated]" if truncated else ""
text = f"{header}```\n{body}{suffix}\n```"
elif record.status == TaskStatus.FAILED:
err = record.error or "unknown error"
text = (
f"**`{tool}`** failed after {elapsed}s\n"
f"```\n{err}\n```"
)
else:
return
chunks = split_message(text, max_length=1950)
try:
for chunk in chunks[:MAX_REDIRECT_MESSAGES]:
await adapter.send(channel_id, chunk)
logger.info(
"Delivered task %s output (%d msg) to %s:%s",
record.task_id, min(len(chunks), MAX_REDIRECT_MESSAGES),
record.redirect_platform, channel_id,
)
except Exception:
logger.exception(
"Failed to deliver task %s output to %s:%s",
record.task_id, record.redirect_platform, channel_id,
)