Source code for openrouter_client.sanitization

"""Multimodal input sanitization, image compression, and model target helpers."""

from __future__ import annotations

import base64
import copy
import logging
from typing import Any

from platforms.media_common import (
    CLAUDE_MAX_IMAGE_BYTES,
    detect_image_mimetype_from_bytes,
    shrink_image_under_max_bytes,
)

logger = logging.getLogger(__name__)

# Short model aliases that the local proxy understands but that differ from
# the canonical provider model id.  Applied before every API call so caps,
# Gemini detection, and the response header all see the canonical name.
_MODEL_ALIASES: dict[str, str] = {
    "gemini-3-pro-preview": "gemini-3.1-pro-preview",
}

# Resolved model ids (after :func:`resolve_model_alias`) that do not accept
# image / video / audio on the routed endpoint — strip before first request.
# Any OpenRouter id ending in ``:free`` is also treated as text-only (free
# tiers typically omit vision/audio/video).
_MODEL_IDS_TEXT_ONLY_BEFORE_REQUEST: frozenset[str] = frozenset(
    {
        "arcee-ai/trinity-large-thinking",
        "deepseek/deepseek-v4-flash",
        "deepseek/deepseek-v4-flash:free",
        "deepseek/deepseek-v4-pro",
        "moonshotai/kimi-k2.6",
        "x-ai/grok-4.3",
        "xiaomi/mimo-v2-pro",
        "xiaomi/mimo-v2.5",
        "xiaomi/mimo-v2.5-pro",
    }
)

_RELABELED_ASSISTANT_PREFIX = "[My previous response:]\n"

_CLAUDE_OVERSIZED_IMAGE_PLACEHOLDER = (
    "[Image omitted: larger than 4 MB and could not be compressed enough "
    "for Claude-compatible models.]"
)


def _strip_audio_parts(messages: list[dict[str, Any]]) -> int:
    """Remove all ``input_audio`` content parts from *messages* in-place.

    Used as a fallback retry when the endpoint returns a 404 "No endpoints
    found that support input audio" error.  Returns the number of parts
    removed.
    """
    removed = 0
    for m in messages:
        c = m.get("content")
        if not isinstance(c, list):
            continue
        new_c = [
            p
            for p in c
            if not (isinstance(p, dict) and "audio" in str(p.get("type", "")).lower())
        ]
        if len(new_c) < len(c):
            removed += len(c) - len(new_c)
            m["content"] = new_c
    return removed


def _strip_video_url_parts(messages: list[dict[str, Any]]) -> int:
    """Remove all ``video_url`` content parts from *messages* in-place.

    Used before sending to models that do not support the ``video_url``
    multimodal type (non-Gemini models), and for Claude-family models
    (which also run :func:`_strip_video_and_audio_file_parts`).  Returns
    the number of parts removed.
    """
    removed = 0
    for m in messages:
        c = m.get("content")
        if not isinstance(c, list):
            continue
        new_c = [
            p for p in c if not (isinstance(p, dict) and p.get("type") == "video_url")
        ]
        if len(new_c) < len(c):
            removed += len(c) - len(new_c)
            m["content"] = new_c
    return removed


def _data_uri_mediatype_lower(file_data: str) -> str | None:
    """Return the MIME type segment of a ``data:`` URI, lowercased, or ``None``."""
    if not isinstance(file_data, str) or not file_data.lower().startswith("data:"):
        return None
    rest = file_data[5:]
    semi = rest.find(";")
    comma = rest.find(",")
    end = len(rest)
    if comma >= 0 and semi >= 0:
        end = min(semi, comma)
    elif comma >= 0:
        end = comma
    elif semi >= 0:
        end = semi
    mediatype = rest[:end].strip().lower()
    return mediatype or None


def _strip_video_and_audio_file_parts(messages: list[dict[str, Any]]) -> int:
    """Remove ``file`` parts whose ``file_data`` is a ``data:video/*`` or ``data:audio/*`` URI.

    Claude (and some proxies) reject or mishandle these structured parts; mutates
    *messages* in place.  Returns the number of parts removed.
    """
    removed = 0
    for m in messages:
        c = m.get("content")
        if not isinstance(c, list):
            continue
        new_c: list[Any] = []
        for p in c:
            drop = False
            if isinstance(p, dict) and p.get("type") == "file":
                finfo = p.get("file")
                if isinstance(finfo, dict):
                    mime = _data_uri_mediatype_lower(finfo.get("file_data", ""))
                    if mime and (
                        mime.startswith("video/") or mime.startswith("audio/")
                    ):
                        drop = True
            if not drop:
                new_c.append(p)
        if len(new_c) < len(c):
            removed += len(c) - len(new_c)
            m["content"] = new_c
    return removed


