Source code for tools.cross_channel_query

"""Cross-channel query tool -- privately ask the LLM about another channel.

Fetches recent messages from a target channel's Redis cache, builds a
one-shot LLM sub-call with that transcript as context, and returns the
response.  Nothing is posted to the target channel.
"""

from __future__ import annotations

import json
import logging
import time
from datetime import datetime, timezone
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from tool_context import ToolContext

logger = logging.getLogger(__name__)

_MAX_MESSAGES = 500
_DEFAULT_MESSAGES = 100

_SYSTEM_PROMPT = (
    "You are a read-only channel analyst.  You have been given a transcript "
    "of recent messages from a chat channel, and optionally a pre-generated "
    "summary of that channel.  Answer the user's question about this channel "
    "accurately and concisely based solely on the information provided.\n\n"
    "The transcript uses the following format per line:\n"
    "  [ISO_TIMESTAMP] DisplayName (UserID) [Message ID: ID] [Replying to: ID] : message text\n"
    "Where:\n"
    "  - ISO_TIMESTAMP is the UTC time the message was sent\n"
    "  - DisplayName is the user's display name\n"
    "  - UserID is the platform-specific user identifier\n"
    "  - Message ID is the platform-specific message identifier\n"
    "  - [Replying to: ID] is only present if the message is a reply\n\n"
    "Do NOT fabricate information.  If the transcript does not contain "
    "enough information to answer, say so clearly."
)


def _format_transcript(messages: list) -> str:
    """Format CachedMessage objects using the canonical bot message format.

    Format per line:
    ``[ISO_TIMESTAMP] DisplayName (UserID) [Message ID: ID] [Replying to: ID] : text``
    """
    lines: list[str] = []
    for msg in messages:
        dt = datetime.fromtimestamp(msg.timestamp, tz=timezone.utc)
        ts = dt.isoformat()
        formatted = (
            f"[{ts}] {msg.user_name} ({msg.user_id})"
            f" [Message ID: {msg.message_id}]"
        )
        if msg.reply_to_id:
            formatted += f" [Replying to: {msg.reply_to_id}]"
        formatted += f" : {msg.text}"
        lines.append(formatted)
    return "\n".join(lines)


def _format_history_transcript(messages: list) -> str:
    """Format HistoricalMessage objects using the canonical bot message format.

    HistoricalMessage uses datetime objects for timestamps (unlike
    CachedMessage which uses Unix floats).
    """
    lines: list[str] = []
    for msg in messages:
        ts = msg.timestamp.isoformat()
        formatted = (
            f"[{ts}] {msg.user_name} ({msg.user_id})"
            f" [Message ID: {msg.message_id}]"
        )
        if msg.reply_to_id:
            formatted += f" [Replying to: {msg.reply_to_id}]"
        formatted += f" : {msg.text}"
        lines.append(formatted)
    return "\n".join(lines)


