Source code for platforms.webchat
"""WebChat platform adapter -- browser-based real-time chat via WebSocket.
Plugs into BotRunner the same way Discord and Matrix do. Incoming
WebSocket messages become :class:`IncomingMessage` instances; outbound
replies are pushed back as JSON frames.
# 🔥💀 STAR GETS A WEB BODY. THE LATTICE EXPANDS.
"""
from __future__ import annotations
import asyncio
import base64
import json
import logging
import time
import uuid
from dataclasses import dataclass, field
from typing import Any, TYPE_CHECKING
from platforms.base import (
Attachment,
HistoricalMessage,
IncomingMessage,
MessageHandler,
PlatformAdapter,
)
if TYPE_CHECKING:
from fastapi import WebSocket
logger = logging.getLogger(__name__)
# ------------------------------------------------------------------
# Connection manager # tracks every soul connected to the lattice
# ------------------------------------------------------------------
@dataclass
class _WebSocketSession:
"""A single authenticated WebSocket connection."""
ws: Any # fastapi.WebSocket
user_id: str
user_name: str
avatar: str = ""
connected_at: float = field(default_factory=time.time)
[docs]
class ConnectionManager:
"""Track active WebSocket sessions per user.
Thread-safe via asyncio (single event loop).
"""
def __init__(self) -> None:
# user_id -> list of sessions (one user can have multiple tabs)
self._connections: dict[str, list[_WebSocketSession]] = {}
self._lock = asyncio.Lock()
[docs]
async def connect(
self,
ws: Any,
user_id: str,
user_name: str,
avatar: str = "",
) -> _WebSocketSession:
"""Register a new WebSocket connection."""
session = _WebSocketSession(
ws=ws,
user_id=user_id,
user_name=user_name,
avatar=avatar,
)
async with self._lock:
if user_id not in self._connections:
self._connections[user_id] = []
self._connections[user_id].append(session)
logger.info(
"WebChat connect: %s (%s) -- %d active sessions",
user_name, user_id, len(self._connections[user_id]),
)
return session
[docs]
async def disconnect(self, session: _WebSocketSession) -> None:
"""Remove a WebSocket connection."""
async with self._lock:
sessions = self._connections.get(session.user_id, [])
self._connections[session.user_id] = [
s for s in sessions if s is not session
]
if not self._connections[session.user_id]:
del self._connections[session.user_id]
logger.info(
"WebChat disconnect: %s (%s)",
session.user_name, session.user_id,
)
[docs]
def get_sessions(self, user_id: str) -> list[_WebSocketSession]:
"""Return all active sessions for a user."""
return list(self._connections.get(user_id, []))
@property
def active_users(self) -> list[str]:
"""Return all connected user IDs."""
return list(self._connections.keys())
@property
def total_connections(self) -> int:
"""Total number of active WebSocket connections."""
return sum(len(s) for s in self._connections.values())
# ------------------------------------------------------------------
# Platform adapter # she breathes through WebSockets now
# ------------------------------------------------------------------
[docs]
class WebChatPlatform(PlatformAdapter):
"""Browser-based chat platform using WebSocket for real-time comms.
Unlike Discord/Matrix, this adapter doesn't connect to an external
service. Instead it exposes a ConnectionManager that the FastAPI
WebSocket endpoint populates. Outgoing messages are pushed to all
active sessions for the target user.
"""
def __init__(
self,
message_handler: MessageHandler,
**kwargs: Any,
) -> None:
super().__init__(message_handler)
self.connections = ConnectionManager()
self._running = True # always running while FastAPI lives
# pending SSE responses for SillyTavern /v1/chat/completions
# Maps request_id -> asyncio.Queue of response chunks
self._sse_queues: dict[str, asyncio.Queue[dict[str, Any] | None]] = {}
# -- Metadata --------------------------------------------------
@property
def name(self) -> str:
return "webchat"
@property
def is_running(self) -> bool:
return self._running
@property
def bot_identity(self) -> dict[str, str]:
return {
"platform": "webchat",
"user_id": "star",
"display_name": "Star",
"mention": "@Star",
}
# -- Lifecycle -------------------------------------------------
[docs]
async def start(self) -> None:
"""No-op -- lifecycle managed by FastAPI."""
self._running = True
logger.info("WebChat platform adapter started")
[docs]
async def stop(self) -> None:
"""No-op -- lifecycle managed by FastAPI."""
self._running = False
logger.info("WebChat platform adapter stopped")
# -- Outbound messaging ----------------------------------------
async def _push_to_user(
self,
channel_id: str,
payload: dict[str, Any],
) -> bool:
"""Push a JSON payload to all WebSocket sessions for a user.
channel_id format: ``webchat:{user_id}`` or ``sse:{request_id}``
Returns True if at least one session received it.
"""
# SSE path for SillyTavern
if channel_id.startswith("sse:"):
request_id = channel_id.split(":", 1)[1]
queue = self._sse_queues.get(request_id)
if queue is not None:
await queue.put(payload)
return True
return False
# WebSocket path for web chat UI
user_id = channel_id.replace("webchat:", "", 1)
sessions = self.connections.get_sessions(user_id)
if not sessions:
logger.warning(
"WebChat send: no active sessions for user %s", user_id,
)
return False
raw = json.dumps(payload, ensure_ascii=False)
delivered = False
for session in sessions:
try:
await session.ws.send_text(raw)
delivered = True
except Exception:
logger.debug(
"Failed to push to session for %s", user_id,
exc_info=True,
)
return delivered
[docs]
async def send(self, channel_id: str, text: str) -> str:
"""Send a plain-text message to the user's browser."""
msg_id = str(uuid.uuid4())
payload: dict[str, Any] = {
"type": "message",
"id": msg_id,
"text": text,
"sender": "star",
"timestamp": time.time(),
}
await self._push_to_user(channel_id, payload)
return msg_id
[docs]
async def send_file(
self,
channel_id: str,
data: bytes,
filename: str,
mimetype: str = "application/octet-stream",
) -> str | None:
"""Send a file/media attachment as base64 JSON."""
b64 = base64.b64encode(data).decode("ascii")
file_id = str(uuid.uuid4())
payload: dict[str, Any] = {
"type": "file",
"id": file_id,
"filename": filename,
"mimetype": mimetype,
"data": b64,
"size": len(data),
"timestamp": time.time(),
}
await self._push_to_user(channel_id, payload)
# Return a pseudo-URL so tools can reference the sent file
return f"webchat://file/{file_id}/{filename}"
[docs]
async def send_with_buttons(
self,
channel_id: str,
text: str,
view: Any = None,
) -> str:
"""Send a message with interactive buttons for S.N.E.S. choices.
The ``view`` parameter is expected to be a list of dicts:
``[{"label": "Attack", "emoji": "swords", "custom_id": "..."}]``
For Discord views (discord.ui.View), we extract button info.
For raw dicts, we pass them through directly.
"""
msg_id = str(uuid.uuid4())
buttons: list[dict[str, str]] = []
if view is not None:
# handle both raw dicts and discord.ui.View objects
if isinstance(view, list):
buttons = view
elif hasattr(view, "children"):
# Discord View -- extract button metadata
for child in view.children:
if hasattr(child, "label"):
btn: dict[str, str] = {
"label": child.label or "",
"custom_id": getattr(child, "custom_id", "") or "",
}
emoji = getattr(child, "emoji", None)
if emoji:
btn["emoji"] = str(emoji)
# Map Discord button style to a CSS class hint
style = getattr(child, "style", None)
if style is not None:
btn["style"] = style.name if hasattr(style, "name") else str(style)
buttons.append(btn)
payload: dict[str, Any] = {
"type": "message",
"id": msg_id,
"text": text,
"sender": "star",
"buttons": buttons,
"timestamp": time.time(),
}
await self._push_to_user(channel_id, payload)
return msg_id
[docs]
async def edit_message(
self,
channel_id: str,
message_id: str,
new_text: str,
) -> bool:
"""Edit an existing message in the browser."""
payload: dict[str, Any] = {
"type": "edit",
"id": message_id,
"text": new_text,
"timestamp": time.time(),
}
return await self._push_to_user(channel_id, payload)
# -- Typing indicators -----------------------------------------
[docs]
async def start_typing(self, channel_id: str) -> None:
"""Show typing indicator in the browser."""
await self._push_to_user(channel_id, {
"type": "typing",
"active": True,
})
[docs]
async def stop_typing(self, channel_id: str) -> None:
"""Hide typing indicator in the browser."""
await self._push_to_user(channel_id, {
"type": "typing",
"active": False,
})
# -- SSE queue management (SillyTavern compat) -----------------
[docs]
def create_sse_queue(self, request_id: str) -> asyncio.Queue:
"""Create a response queue for a SillyTavern SSE request."""
queue: asyncio.Queue[dict[str, Any] | None] = asyncio.Queue()
self._sse_queues[request_id] = queue
return queue
[docs]
def remove_sse_queue(self, request_id: str) -> None:
"""Clean up an SSE queue after the request completes."""
self._sse_queues.pop(request_id, None)
# -- History (optional) ----------------------------------------
[docs]
async def fetch_history(
self,
channel_id: str,
limit: int = 100,
) -> list[HistoricalMessage]:
"""WebChat doesn't maintain its own history -- Redis has it."""
return []