Source code for tools.edit_image

"""Edit existing image(s) via native Gemini API and send to the current channel.

Supports multiple input images for compositing, style blending,
and multi-reference editing.  # 🔥 multi-image witch energy
"""

from __future__ import annotations

import asyncio
import base64
import hashlib
import jsonutil as json
import logging
from io import BytesIO
from typing import Any, TYPE_CHECKING

if TYPE_CHECKING:
    from tool_context import ToolContext

from tools._safe_http import assert_safe_http_url, safe_http_request, safe_httpx_client

logger = logging.getLogger(__name__)

TOOL_NAME = "edit_image"
TOOL_DESCRIPTION = (
    "Edit one or more existing images using AI. Provide URL(s) to "
    "source image(s) and a text prompt describing the desired changes. "
    "Supports multiple input images for compositing, style blending, "
    "or multi-reference editing. The edited image is sent to the "
    "current channel."
)
TOOL_PARAMETERS = {
    "type": "object",
    "properties": {
        "image_urls": {
            "type": "array",
            "items": {"type": "string"},
            "description": (
                "List of source image URLs to edit or combine "
                "(must be publicly accessible). Supports 1-10 images."
            ),
        },
        "image_url": {
            "type": "string",
            "description": (
                "Single source image URL (legacy, prefer image_urls). "
                "If both image_url and image_urls are provided, "
                "image_url is prepended to the list."
            ),
        },
        "prompt": {
            "type": "string",
            "description": (
                "Description of the edits to apply or how to "
                "combine/blend the input images."
            ),
        },
        "aspect_ratio": {
            "type": "string",
            "description": (
                "Output aspect ratio. Default: 1:1. "
                "Supported: 1:1, 2:3, 3:2, 3:4, 4:3, 4:5, "
                "5:4, 9:16, 16:9, 21:9."
            ),
        },
    },
    "required": ["prompt"],
}

# 💀 max images we'll accept per call
_MAX_INPUT_IMAGES = 10


async def _download_image(url: str) -> bytes | None:
    """Fetch the raw bytes of a remote image over HTTP(S), with SSRF guards.

    Used to pull each user-supplied source image into memory before it is
    re-encoded and handed to Gemini for editing. The URL is first normalized and
    validated by :func:`tools._safe_http.assert_safe_http_url` (which blocks
    internal/loopback/metadata addresses to prevent SSRF); the body is then
    fetched through the hardened :func:`tools._safe_http.safe_httpx_client` and
    :func:`tools._safe_http.safe_http_request` (bounded redirects, 30s timeout).
    Any failure is logged and swallowed so a single bad URL does not abort the
    whole edit.

    Note that this is a module-local helper distinct from the same-named
    functions in other tool modules; here it is called only by
    :func:`_download_and_encode`.

    Args:
        url: The image URL to download (leading/trailing whitespace is stripped).

    Returns:
        bytes | None: The raw image bytes on success, or ``None`` if the URL is
        blocked, the request fails, or the response is a non-success status.
    """
    try:
        url = assert_safe_http_url(url.strip())
    except ValueError as exc:
        logger.error("Blocked URL for edit_image: %s", exc)
        return None
    try:
        async with safe_httpx_client(timeout=30.0) as http:
            resp = await safe_http_request(http, "GET", url, max_redirects=5)
            resp.raise_for_status()
            return resp.content
    except Exception as exc:
        logger.error("Failed to download image from %s: %s", url, exc)
        return None


def _detect_mime_type(image_bytes: bytes) -> str:
    """Infer an image's MIME type from its raw bytes via Pillow.

    Opens the bytes with Pillow to read the detected format and maps it to the
    corresponding ``image/*`` MIME type that Gemini's ``inlineData`` part
    expects. This is synchronous, CPU-bound decode work, so :func:`run` invokes
    it through :func:`_download_and_encode` via :func:`asyncio.to_thread` to keep
    the event loop responsive; there are no other callers.

    Args:
        image_bytes: The raw, already-downloaded image bytes.

    Returns:
        str: The detected MIME type, defaulting to ``"image/png"`` for any
        unrecognized format.
    """
    from PIL import Image

    img = Image.open(BytesIO(image_bytes))
    fmt = (img.format or "png").lower()
    return {
        "jpeg": "image/jpeg",
        "jpg": "image/jpeg",
        "png": "image/png",
        "gif": "image/gif",
        "webp": "image/webp",
    }.get(fmt, "image/png")


async def _download_and_encode(url: str) -> dict[str, Any] | None:
    """Download an image and return a Gemini inlineData part.

    Returns None if download or encoding fails.  # 😈
    """
    img_bytes = await _download_image(url)
    if not img_bytes:
        return None

    mime_type = await asyncio.to_thread(_detect_mime_type, img_bytes)
    b64_data = base64.b64encode(img_bytes).decode("utf-8")

    return {
        "inlineData": {
            "mimeType": mime_type,
            "data": b64_data,
        },
    }