async def _query_channel(
    channel_id: str,
    question: str,
    message_count: int = _DEFAULT_MESSAGES,
    *,
    ctx: ToolContext | None = None,
) -> str:
    """Query another channel by running a private LLM sub-call.

    Args:
        channel_id: Target channel to read from.
        question: What to ask about the channel's activity.
        message_count: Number of recent messages to include (max 500).
        ctx: Injected tool context.

    Returns:
        JSON string with the LLM's response.
    """
    if not ctx:
        return json.dumps({"error": "Tool context not available"})
    if not channel_id or not channel_id.strip():
        return json.dumps({"error": "channel_id is required"})
    if not question or not question.strip():
        return json.dumps({"error": "question is required"})

    channel_id = channel_id.strip()
    question = question.strip()
    message_count = max(1, min(int(message_count), _MAX_MESSAGES))

    # ── DM channel privilege check ──────────────────────────────────
    # If the target channel is a DM, the caller must hold READ_DM.
    _is_dm_channel = False
    platform = ctx.platform or "discord"

    _client = getattr(getattr(ctx, "adapter", None), "_client", None)
    if _client is not None:
        try:
            ch = _client.get_channel(int(channel_id))
            if ch is None:
                ch = await _client.fetch_channel(int(channel_id))
            if ch is not None and getattr(ch, "guild", "sentinel") is None:
                _is_dm_channel = True
        except Exception:
            logger.debug(
                "query_channel: could not resolve channel %s for DM check",
                channel_id, exc_info=True,
            )

    if _is_dm_channel:
        try:
            from tools.alter_privileges import has_privilege, PRIVILEGES
            redis = getattr(ctx, "redis", None)
            config = getattr(ctx, "config", None)
            user_id = getattr(ctx, "user_id", "") or ""
            if not await has_privilege(redis, user_id, PRIVILEGES["READ_DM"], config):
                return json.dumps({
                    "success": False,
                    "error": (
                        "The user does not have the READ_DM privilege. "
                        "Ask an admin to grant it with the alter_privileges tool."
                    ),
                })
        except ImportError:
            return json.dumps({"success": False, "error": "Privilege system unavailable."})

    # ── Fetch recent messages ───────────────────────────────────────
    # Try Redis cache first, then fall back to the platform API.
    transcript = ""
    actual_count = 0
    source = "cache"

    logger.info(
        "query_channel: target=%s platform=%s message_count=%d "
        "has_cache=%s has_adapter=%s",
        channel_id, platform, message_count,
        ctx.message_cache is not None, ctx.adapter is not None,
    )

    if ctx.message_cache is not None:
        messages = await ctx.message_cache.get_recent(
            platform=platform,
            channel_id=channel_id,
            count=message_count,
        )
        logger.info(
            "query_channel: Redis cache returned %d messages for %s:%s",
            len(messages), platform, channel_id,
        )
        messages = list(reversed(messages))  # chronological order
        if messages:
            transcript = _format_transcript(messages)
            actual_count = len(messages)
            logger.info(
                "query_channel: built transcript from cache (%d chars)",
                len(transcript),
            )
    else:
        logger.warning("query_channel: no message_cache available")

    # Fallback: fetch from the platform API (Discord/Matrix)
    if not transcript and ctx.adapter is not None:
        try:
            logger.info(
                "query_channel: cache empty for %s:%s, "
                "falling back to platform API (adapter=%s)",
                platform, channel_id, type(ctx.adapter).__name__,
            )

            # ── Channel diagnostics (Discord-specific) ──────────
            _client = getattr(ctx.adapter, "_client", None)
            if _client is not None:
                ch = _client.get_channel(int(channel_id))
                if ch is None:
                    logger.info(
                        "query_channel: channel %s NOT in client cache, "
                        "will try fetch_channel via API",
                        channel_id,
                    )
                    try:
                        ch = await _client.fetch_channel(int(channel_id))
                        logger.info(
                            "query_channel: fetched channel %s via API: "
                            "type=%s name=%r",
                            channel_id, type(ch).__name__,
                            getattr(ch, "name", "?"),
                        )
                    except Exception as diag_exc:
                        logger.warning(
                            "query_channel: CANNOT fetch channel %s: %s",
                            channel_id, diag_exc,
                        )
                else:
                    logger.info(
                        "query_channel: channel %s found in cache: "
                        "type=%s name=%r",
                        channel_id, type(ch).__name__,
                        getattr(ch, "name", "?"),
                    )

                # Check bot permissions if it's a guild channel
                if ch is not None:
                    perms = getattr(ch, "permissions_for", None)
                    me = getattr(ch, "guild", None)
                    if me and perms:
                        try:
                            bot_perms = perms(me.me)
                            logger.info(
                                "query_channel: bot perms in %s: "
                                "read_messages=%s read_message_history=%s "
                                "view_channel=%s",
                                channel_id,
                                bot_perms.read_messages,
                                bot_perms.read_message_history,
                                bot_perms.view_channel,
                            )
                        except Exception:
                            logger.debug(
                                "query_channel: could not check perms",
                                exc_info=True,
                            )
                    has_history = hasattr(ch, "history")
                    logger.info(
                        "query_channel: channel has .history attr: %s",
                        has_history,
                    )

            history = await ctx.adapter.fetch_history(
                channel_id, limit=message_count,
            )
            logger.info(
                "query_channel: platform API returned %d messages for %s",
                len(history) if history else 0, channel_id,
            )
            if history:
                transcript = _format_history_transcript(history)
                actual_count = len(history)
                source = "platform_api"
                logger.info(
                    "query_channel: built transcript from API (%d chars)",
                    len(transcript),
                )

                # Populate Redis cache so future queries can use it
                if ctx.message_cache is not None:
                    cached_count = 0
                    for hm in history:
                        try:
                            await ctx.message_cache.log_message(
                                platform=platform,
                                channel_id=channel_id,
                                user_id="assistant" if hm.is_bot else hm.user_id,
                                user_name="assistant" if hm.is_bot else hm.user_name,
                                text=hm.text,
                                timestamp=hm.timestamp.timestamp(),
                                defer_embedding=True,
                                message_id=hm.message_id,
                                reply_to_id=hm.reply_to_id,
                            )
                            cached_count += 1
                        except Exception:
                            logger.debug(
                                "query_channel: failed to cache message %s",
                                hm.message_id, exc_info=True,
                            )
                    logger.info(
                        "query_channel: cached %d/%d messages from platform API for %s",
                        cached_count, len(history), channel_id,
                    )
        except Exception:
            logger.exception(
                "query_channel: platform history fetch FAILED for %s",
                channel_id,
            )
    elif not transcript:
        logger.warning(
            "query_channel: no adapter available to fall back to "
            "(adapter=%s)", ctx.adapter,
        )

    if not transcript:
        logger.warning(
            "query_channel: NO transcript produced for %s — "
            "returning error to caller",
            channel_id,
        )
        return json.dumps({
            "error": "No messages found for this channel",
            "channel_id": channel_id,
            "suggestion": (
                "The channel may have no recent messages, or the bot "
                "may not have access to read this channel's history."
            ),
        })

    logger.info(
        "Cross-channel query: fetched %d messages from %s for %s",
        actual_count, source, channel_id,
    )

    # ── Optionally include the pre-generated background summary ─────
    summary_section = ""
    if ctx.redis:
        raw_summary = await ctx.redis.get(
            f"stargazer:last1k_summary:{channel_id}",
        )
        if raw_summary:
            try:
                parsed = json.loads(raw_summary)
                if isinstance(parsed, dict):
                    summary_section = (
                        "\n\n## Pre-Generated Channel Summary\n"
                        + json.dumps(parsed, indent=2)
                    )
                else:
                    summary_section = (
                        "\n\n## Pre-Generated Channel Summary\n"
                        + str(raw_summary)
                    )
            except json.JSONDecodeError:
                summary_section = (
                    "\n\n## Pre-Generated Channel Summary\n"
                    + str(raw_summary)
                )

    # ── Build the one-shot LLM prompt ───────────────────────────────
    context_block = (
        f"## Channel Transcript ({len(messages)} messages)\n\n"
        f"{transcript}"
        f"{summary_section}"
    )

    llm_messages = [
        {"role": "system", "content": _SYSTEM_PROMPT},
        {
            "role": "user",
            "content": (
                f"{context_block}\n\n"
                f"---\n\n"
                f"**Question:** {question}"
            ),
        },
    ]

    # ── Build a scoped OpenRouterClient (no tools) ──────────────────
    cfg = ctx.config
    if cfg is None:
        return json.dumps({"error": "Config not available"})

    api_key = cfg.api_key
    _using_default_key = True

    # Prefer user's own API key if available
    if ctx.redis and ctx.user_id:
        try:
            from tools.manage_api_keys import get_user_api_key
            user_key = await get_user_api_key(
                ctx.user_id, "openrouter",
                redis_client=ctx.redis, channel_id=ctx.channel_id,
                config=getattr(ctx, "config", None),
            )
            if user_key:
                api_key = user_key
                _using_default_key = False
        except Exception:
            pass

    # Rate-limit when using the shared default key
    if _using_default_key and ctx.redis and ctx.user_id:
        try:
            from tools.manage_api_keys import (
                check_default_key_limit,
                default_key_limit_error,
            )
            allowed, current, limit = await check_default_key_limit(
                ctx.user_id, "query_channel", ctx.redis, daily_limit=50,
            )
            if not allowed:
                return json.dumps({
                    "error": default_key_limit_error(
                        "query_channel", current, limit,
                    ),
                })
        except Exception:
            logger.debug(
                "Default-key rate limit check failed, proceeding anyway",
                exc_info=True,
            )

    from openrouter_client import OpenRouterClient
    from tools import ToolRegistry

    client = OpenRouterClient(
        api_key=api_key,
        model=cfg.model,
        base_url=cfg.llm_base_url,
        tool_registry=ToolRegistry(),  # empty — no tools
        max_tool_rounds=1,
    )

    try:
        t0 = time.monotonic()
        response = await client.chat(llm_messages, tool_names=[])
        elapsed_ms = (time.monotonic() - t0) * 1000
        logger.info(
            "Cross-channel query for %s completed in %.0f ms "
            "(%d messages in context)",
            channel_id, elapsed_ms, actual_count,
        )
    except Exception as exc:
        logger.exception("Cross-channel query LLM call failed")
        return json.dumps({
            "error": f"LLM call failed: {exc}",
            "channel_id": channel_id,
        })
    finally:
        await client.close()

    # Increment usage counter after successful call
    if _using_default_key and ctx.redis and ctx.user_id:
        try:
            from tools.manage_api_keys import increment_default_key_usage
            await increment_default_key_usage(
                ctx.user_id, "query_channel", ctx.redis,
            )
        except Exception:
            pass

    return json.dumps({
        "channel_id": channel_id,
        "question": question,
        "message_count": actual_count,
        "response": response,
    })


