"""Multimodal input sanitization, image compression, and model target helpers."""
from __future__ import annotations
import base64
import copy
import logging
from typing import Any
from platforms.media_common import (
CLAUDE_MAX_IMAGE_BYTES,
detect_image_mimetype_from_bytes,
shrink_image_under_max_bytes,
)
logger = logging.getLogger(__name__)
# Short model aliases that the local proxy understands but that differ from
# the canonical provider model id. Applied before every API call so caps,
# Gemini detection, and the response header all see the canonical name.
_MODEL_ALIASES: dict[str, str] = {
"gemini-3-pro-preview": "gemini-3.1-pro-preview",
}
# Resolved model ids (after :func:`resolve_model_alias`) that do not accept
# image / video / audio on the routed endpoint — strip before first request.
# Any OpenRouter id ending in ``:free`` is also treated as text-only (free
# tiers typically omit vision/audio/video).
_MODEL_IDS_TEXT_ONLY_BEFORE_REQUEST: frozenset[str] = frozenset(
{
"arcee-ai/trinity-large-thinking",
"deepseek/deepseek-v4-flash",
"deepseek/deepseek-v4-flash:free",
"deepseek/deepseek-v4-pro",
"moonshotai/kimi-k2.6",
"x-ai/grok-4.3",
"xiaomi/mimo-v2-pro",
"xiaomi/mimo-v2.5",
"xiaomi/mimo-v2.5-pro",
}
)
_RELABELED_ASSISTANT_PREFIX = "[My previous response:]\n"
_CLAUDE_OVERSIZED_IMAGE_PLACEHOLDER = (
"[Image omitted: larger than 4 MB and could not be compressed enough "
"for Claude-compatible models.]"
)
def _strip_audio_parts(messages: list[dict[str, Any]]) -> int:
"""Remove all ``input_audio`` content parts from *messages* in-place.
Used as a fallback retry when the endpoint returns a 404 "No endpoints
found that support input audio" error. Returns the number of parts
removed.
"""
removed = 0
for m in messages:
c = m.get("content")
if not isinstance(c, list):
continue
new_c = [
p
for p in c
if not (isinstance(p, dict) and "audio" in str(p.get("type", "")).lower())
]
if len(new_c) < len(c):
removed += len(c) - len(new_c)
m["content"] = new_c
return removed
def _strip_video_url_parts(messages: list[dict[str, Any]]) -> int:
"""Remove all ``video_url`` content parts from *messages* in-place.
Used before sending to models that do not support the ``video_url``
multimodal type (non-Gemini models), and for Claude-family models
(which also run :func:`_strip_video_and_audio_file_parts`). Returns
the number of parts removed.
"""
removed = 0
for m in messages:
c = m.get("content")
if not isinstance(c, list):
continue
new_c = [
p for p in c if not (isinstance(p, dict) and p.get("type") == "video_url")
]
if len(new_c) < len(c):
removed += len(c) - len(new_c)
m["content"] = new_c
return removed
def _data_uri_mediatype_lower(file_data: str) -> str | None:
"""Return the MIME type segment of a ``data:`` URI, lowercased, or ``None``."""
if not isinstance(file_data, str) or not file_data.lower().startswith("data:"):
return None
rest = file_data[5:]
semi = rest.find(";")
comma = rest.find(",")
end = len(rest)
if comma >= 0 and semi >= 0:
end = min(semi, comma)
elif comma >= 0:
end = comma
elif semi >= 0:
end = semi
mediatype = rest[:end].strip().lower()
return mediatype or None
def _strip_video_and_audio_file_parts(messages: list[dict[str, Any]]) -> int:
"""Remove ``file`` parts whose ``file_data`` is a ``data:video/*`` or ``data:audio/*`` URI.
Claude (and some proxies) reject or mishandle these structured parts; mutates
*messages* in place. Returns the number of parts removed.
"""
removed = 0
for m in messages:
c = m.get("content")
if not isinstance(c, list):
continue
new_c: list[Any] = []
for p in c:
drop = False
if isinstance(p, dict) and p.get("type") == "file":
finfo = p.get("file")
if isinstance(finfo, dict):
mime = _data_uri_mediatype_lower(finfo.get("file_data", ""))
if mime and (
mime.startswith("video/") or mime.startswith("audio/")
):
drop = True
if not drop:
new_c.append(p)
if len(new_c) < len(c):
removed += len(c) - len(new_c)
m["content"] = new_c
return removed
def _strip_image_url_parts(messages: list[dict[str, Any]]) -> int:
"""Remove all ``image_url`` content parts from *messages* in-place.
Used as a fallback retry when the endpoint returns HTTP 404 with a body
like "No endpoints found that support image input". Returns the number
of parts removed.
"""
removed = 0
for m in messages:
c = m.get("content")
if not isinstance(c, list):
continue
new_c = [
p for p in c if not (isinstance(p, dict) and p.get("type") == "image_url")
]
if len(new_c) < len(c):
removed += len(c) - len(new_c)
m["content"] = new_c
return removed
def _part_is_non_text(part: Any) -> bool:
"""Return True for any content part that is not a plain ``text`` part.
Covers every multimodal shape this codebase produces — ``image_url``,
``file`` (e.g. images/PDFs sent inline by image-generation tools and
embeds), ``video_url``, ``input_audio`` — plus any future/foreign type.
All of these are rejected by providers when they appear in an
assistant/model turn.
"""
if not isinstance(part, dict):
return False
return part.get("type") not in (None, "text")
def _relabel_assistant_image_turns_as_user(
messages: list[dict[str, Any]],
) -> int:
"""Re-role ``assistant`` turns that carry non-text media to ``user``.
No provider accepts media inside an assistant/model turn: Anthropic
rejects ``image`` blocks and Google Gemini returns ``INVALID_ARGUMENT``
for any inline media in a ``model`` turn. Bot-sent media — notably images
that image-generation tools post inside Discord embeds, which can be
represented as ``file`` or ``video_url`` parts rather than ``image_url`` —
must therefore be carried as ``user`` content. This converts any such
assistant turn so the model still sees the media without a validation
error.
A short text prefix marks the content as the model's own prior output (a
synthetic text part is inserted when the turn has no text part at all).
Idempotent: once relabeled to ``user``, subsequent calls skip it.
Mutates *messages* in place. Returns the number of messages relabeled.
"""
relabeled = 0
for m in messages:
if m.get("role") != "assistant":
continue
c = m.get("content")
if not isinstance(c, list):
continue
if not any(_part_is_non_text(p) for p in c):
continue
m["role"] = "user"
prefixed = False
for p in c:
if isinstance(p, dict) and p.get("type") == "text":
p["text"] = _RELABELED_ASSISTANT_PREFIX + p.get("text", "")
prefixed = True
break
if not prefixed:
c.insert(
0,
{"type": "text", "text": _RELABELED_ASSISTANT_PREFIX.rstrip("\n")},
)
relabeled += 1
return relabeled
def _cap_video_url_parts(messages: list[dict[str, Any]], max_keep: int) -> int:
"""Drop older ``video_url`` content parts; keep the last *max_keep* in order.
Traversal order: each message in *messages*, then each part in that
message's list ``content``. Mutates *messages* in place by assigning
new ``content`` lists only where parts were removed.
Returns the number of parts removed.
"""
if max_keep <= 0:
return 0
positions: list[tuple[int, int]] = []
for mi, m in enumerate(messages):
c = m.get("content")
if not isinstance(c, list):
continue
for pi, p in enumerate(c):
if isinstance(p, dict) and p.get("type") == "video_url":
positions.append((mi, pi))
n_vid = len(positions)
if n_vid <= max_keep:
return 0
remove_set = set(positions[: n_vid - max_keep])
stripped = n_vid - max_keep
by_msg: dict[int, set[int]] = {}
for mi, pi in remove_set:
by_msg.setdefault(mi, set()).add(pi)
for mi, drop_pi in by_msg.items():
m = messages[mi]
c = m.get("content")
if not isinstance(c, list):
continue
m["content"] = [p for pi, p in enumerate(c) if pi not in drop_pi]
return stripped
[docs]
def resolve_model_alias(model: str) -> str:
"""Expand a short model alias to its canonical provider model id.
The local proxy accepts friendly short names (e.g. ``gemini-3-pro-preview``)
but routes them to a canonical backend model. Resolving the alias here
ensures all downstream logic (token caps, Gemini detection, response header)
sees the true model name.
"""
return _MODEL_ALIASES.get(model, model)
def _model_targets_gemini(model: str) -> bool:
"""Return True if *model* is routed to Google Gemini.
Delegates to :func:`model_capabilities.get_capabilities` for registered
models; falls back to substring match for unregistered aliases.
"""
from model_capabilities import get_capabilities
caps = get_capabilities(model)
if caps.provider == "google":
return True
return "gemini" in model.lower()
def _model_targets_claude(model: str) -> bool:
"""Return True if *model* id refers to an Anthropic Claude family model.
Delegates to :func:`model_capabilities.get_capabilities` for registered
models; falls back to substring match for unregistered aliases.
"""
from model_capabilities import get_capabilities
caps = get_capabilities(model)
if caps.provider == "anthropic":
return True
return "claude" in model.lower()
def _model_targets_gemma4(model: str) -> bool:
"""Return True if *model* id refers to a Gemma 4 family model."""
return "gemma4" in model.lower().replace("-", "")
def _approx_b64_decoded_len(b64: str) -> int:
"""Upper-bound-style length of decoded base64 without decoding (valid padding)."""
pad = 0
if b64.endswith("=="):
pad = 2
elif b64.endswith("="):
pad = 1
return max(0, (len(b64) * 3) // 4 - pad)
def _clamp_claude_oversized_images(
messages: list[dict[str, Any]],
effective_model: str,
) -> int:
"""Shrink or strip ``image_url`` data-URI parts over 4 MiB for Claude models.
Mutates *messages* in place. Returns the number of parts shrunk or replaced.
"""
if not _model_targets_claude(effective_model):
return 0
import openrouter_client
max_image_bytes = getattr(openrouter_client, "CLAUDE_MAX_IMAGE_BYTES", CLAUDE_MAX_IMAGE_BYTES)
changed = 0
for m in messages:
c = m.get("content")
if not isinstance(c, list):
continue
for i, part in enumerate(c):
if not isinstance(part, dict) or part.get("type") != "image_url":
continue
iu = part.get("image_url")
if not isinstance(iu, dict):
continue
url = iu.get("url")
if not isinstance(url, str) or not url.startswith("data:"):
continue
try:
meta, b64_payload = url[5:].split(",", 1)
except ValueError:
continue
if "base64" not in meta.lower():
continue
mime = meta.split(";")[0].strip()
if not mime.lower().startswith("image/"):
continue
if _approx_b64_decoded_len(b64_payload) <= max_image_bytes:
continue
try:
raw = base64.b64decode(b64_payload)
except Exception:
c[i] = {"type": "text", "text": _CLAUDE_OVERSIZED_IMAGE_PLACEHOLDER}
changed += 1
continue
# Valid base64 of a payload above the limit implies len(raw) > limit.
# If the declared size is over the limit but decoded bytes are not,
# the data URL is corrupt (e.g. invalid alphabet decodes to empty).
if len(raw) <= max_image_bytes:
c[i] = {"type": "text", "text": _CLAUDE_OVERSIZED_IMAGE_PLACEHOLDER}
changed += 1
continue
smaller = shrink_image_under_max_bytes(
raw,
mime,
max_bytes=max_image_bytes,
)
if smaller is None:
c[i] = {"type": "text", "text": _CLAUDE_OVERSIZED_IMAGE_PLACEHOLDER}
changed += 1
continue
out_mime = detect_image_mimetype_from_bytes(smaller) or mime
new_b64 = base64.standard_b64encode(smaller).decode("ascii")
c[i] = {
"type": "image_url",
"image_url": {"url": f"data:{out_mime};base64,{new_b64}"},
}
changed += 1
return changed
def _model_strips_all_multimodal_before_request(model: str) -> bool:
"""True if *model* should be sent text-only (strip image, video, audio parts)."""
key = resolve_model_alias(model).strip().lower()
if key in _MODEL_IDS_TEXT_ONLY_BEFORE_REQUEST:
return True
# Google Gemini models natively support multimodal/vision capabilities,
# even when routed through a ":free" tier on OpenRouter.
if _model_targets_gemini(key):
return False
return key.endswith(":free")
def _strip_additional_properties(node: Any) -> Any:
"""Remove ``additionalProperties`` at every depth (Gemini function schemas reject it)."""
if isinstance(node, dict):
return {
k: _strip_additional_properties(v)
for k, v in node.items()
if k != "additionalProperties"
}
if isinstance(node, list):
return [_strip_additional_properties(x) for x in node]
return node
def _sanitize_openai_tools_for_gemini(
tools: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Deep-copy OpenAI-format tools and strip unsupported JSON Schema keys."""
out: list[dict[str, Any]] = []
for t in tools:
t2 = copy.deepcopy(t)
fn = t2.get("function")
if isinstance(fn, dict):
params = fn.get("parameters")
if isinstance(params, dict):
fn["parameters"] = _strip_additional_properties(params)
out.append(t2)
return out