[docs] async def run( prompt: str, image_urls: list[str] | None = None, image_url: str | None = None, aspect_ratio: str = "1:1", ctx: ToolContext | None = None, ) -> str: """Execute this tool and return the result. Args: prompt: Description of edits or how to combine images. image_urls: List of source image URLs. image_url: Single source image URL (legacy compat). aspect_ratio: Output aspect ratio. ctx: Tool execution context. Returns: str: JSON result string. """ from tools.generate_image import ( _call_gemini_native, _default_key_image_quota_applies, _resolve_api_key, _IMAGE_DAILY_LIMIT, IMAGE_RATE_LIMIT_ERROR, ) if ctx is None or ctx.adapter is None: return "Error: No platform adapter available." # 🌀 Merge legacy single URL + array into one list urls: list[str] = [] if image_url: urls.append(image_url) if image_urls: urls.extend(image_urls) # Deduplicate while preserving order seen: set[str] = set() unique_urls: list[str] = [] for u in urls: if u not in seen: seen.add(u) unique_urls.append(u) urls = unique_urls if not urls: return json.dumps( { "error": ( "No image URLs provided. Supply at least one image " "via image_urls (array) or image_url (string)." ), } ) if len(urls) > _MAX_INPUT_IMAGES: return json.dumps( { "error": ( f"Too many images ({len(urls)}). " f"Maximum is {_MAX_INPUT_IMAGES}." ), } ) api_key, using_own_key = await _resolve_api_key(ctx) # Rate-limit default/fallback key users (exempt: admin, BYPASS_RATELIMIT, own key) if await _default_key_image_quota_applies(ctx, using_own_key): redis = getattr(ctx, "redis", None) from tools.manage_api_keys import check_default_key_limit allowed, current, limit = await check_default_key_limit( ctx.user_id, "image_generation", redis, daily_limit=_IMAGE_DAILY_LIMIT, ) if not allowed: return json.dumps( { "error": IMAGE_RATE_LIMIT_ERROR.format( current=current, limit=limit, ), } ) # 🕷️ Download all images concurrently download_tasks = [_download_and_encode(u) for u in urls] results = await asyncio.gather(*download_tasks) # Build Gemini parts: all images first, then text prompt prompt_parts: list[dict[str, Any]] = [] failed_urls: list[str] = [] for url, result in zip(urls, results): if result is not None: prompt_parts.append(result) else: failed_urls.append(url) if not prompt_parts: return json.dumps( { "error": ( "Failed to download any of the provided images. " f"Failed URLs: {failed_urls}" ), } ) # Append text prompt after all images # 💕 prompt_parts.append({"text": prompt}) try: img_bytes = await _call_gemini_native( prompt_parts, api_key, aspect_ratio, ) if not img_bytes: return json.dumps( { "error": "No image was generated.", } ) from PIL import Image def _convert_to_png(data: bytes) -> bytes: """Re-encode arbitrary image bytes as PNG. Normalizes whatever format Gemini returns into PNG so the result can be given a stable ``.png`` filename and content type before being uploaded to the channel. This is a synchronous Pillow decode/encode closure run off-thread via :func:`asyncio.to_thread`; it is defined and used only inside :func:`run`. Args: data: The raw image bytes produced by Gemini. Returns: bytes: The same image re-encoded as PNG. """ img = Image.open(BytesIO(data)) buf = BytesIO() img.save(buf, format="PNG") return buf.getvalue() png_bytes = await asyncio.to_thread(_convert_to_png, img_bytes) h = hashlib.sha256(png_bytes).hexdigest()[:16] fname = f"edited_{h}.png" file_url = await ctx.adapter.send_file( ctx.channel_id, png_bytes, fname, "image/png", ) ctx.sent_files.append( { "data": png_bytes, "filename": fname, "mimetype": "image/png", "file_url": file_url or "", } ) if await _default_key_image_quota_applies(ctx, using_own_key): redis = getattr(ctx, "redis", None) from tools.manage_api_keys import increment_default_key_usage await increment_default_key_usage( ctx.user_id, "image_generation", redis, ) result_info: dict[str, Any] = { "success": True, "images_used": len(prompt_parts) - 1, "filename": fname, "result": "Image edited and sent to the channel.", } if file_url: result_info["file_url"] = file_url if failed_urls: result_info["warnings"] = ( f"Failed to download {len(failed_urls)} image(s): " f"{failed_urls}" ) return json.dumps(result_info) except Exception as exc: logger.error( "Image edit error: %s", exc, exc_info=True, ) return json.dumps( { "error": f"Image editing failed: {exc}", } )