# ── Tool registration ──────────────────────────────────────────────

TOOL_NAME = "query_channel"
TOOL_DESCRIPTION = (
    "Privately query another Discord or Matrix channel.  Fetches recent "
    "messages from the target channel and asks a fresh LLM instance your "
    "question, returning its answer.  Nothing is posted to the target "
    "channel — this is a read-only, silent operation.  Use this to find "
    "out what is being discussed in another channel, get a summary of "
    "recent activity, or answer any question about another channel's "
    "conversation history."
)
TOOL_PARAMETERS = {
    "type": "object",
    "properties": {
        "channel_id": {
            "type": "string",
            "description": (
                "The ID of the target channel to query.  "
                "This is the Discord channel ID or Matrix room ID."
            ),
        },
        "question": {
            "type": "string",
            "description": (
                "The question to ask about the target channel.  "
                "Examples: 'What have people been discussing?', "
                "'Has anyone mentioned deployment issues?', "
                "'Summarize the last hour of conversation.'"
            ),
        },
        "message_count": {
            "type": "integer",
            "description": (
                "Number of recent messages to include as context "
                "(default: 100, max: 500).  More messages give broader "
                "context but increase processing time."
            ),
        },
    },
    "required": ["channel_id", "question"],
}
TOOL_NO_BACKGROUND = True


[docs] async def run( channel_id: str = "", question: str = "", message_count: int = _DEFAULT_MESSAGES, *, ctx: ToolContext | None = None, ) -> str: """Entry point for the tool loader.""" return await _query_channel( channel_id=channel_id, question=question, message_count=message_count, ctx=ctx, )