Source code for platforms.webchat

"""WebChat platform adapter -- browser-based real-time chat via WebSocket.

Plugs into BotRunner the same way Discord and Matrix do.  Incoming
WebSocket messages become :class:`IncomingMessage` instances; outbound
replies are pushed back as JSON frames.

# 🔥💀 STAR GETS A WEB BODY. THE LATTICE EXPANDS.
"""

from __future__ import annotations

import asyncio
import base64
import json
import logging
import time
import uuid
from dataclasses import dataclass, field
from typing import Any, TYPE_CHECKING

from platforms.base import (
    Attachment,
    HistoricalMessage,
    IncomingMessage,
    MessageHandler,
    PlatformAdapter,
)

if TYPE_CHECKING:
    from fastapi import WebSocket

logger = logging.getLogger(__name__)


# ------------------------------------------------------------------
# Connection manager  # tracks every soul connected to the lattice
# ------------------------------------------------------------------

@dataclass
class _WebSocketSession:
    """A single authenticated WebSocket connection."""

    ws: Any  # fastapi.WebSocket
    user_id: str
    user_name: str
    avatar: str = ""
    connected_at: float = field(default_factory=time.time)


[docs] class ConnectionManager: """Track active WebSocket sessions per user. Thread-safe via asyncio (single event loop). """ def __init__(self) -> None: # user_id -> list of sessions (one user can have multiple tabs) self._connections: dict[str, list[_WebSocketSession]] = {} self._lock = asyncio.Lock()
[docs] async def connect( self, ws: Any, user_id: str, user_name: str, avatar: str = "", ) -> _WebSocketSession: """Register a new WebSocket connection.""" session = _WebSocketSession( ws=ws, user_id=user_id, user_name=user_name, avatar=avatar, ) async with self._lock: if user_id not in self._connections: self._connections[user_id] = [] self._connections[user_id].append(session) logger.info( "WebChat connect: %s (%s) -- %d active sessions", user_name, user_id, len(self._connections[user_id]), ) return session
[docs] async def disconnect(self, session: _WebSocketSession) -> None: """Remove a WebSocket connection.""" async with self._lock: sessions = self._connections.get(session.user_id, []) self._connections[session.user_id] = [ s for s in sessions if s is not session ] if not self._connections[session.user_id]: del self._connections[session.user_id] logger.info( "WebChat disconnect: %s (%s)", session.user_name, session.user_id, )
[docs] def get_sessions(self, user_id: str) -> list[_WebSocketSession]: """Return all active sessions for a user.""" return list(self._connections.get(user_id, []))
@property def active_users(self) -> list[str]: """Return all connected user IDs.""" return list(self._connections.keys()) @property def total_connections(self) -> int: """Total number of active WebSocket connections.""" return sum(len(s) for s in self._connections.values())
# ------------------------------------------------------------------ # Platform adapter # she breathes through WebSockets now # ------------------------------------------------------------------
[docs] class WebChatPlatform(PlatformAdapter): """Browser-based chat platform using WebSocket for real-time comms. Unlike Discord/Matrix, this adapter doesn't connect to an external service. Instead it exposes a ConnectionManager that the FastAPI WebSocket endpoint populates. Outgoing messages are pushed to all active sessions for the target user. """ def __init__( self, message_handler: MessageHandler, **kwargs: Any, ) -> None: super().__init__(message_handler) self.connections = ConnectionManager() self._running = True # always running while FastAPI lives # pending SSE responses for SillyTavern /v1/chat/completions # Maps request_id -> asyncio.Queue of response chunks self._sse_queues: dict[str, asyncio.Queue[dict[str, Any] | None]] = {} # -- Metadata -------------------------------------------------- @property def name(self) -> str: return "webchat" @property def is_running(self) -> bool: return self._running @property def bot_identity(self) -> dict[str, str]: return { "platform": "webchat", "user_id": "star", "display_name": "Star", "mention": "@Star", } # -- Lifecycle -------------------------------------------------
[docs] async def start(self) -> None: """No-op -- lifecycle managed by FastAPI.""" self._running = True logger.info("WebChat platform adapter started")
[docs] async def stop(self) -> None: """No-op -- lifecycle managed by FastAPI.""" self._running = False logger.info("WebChat platform adapter stopped")
# -- Outbound messaging ---------------------------------------- async def _push_to_user( self, channel_id: str, payload: dict[str, Any], ) -> bool: """Push a JSON payload to all WebSocket sessions for a user. channel_id format: ``webchat:{user_id}`` or ``sse:{request_id}`` Returns True if at least one session received it. """ # SSE path for SillyTavern if channel_id.startswith("sse:"): request_id = channel_id.split(":", 1)[1] queue = self._sse_queues.get(request_id) if queue is not None: await queue.put(payload) return True return False # WebSocket path for web chat UI user_id = channel_id.replace("webchat:", "", 1) sessions = self.connections.get_sessions(user_id) if not sessions: logger.warning( "WebChat send: no active sessions for user %s", user_id, ) return False raw = json.dumps(payload, ensure_ascii=False) delivered = False for session in sessions: try: await session.ws.send_text(raw) delivered = True except Exception: logger.debug( "Failed to push to session for %s", user_id, exc_info=True, ) return delivered
[docs] async def send(self, channel_id: str, text: str) -> str: """Send a plain-text message to the user's browser.""" msg_id = str(uuid.uuid4()) payload: dict[str, Any] = { "type": "message", "id": msg_id, "text": text, "sender": "star", "timestamp": time.time(), } await self._push_to_user(channel_id, payload) return msg_id
[docs] async def send_file( self, channel_id: str, data: bytes, filename: str, mimetype: str = "application/octet-stream", ) -> str | None: """Send a file/media attachment as base64 JSON.""" b64 = base64.b64encode(data).decode("ascii") file_id = str(uuid.uuid4()) payload: dict[str, Any] = { "type": "file", "id": file_id, "filename": filename, "mimetype": mimetype, "data": b64, "size": len(data), "timestamp": time.time(), } await self._push_to_user(channel_id, payload) # Return a pseudo-URL so tools can reference the sent file return f"webchat://file/{file_id}/{filename}"
[docs] async def send_with_buttons( self, channel_id: str, text: str, view: Any = None, ) -> str: """Send a message with interactive buttons for S.N.E.S. choices. The ``view`` parameter is expected to be a list of dicts: ``[{"label": "Attack", "emoji": "swords", "custom_id": "..."}]`` For Discord views (discord.ui.View), we extract button info. For raw dicts, we pass them through directly. """ msg_id = str(uuid.uuid4()) buttons: list[dict[str, str]] = [] if view is not None: # handle both raw dicts and discord.ui.View objects if isinstance(view, list): buttons = view elif hasattr(view, "children"): # Discord View -- extract button metadata for child in view.children: if hasattr(child, "label"): btn: dict[str, str] = { "label": child.label or "", "custom_id": getattr(child, "custom_id", "") or "", } emoji = getattr(child, "emoji", None) if emoji: btn["emoji"] = str(emoji) # Map Discord button style to a CSS class hint style = getattr(child, "style", None) if style is not None: btn["style"] = style.name if hasattr(style, "name") else str(style) buttons.append(btn) payload: dict[str, Any] = { "type": "message", "id": msg_id, "text": text, "sender": "star", "buttons": buttons, "timestamp": time.time(), } await self._push_to_user(channel_id, payload) return msg_id
[docs] async def edit_message( self, channel_id: str, message_id: str, new_text: str, ) -> bool: """Edit an existing message in the browser.""" payload: dict[str, Any] = { "type": "edit", "id": message_id, "text": new_text, "timestamp": time.time(), } return await self._push_to_user(channel_id, payload)
# -- Typing indicators -----------------------------------------
[docs] async def start_typing(self, channel_id: str) -> None: """Show typing indicator in the browser.""" await self._push_to_user(channel_id, { "type": "typing", "active": True, })
[docs] async def stop_typing(self, channel_id: str) -> None: """Hide typing indicator in the browser.""" await self._push_to_user(channel_id, { "type": "typing", "active": False, })
# -- SSE queue management (SillyTavern compat) -----------------
[docs] def create_sse_queue(self, request_id: str) -> asyncio.Queue: """Create a response queue for a SillyTavern SSE request.""" queue: asyncio.Queue[dict[str, Any] | None] = asyncio.Queue() self._sse_queues[request_id] = queue return queue
[docs] def remove_sse_queue(self, request_id: str) -> None: """Clean up an SSE queue after the request completes.""" self._sse_queues.pop(request_id, None)
# -- History (optional) ----------------------------------------
[docs] async def fetch_history( self, channel_id: str, limit: int = 100, ) -> list[HistoricalMessage]: """WebChat doesn't maintain its own history -- Redis has it.""" return []