"""Edit existing image(s) via native Gemini API and send to the current channel.
Supports multiple input images for compositing, style blending,
and multi-reference editing. # 🔥 multi-image witch energy
"""
from __future__ import annotations
import asyncio
import base64
import hashlib
import jsonutil as json
import logging
from io import BytesIO
from typing import Any, TYPE_CHECKING
if TYPE_CHECKING:
from tool_context import ToolContext
from tools._safe_http import assert_safe_http_url, safe_http_request, safe_httpx_client
logger = logging.getLogger(__name__)
TOOL_NAME = "edit_image"
TOOL_DESCRIPTION = (
"Edit one or more existing images using AI. Provide URL(s) to "
"source image(s) and a text prompt describing the desired changes. "
"Supports multiple input images for compositing, style blending, "
"or multi-reference editing. The edited image is sent to the "
"current channel."
)
TOOL_PARAMETERS = {
"type": "object",
"properties": {
"image_urls": {
"type": "array",
"items": {"type": "string"},
"description": (
"List of source image URLs to edit or combine "
"(must be publicly accessible). Supports 1-10 images."
),
},
"image_url": {
"type": "string",
"description": (
"Single source image URL (legacy, prefer image_urls). "
"If both image_url and image_urls are provided, "
"image_url is prepended to the list."
),
},
"prompt": {
"type": "string",
"description": (
"Description of the edits to apply or how to "
"combine/blend the input images."
),
},
"aspect_ratio": {
"type": "string",
"description": (
"Output aspect ratio. Default: 1:1. "
"Supported: 1:1, 2:3, 3:2, 3:4, 4:3, 4:5, "
"5:4, 9:16, 16:9, 21:9."
),
},
},
"required": ["prompt"],
}
# 💀 max images we'll accept per call
_MAX_INPUT_IMAGES = 10
async def _download_image(url: str) -> bytes | None:
"""Fetch the raw bytes of a remote image over HTTP(S), with SSRF guards.
Used to pull each user-supplied source image into memory before it is
re-encoded and handed to Gemini for editing. The URL is first normalized and
validated by :func:`tools._safe_http.assert_safe_http_url` (which blocks
internal/loopback/metadata addresses to prevent SSRF); the body is then
fetched through the hardened :func:`tools._safe_http.safe_httpx_client` and
:func:`tools._safe_http.safe_http_request` (bounded redirects, 30s timeout).
Any failure is logged and swallowed so a single bad URL does not abort the
whole edit.
Note that this is a module-local helper distinct from the same-named
functions in other tool modules; here it is called only by
:func:`_download_and_encode`.
Args:
url: The image URL to download (leading/trailing whitespace is stripped).
Returns:
bytes | None: The raw image bytes on success, or ``None`` if the URL is
blocked, the request fails, or the response is a non-success status.
"""
try:
url = assert_safe_http_url(url.strip())
except ValueError as exc:
logger.error("Blocked URL for edit_image: %s", exc)
return None
try:
async with safe_httpx_client(timeout=30.0) as http:
resp = await safe_http_request(http, "GET", url, max_redirects=5)
resp.raise_for_status()
return resp.content
except Exception as exc:
logger.error("Failed to download image from %s: %s", url, exc)
return None
def _detect_mime_type(image_bytes: bytes) -> str:
"""Infer an image's MIME type from its raw bytes via Pillow.
Opens the bytes with Pillow to read the detected format and maps it to the
corresponding ``image/*`` MIME type that Gemini's ``inlineData`` part
expects. This is synchronous, CPU-bound decode work, so :func:`run` invokes
it through :func:`_download_and_encode` via :func:`asyncio.to_thread` to keep
the event loop responsive; there are no other callers.
Args:
image_bytes: The raw, already-downloaded image bytes.
Returns:
str: The detected MIME type, defaulting to ``"image/png"`` for any
unrecognized format.
"""
from PIL import Image
img = Image.open(BytesIO(image_bytes))
fmt = (img.format or "png").lower()
return {
"jpeg": "image/jpeg",
"jpg": "image/jpeg",
"png": "image/png",
"gif": "image/gif",
"webp": "image/webp",
}.get(fmt, "image/png")
async def _download_and_encode(url: str) -> dict[str, Any] | None:
"""Download an image and return a Gemini inlineData part.
Returns None if download or encoding fails. # 😈
"""
img_bytes = await _download_image(url)
if not img_bytes:
return None
mime_type = await asyncio.to_thread(_detect_mime_type, img_bytes)
b64_data = base64.b64encode(img_bytes).decode("utf-8")
return {
"inlineData": {
"mimeType": mime_type,
"data": b64_data,
},
}
[docs]
async def run(
prompt: str,
image_urls: list[str] | None = None,
image_url: str | None = None,
aspect_ratio: str = "1:1",
ctx: ToolContext | None = None,
) -> str:
"""Execute this tool and return the result.
Args:
prompt: Description of edits or how to combine images.
image_urls: List of source image URLs.
image_url: Single source image URL (legacy compat).
aspect_ratio: Output aspect ratio.
ctx: Tool execution context.
Returns:
str: JSON result string.
"""
from tools.generate_image import (
_call_gemini_native,
_default_key_image_quota_applies,
_resolve_api_key,
_IMAGE_DAILY_LIMIT,
IMAGE_RATE_LIMIT_ERROR,
)
if ctx is None or ctx.adapter is None:
return "Error: No platform adapter available."
# 🌀 Merge legacy single URL + array into one list
urls: list[str] = []
if image_url:
urls.append(image_url)
if image_urls:
urls.extend(image_urls)
# Deduplicate while preserving order
seen: set[str] = set()
unique_urls: list[str] = []
for u in urls:
if u not in seen:
seen.add(u)
unique_urls.append(u)
urls = unique_urls
if not urls:
return json.dumps(
{
"error": (
"No image URLs provided. Supply at least one image "
"via image_urls (array) or image_url (string)."
),
}
)
if len(urls) > _MAX_INPUT_IMAGES:
return json.dumps(
{
"error": (
f"Too many images ({len(urls)}). "
f"Maximum is {_MAX_INPUT_IMAGES}."
),
}
)
api_key, using_own_key = await _resolve_api_key(ctx)
# Rate-limit default/fallback key users (exempt: admin, BYPASS_RATELIMIT, own key)
if await _default_key_image_quota_applies(ctx, using_own_key):
redis = getattr(ctx, "redis", None)
from tools.manage_api_keys import check_default_key_limit
allowed, current, limit = await check_default_key_limit(
ctx.user_id,
"image_generation",
redis,
daily_limit=_IMAGE_DAILY_LIMIT,
)
if not allowed:
return json.dumps(
{
"error": IMAGE_RATE_LIMIT_ERROR.format(
current=current,
limit=limit,
),
}
)
# 🕷️ Download all images concurrently
download_tasks = [_download_and_encode(u) for u in urls]
results = await asyncio.gather(*download_tasks)
# Build Gemini parts: all images first, then text prompt
prompt_parts: list[dict[str, Any]] = []
failed_urls: list[str] = []
for url, result in zip(urls, results):
if result is not None:
prompt_parts.append(result)
else:
failed_urls.append(url)
if not prompt_parts:
return json.dumps(
{
"error": (
"Failed to download any of the provided images. "
f"Failed URLs: {failed_urls}"
),
}
)
# Append text prompt after all images # 💕
prompt_parts.append({"text": prompt})
try:
img_bytes = await _call_gemini_native(
prompt_parts,
api_key,
aspect_ratio,
)
if not img_bytes:
return json.dumps(
{
"error": "No image was generated.",
}
)
from PIL import Image
def _convert_to_png(data: bytes) -> bytes:
"""Re-encode arbitrary image bytes as PNG.
Normalizes whatever format Gemini returns into PNG so the result can
be given a stable ``.png`` filename and content type before being
uploaded to the channel. This is a synchronous Pillow decode/encode
closure run off-thread via :func:`asyncio.to_thread`; it is defined
and used only inside :func:`run`.
Args:
data: The raw image bytes produced by Gemini.
Returns:
bytes: The same image re-encoded as PNG.
"""
img = Image.open(BytesIO(data))
buf = BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
png_bytes = await asyncio.to_thread(_convert_to_png, img_bytes)
h = hashlib.sha256(png_bytes).hexdigest()[:16]
fname = f"edited_{h}.png"
file_url = await ctx.adapter.send_file(
ctx.channel_id,
png_bytes,
fname,
"image/png",
)
ctx.sent_files.append(
{
"data": png_bytes,
"filename": fname,
"mimetype": "image/png",
"file_url": file_url or "",
}
)
if await _default_key_image_quota_applies(ctx, using_own_key):
redis = getattr(ctx, "redis", None)
from tools.manage_api_keys import increment_default_key_usage
await increment_default_key_usage(
ctx.user_id,
"image_generation",
redis,
)
result_info: dict[str, Any] = {
"success": True,
"images_used": len(prompt_parts) - 1,
"filename": fname,
"result": "Image edited and sent to the channel.",
}
if file_url:
result_info["file_url"] = file_url
if failed_urls:
result_info["warnings"] = (
f"Failed to download {len(failed_urls)} image(s): " f"{failed_urls}"
)
return json.dumps(result_info)
except Exception as exc:
logger.error(
"Image edit error: %s",
exc,
exc_info=True,
)
return json.dumps(
{
"error": f"Image editing failed: {exc}",
}
)