"""Central response postprocessing pipeline.
Transforms raw LLM output into clean, Discord-friendly text before it is
sent to any platform adapter. The pipeline runs every step in order:
1. Extract and strip ``<thought>`` / ``<thinking>`` / ``💡thought…</font>`` tags
2. Wrap raw (undelimited) LaTeX in display math delimiters
3. Convert LaTeX to Discord-friendly Unicode
4. Convert Markdown tables to Unicode box-drawing tables
5. Strip echoed message-metadata patterns
6. Filter backticks around Discord mentions
7. Strip orphaned XML-style tags (e.g. ``</xai:function_call>``)
8. Strip hallucinated tool-call JSON / ``<tool_call>`` tags
9. Replace special tokens
10. Strip any leading preamble text before the first ``[``
"""
from __future__ import annotations
import logging
import os
import re
from typing import Tuple
import httpx
from latex_converter import convert_latex_to_discord
from message_utils import filter_backticks_from_mentions
logger = logging.getLogger(__name__)
# LLM filter: local proxy endpoint and model
_LLM_FILTER_API_URL = "http://localhost:3000/openai/chat/completions"
_LLM_FILTER_MODEL = "gemini-3-flash-preview"
# Default system prompt for detecting undesirable response behaviors.
# Adapt this to target different issues (overrefusal, nonsense, self-repeat, etc.).
_DEFAULT_LLM_FILTER_SYSTEM = """You are a binary classifier. Given an AI assistant's response, answer ONLY "YES" or "NO".
YES = the response is a refusal to answer the question or provide the information requested.
YES = the response is clearly nonsense or self-repetition.
NO = literally anything else.
Output nothing else—no explanation, no punctuation, no reasoning. Just YES or NO."""
[docs]
async def llm_filter_response(
response_text: str,
system_prompt: str | None = None,
api_key: str | None = None,
) -> bool:
"""Run an LLM-based filter to detect undesirable response behaviors.
Sends the response to Gemini Flash via the local proxy.
The model answers YES (undesirable) or NO (acceptable). Reasoning
is disabled; any extra text is stripped to extract the verdict.
Args:
response_text: The original LLM response to evaluate.
system_prompt: Custom system prompt for the classifier. If None,
uses a default that targets overrefusal, nonsense, self-repeat.
api_key: OpenRouter API key. If None, uses OPENROUTER_API_KEY
or API_KEY env, or config.api_key.
Returns:
True if the response is undesirable (filter it out), False if
acceptable or on error (fail-open).
"""
if not response_text or not response_text.strip():
return False
key = api_key or os.environ.get("OPENROUTER_API_KEY") or os.environ.get("API_KEY")
if not key:
try:
from config import Config
cfg = Config.load()
key = cfg.api_key or ""
except Exception:
pass
if not key:
logger.warning("llm_filter_response: no API key, skipping filter")
return False
sys_prompt = system_prompt or _DEFAULT_LLM_FILTER_SYSTEM
model = _LLM_FILTER_MODEL
payload = {
"model": model,
"messages": [
{"role": "system", "content": sys_prompt},
{"role": "user", "content": response_text},
],
"temperature": 0.0,
"max_tokens": 5,
}
headers = {
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
}
try:
async with httpx.AsyncClient(timeout=15.0) as client:
resp = await client.post(_LLM_FILTER_API_URL, json=payload, headers=headers)
resp.raise_for_status()
data = resp.json()
except Exception as e:
logger.warning("llm_filter_response failed: %s", e)
return False
content = (
data.get("choices", [{}])[0]
.get("message", {})
.get("content", "")
.strip()
.upper()
)
# Strip reasoning: first word wins; else look for standalone YES/NO
words = content.split()
if words and words[0] == "YES":
return True
if words and words[0] == "NO":
return False
if re.search(r"\bYES\b", content):
return True
if re.search(r"\bNO\b", content):
return False
return False # fail-open: unknown format → treat as acceptable
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
[docs]
def postprocess_response(text: str) -> str:
"""Run the full postprocessing pipeline on *text* and return the result.
Returns an empty string if the input is ``None`` or whitespace-only after
processing.
"""
if not text or not isinstance(text, str):
return ""
text, thoughts = extract_and_strip_thoughts(text)
if thoughts:
logger.info(
"Stripped %d thought block(s) from response", len(thoughts)
)
text = wrap_raw_latex(text)
text = convert_latex_to_discord(text)
text = convert_markdown_tables_to_unicode(text)
text = strip_message_metadata(text)
text = filter_backticks_from_mentions(text)
text = strip_orphaned_tags(text)
text = strip_tool_call_artifacts(text)
text = replace_special_tokens(text)
text = strip_leading_preamble(text)
return text.strip() if text else ""
[docs]
def postprocess_intermediate_response(text: str) -> str:
"""Lightweight cleanup for assistant text emitted during tool-use rounds.
Skips LaTeX/table conversion and ``strip_leading_preamble`` so short
user-visible status lines (e.g. “Checking that now…”) are not dropped when
the model has not yet emitted a full formatted reply header.
"""
if not text or not isinstance(text, str):
return ""
text, thoughts = extract_and_strip_thoughts(text)
if thoughts:
logger.debug(
"Intermediate: stripped %d thought block(s)", len(thoughts),
)
text = strip_message_metadata(text)
text = filter_backticks_from_mentions(text)
text = strip_orphaned_tags(text)
text = strip_tool_call_artifacts(text)
text = replace_special_tokens(text)
return text.strip() if text else ""
# ------------------------------------------------------------------
# 1. Thought / thinking tag extraction
# ------------------------------------------------------------------
[docs]
def extract_and_strip_thoughts(text: str) -> Tuple[str, list[str]]:
"""Remove ``<thought>``, ``<thinking>``, and ``💡thought…</font>`` blocks.
Returns ``(cleaned_text, list_of_thought_strings)``.
"""
thought_pat = r"<thought>(.*?)</thought>"
thinking_pat = r"<thinking>(.*?)</thinking>"
glitch_pat = r"💡thought(.*?)</font>"
thoughts = re.findall(thought_pat, text, re.DOTALL)
thoughts.extend(re.findall(thinking_pat, text, re.DOTALL))
thoughts.extend(re.findall(glitch_pat, text, re.DOTALL))
cleaned = re.sub(thought_pat, "", text, flags=re.DOTALL)
cleaned = re.sub(thinking_pat, "", cleaned, flags=re.DOTALL)
cleaned = re.sub(glitch_pat, "", cleaned, flags=re.DOTALL)
cleaned = re.sub(r"\n\s*\n\s*\n", "\n\n", cleaned)
return cleaned.strip(), thoughts
# ------------------------------------------------------------------
# 2. Raw LaTeX wrapping
# ------------------------------------------------------------------
_LATEX_INDICATORS = [
r"\\frac\{",
r"\\int",
r"\\sum",
r"\\prod",
r"\\lim",
r"\\partial",
r"\\nabla",
r"\\mathbf\{",
r"\\vec\{",
r"\\hat\{",
r"\\mathbb\{",
r"\\mathcal\{",
]
[docs]
def wrap_raw_latex(text: str) -> str:
"""Detect raw (undelimited) LaTeX and wrap it in ``$$...$$``."""
if not text or not isinstance(text, str):
return text
# Strip trailing orphaned $$
text = re.sub(r"[\n\s]*\$\$\s*$", "", text)
text = re.sub(r"(\\n)+\s*\$\$\s*$", "", text)
text = text.strip()
# Already has properly paired delimiters -- leave alone
if re.search(r"\$\$[\s\S]+?\$\$", text):
return text
if re.search(r"\$[^\$]+\$", text):
return text
if re.search(r"\\\[[\s\S]+?\\\]", text):
return text
if re.search(r"\\\([\s\S]+?\\\)", text):
return text
if re.search(r"\\begin\{", text):
return text
has_latex = any(re.search(p, text) for p in _LATEX_INDICATORS)
has_math_ops = bool(re.search(r"[=+\-*/^_\\\{\}]", text))
if has_latex and has_math_ops and len(text.strip()) > 10:
# Mixed prose + math: wrapping the whole response would mangle it.
# Only wrap when the text is predominantly a math expression.
if re.search(r'[.!?]\s+[A-Z]', text):
return text
non_math = re.sub(r'\\[A-Za-z]+\{[^}]*\}|[=+\-*/^_\\{}]', '', text)
if len(non_math.strip()) > len(text) * 0.4:
return text
return f"$$\n{text.strip()}\n$$"
return text
# ------------------------------------------------------------------
# 4. Markdown tables -> Unicode box-drawing
# ------------------------------------------------------------------
_TABLE_PATTERN = re.compile(
r"(?:^|\n)(\|[^\n]+\|)\n" # header row
r"(\|[\s:|-]+\|)\n" # separator row
r"((?:\|[^\n]+\|\n?)+)", # data rows
re.MULTILINE,
)
[docs]
def convert_markdown_tables_to_unicode(text: str) -> str:
"""Convert Markdown tables to Unicode box-drawing character tables."""
if not text or "|" not in text:
return text
def _in_code_block(full_text: str, match_start: int) -> bool:
"""Internal helper: in code block.
Args:
full_text (str): The full text value.
match_start (int): The match start value.
Returns:
bool: True on success, False otherwise.
"""
return full_text[:match_start].count("```") % 2 == 1
def _convert(m: re.Match) -> str:
"""Internal helper: convert.
Args:
m (re.Match): The m value.
Returns:
str: Result string.
"""
header_line = m.group(1).strip()
data_lines = m.group(3).strip()
headers = [c.strip() for c in header_line.split("|")[1:-1]]
rows: list[list[str]] = []
for line in data_lines.split("\n"):
line = line.strip()
if line and line.startswith("|") and line.endswith("|"):
cells = [c.strip() for c in line.split("|")[1:-1]]
if len(cells) == len(headers):
rows.append(cells)
if not headers or not rows:
return m.group(0)
col_widths = []
for i in range(len(headers)):
w = len(headers[i])
for row in rows:
if i < len(row):
w = max(w, len(row[i]))
col_widths.append(w)
lines: list[str] = []
lines.append(
"\u2554"
+ "\u2566".join("\u2550" * (w + 2) for w in col_widths)
+ "\u2557"
)
hdr_cells = [
f" {headers[i].ljust(col_widths[i])} "
for i in range(len(headers))
]
lines.append("\u2551" + "\u2551".join(hdr_cells) + "\u2551")
lines.append(
"\u2560"
+ "\u256c".join("\u2550" * (w + 2) for w in col_widths)
+ "\u2563"
)
for idx, row in enumerate(rows):
dcells = [
f" {row[i].ljust(col_widths[i])} "
for i in range(len(row))
]
lines.append("\u2551" + "\u2551".join(dcells) + "\u2551")
if idx < len(rows) - 1:
lines.append(
"\u2560"
+ "\u256c".join("\u2550" * (w + 2) for w in col_widths)
+ "\u2563"
)
lines.append(
"\u255a"
+ "\u2569".join("\u2550" * (w + 2) for w in col_widths)
+ "\u255d"
)
table_text = "\n".join(lines)
if _in_code_block(text, m.start()):
return "\n" + table_text + "\n"
return "\n```\n" + table_text + "\n```\n"
return re.sub(_TABLE_PATTERN, _convert, text)
# ------------------------------------------------------------------
# 5. Strip echoed message metadata
# ------------------------------------------------------------------
_METADATA_PATTERN = re.compile(
r"\[[\d\-:T\+\.]+\]\s+.+?(?:\s+\([^)]*\))?\s+\[Message ID:\s+\d+\]\s*:\s*"
)
# ------------------------------------------------------------------
# 7. Strip orphaned tags
# ------------------------------------------------------------------
# ------------------------------------------------------------------
# 8. Strip hallucinated tool-call syntax
# ------------------------------------------------------------------
_TOOL_CALL_JSON = re.compile(
r'\{"name":\s*"[a-z_]+".*?"arguments":\s*\{.*?\}\s*\}',
re.DOTALL,
)
# ------------------------------------------------------------------
# 9. Special token replacement
# ------------------------------------------------------------------
[docs]
def replace_special_tokens(text: str) -> str:
"""Replace known special tokens with their intended characters."""
text = text.replace("`arrow`", "\u2192")
return text
# ------------------------------------------------------------------
# 10. Strip leading preamble before first "["
# ------------------------------------------------------------------
[docs]
def strip_leading_preamble(text: str) -> str:
"""Strip leaked thought/preamble text before the status header.
When the model's ``<thinking>`` close tag is malformed the regex-based
strippers miss it, leaving residual thought content at the start of the
response. This safety net removes everything that precedes the header,
identified by ``[``` `` (backtick immediately after the opening bracket,
which is the model-name header format ``[`model` :: ...]``), or by
``[<code>`` (HTML header format used on Matrix).
Plain ``[`` characters used in markdown links, arrays, or prose are
left untouched.
"""
if not text:
return text
# Backtick-format header (Discord): [`model` :: ...]
idx = text.find("[`")
if idx > 0:
return text[idx:]
# HTML-format header (Matrix): [<code>model</code> :: ...]
idx = text.find("[<code>")
if idx > 0:
return text[idx:]
return text