def _strip_image_url_parts(messages: list[dict[str, Any]]) -> int:
    """Remove all ``image_url`` content parts from *messages* in-place.

    Used as a fallback retry when the endpoint returns HTTP 404 with a body
    like "No endpoints found that support image input".  Returns the number
    of parts removed.
    """
    removed = 0
    for m in messages:
        c = m.get("content")
        if not isinstance(c, list):
            continue
        new_c = [
            p for p in c if not (isinstance(p, dict) and p.get("type") == "image_url")
        ]
        if len(new_c) < len(c):
            removed += len(c) - len(new_c)
            m["content"] = new_c
    return removed


def _part_is_non_text(part: Any) -> bool:
    """Return True for any content part that is not a plain ``text`` part.

    Covers every multimodal shape this codebase produces — ``image_url``,
    ``file`` (e.g. images/PDFs sent inline by image-generation tools and
    embeds), ``video_url``, ``input_audio`` — plus any future/foreign type.
    All of these are rejected by providers when they appear in an
    assistant/model turn.
    """
    if not isinstance(part, dict):
        return False
    return part.get("type") not in (None, "text")


def _relabel_assistant_image_turns_as_user(
    messages: list[dict[str, Any]],
) -> int:
    """Re-role ``assistant`` turns that carry non-text media to ``user``.

    No provider accepts media inside an assistant/model turn: Anthropic
    rejects ``image`` blocks and Google Gemini returns ``INVALID_ARGUMENT``
    for any inline media in a ``model`` turn. Bot-sent media — notably images
    that image-generation tools post inside Discord embeds, which can be
    represented as ``file`` or ``video_url`` parts rather than ``image_url`` —
    must therefore be carried as ``user`` content. This converts any such
    assistant turn so the model still sees the media without a validation
    error.

    A short text prefix marks the content as the model's own prior output (a
    synthetic text part is inserted when the turn has no text part at all).
    Idempotent: once relabeled to ``user``, subsequent calls skip it.

    Mutates *messages* in place.  Returns the number of messages relabeled.
    """
    relabeled = 0
    for m in messages:
        if m.get("role") != "assistant":
            continue
        c = m.get("content")
        if not isinstance(c, list):
            continue
        if not any(_part_is_non_text(p) for p in c):
            continue
        m["role"] = "user"
        prefixed = False
        for p in c:
            if isinstance(p, dict) and p.get("type") == "text":
                p["text"] = _RELABELED_ASSISTANT_PREFIX + p.get("text", "")
                prefixed = True
                break
        if not prefixed:
            c.insert(
                0,
                {"type": "text", "text": _RELABELED_ASSISTANT_PREFIX.rstrip("\n")},
            )
        relabeled += 1
    return relabeled


def _cap_video_url_parts(messages: list[dict[str, Any]], max_keep: int) -> int:
    """Drop older ``video_url`` content parts; keep the last *max_keep* in order.

    Traversal order: each message in *messages*, then each part in that
    message's list ``content``. Mutates *messages* in place by assigning
    new ``content`` lists only where parts were removed.

    Returns the number of parts removed.
    """
    if max_keep <= 0:
        return 0
    positions: list[tuple[int, int]] = []
    for mi, m in enumerate(messages):
        c = m.get("content")
        if not isinstance(c, list):
            continue
        for pi, p in enumerate(c):
            if isinstance(p, dict) and p.get("type") == "video_url":
                positions.append((mi, pi))
    n_vid = len(positions)
    if n_vid <= max_keep:
        return 0
    remove_set = set(positions[: n_vid - max_keep])
    stripped = n_vid - max_keep
    by_msg: dict[int, set[int]] = {}
    for mi, pi in remove_set:
        by_msg.setdefault(mi, set()).add(pi)
    for mi, drop_pi in by_msg.items():
        m = messages[mi]
        c = m.get("content")
        if not isinstance(c, list):
            continue
        m["content"] = [p for pi, p in enumerate(c) if pi not in drop_pi]
    return stripped


[docs] def resolve_model_alias(model: str) -> str: """Expand a short model alias to its canonical provider model id. The local proxy accepts friendly short names (e.g. ``gemini-3-pro-preview``) but routes them to a canonical backend model. Resolving the alias here ensures all downstream logic (token caps, Gemini detection, response header) sees the true model name. """ return _MODEL_ALIASES.get(model, model)
def _model_targets_gemini(model: str) -> bool: """Return True if *model* is routed to Google Gemini. Delegates to :func:`model_capabilities.get_capabilities` for registered models; falls back to substring match for unregistered aliases. """ from model_capabilities import get_capabilities caps = get_capabilities(model) if caps.provider == "google": return True return "gemini" in model.lower() def _model_targets_claude(model: str) -> bool: """Return True if *model* id refers to an Anthropic Claude family model. Delegates to :func:`model_capabilities.get_capabilities` for registered models; falls back to substring match for unregistered aliases. """ from model_capabilities import get_capabilities caps = get_capabilities(model) if caps.provider == "anthropic": return True return "claude" in model.lower() def _model_targets_gemma4(model: str) -> bool: """Return True if *model* id refers to a Gemma 4 family model.""" return "gemma4" in model.lower().replace("-", "") def _approx_b64_decoded_len(b64: str) -> int: """Upper-bound-style length of decoded base64 without decoding (valid padding).""" pad = 0 if b64.endswith("=="): pad = 2 elif b64.endswith("="): pad = 1 return max(0, (len(b64) * 3) // 4 - pad) def _clamp_claude_oversized_images( messages: list[dict[str, Any]], effective_model: str, ) -> int: """Shrink or strip ``image_url`` data-URI parts over 4 MiB for Claude models. Mutates *messages* in place. Returns the number of parts shrunk or replaced. """ if not _model_targets_claude(effective_model): return 0 import openrouter_client max_image_bytes = getattr(openrouter_client, "CLAUDE_MAX_IMAGE_BYTES", CLAUDE_MAX_IMAGE_BYTES) changed = 0 for m in messages: c = m.get("content") if not isinstance(c, list): continue for i, part in enumerate(c): if not isinstance(part, dict) or part.get("type") != "image_url": continue iu = part.get("image_url") if not isinstance(iu, dict): continue url = iu.get("url") if not isinstance(url, str) or not url.startswith("data:"): continue try: meta, b64_payload = url[5:].split(",", 1) except ValueError: continue if "base64" not in meta.lower(): continue mime = meta.split(";")[0].strip() if not mime.lower().startswith("image/"): continue if _approx_b64_decoded_len(b64_payload) <= max_image_bytes: continue try: raw = base64.b64decode(b64_payload) except Exception: c[i] = {"type": "text", "text": _CLAUDE_OVERSIZED_IMAGE_PLACEHOLDER} changed += 1 continue # Valid base64 of a payload above the limit implies len(raw) > limit. # If the declared size is over the limit but decoded bytes are not, # the data URL is corrupt (e.g. invalid alphabet decodes to empty). if len(raw) <= max_image_bytes: c[i] = {"type": "text", "text": _CLAUDE_OVERSIZED_IMAGE_PLACEHOLDER} changed += 1 continue smaller = shrink_image_under_max_bytes( raw, mime, max_bytes=max_image_bytes, ) if smaller is None: c[i] = {"type": "text", "text": _CLAUDE_OVERSIZED_IMAGE_PLACEHOLDER} changed += 1 continue out_mime = detect_image_mimetype_from_bytes(smaller) or mime new_b64 = base64.standard_b64encode(smaller).decode("ascii") c[i] = { "type": "image_url", "image_url": {"url": f"data:{out_mime};base64,{new_b64}"}, } changed += 1 return changed def _model_strips_all_multimodal_before_request(model: str) -> bool: """True if *model* should be sent text-only (strip image, video, audio parts).""" key = resolve_model_alias(model).strip().lower() if key in _MODEL_IDS_TEXT_ONLY_BEFORE_REQUEST: return True # Google Gemini models natively support multimodal/vision capabilities, # even when routed through a ":free" tier on OpenRouter. if _model_targets_gemini(key): return False return key.endswith(":free") def _strip_additional_properties(node: Any) -> Any: """Remove ``additionalProperties`` at every depth (Gemini function schemas reject it).""" if isinstance(node, dict): return { k: _strip_additional_properties(v) for k, v in node.items() if k != "additionalProperties" } if isinstance(node, list): return [_strip_additional_properties(x) for x in node] return node def _sanitize_openai_tools_for_gemini( tools: list[dict[str, Any]], ) -> list[dict[str, Any]]: """Deep-copy OpenAI-format tools and strip unsupported JSON Schema keys.""" out: list[dict[str, Any]] = [] for t in tools: t2 = copy.deepcopy(t) fn = t2.get("function") if isinstance(fn, dict): params = fn.get("parameters") if isinstance(params, dict): fn["parameters"] = _strip_additional_properties(params) out.append(t2) return out