"""Edit existing image(s) via native Gemini API and send to the current channel.
Supports multiple input images for compositing, style blending,
and multi-reference editing. # 🔥 multi-image witch energy
"""
from __future__ import annotations
import asyncio
import base64
import hashlib
import json
import logging
from io import BytesIO
from typing import Any, TYPE_CHECKING
if TYPE_CHECKING:
from tool_context import ToolContext
from tools._safe_http import assert_safe_http_url
logger = logging.getLogger(__name__)
TOOL_NAME = "edit_image"
TOOL_DESCRIPTION = (
"Edit one or more existing images using AI. Provide URL(s) to "
"source image(s) and a text prompt describing the desired changes. "
"Supports multiple input images for compositing, style blending, "
"or multi-reference editing. The edited image is sent to the "
"current channel."
)
TOOL_PARAMETERS = {
"type": "object",
"properties": {
"image_urls": {
"type": "array",
"items": {"type": "string"},
"description": (
"List of source image URLs to edit or combine "
"(must be publicly accessible). Supports 1-10 images."
),
},
"image_url": {
"type": "string",
"description": (
"Single source image URL (legacy, prefer image_urls). "
"If both image_url and image_urls are provided, "
"image_url is prepended to the list."
),
},
"prompt": {
"type": "string",
"description": (
"Description of the edits to apply or how to "
"combine/blend the input images."
),
},
"aspect_ratio": {
"type": "string",
"description": (
"Output aspect ratio. Default: 1:1. "
"Supported: 1:1, 2:3, 3:2, 3:4, 4:3, 4:5, "
"5:4, 9:16, 16:9, 21:9."
),
},
},
"required": ["prompt"],
}
# 💀 max images we'll accept per call
_MAX_INPUT_IMAGES = 10
async def _download_image(url: str) -> bytes | None:
"""Download an image from a URL."""
import httpx
try:
url = assert_safe_http_url(url.strip())
except ValueError as exc:
logger.error("Blocked URL for edit_image: %s", exc)
return None
try:
async with httpx.AsyncClient(
timeout=30.0, follow_redirects=True,
) as http:
resp = await http.get(url)
resp.raise_for_status()
return resp.content
except Exception as exc:
logger.error("Failed to download image from %s: %s", url, exc)
return None
def _detect_mime_type(image_bytes: bytes) -> str:
"""Detect the MIME type from image bytes."""
from PIL import Image
img = Image.open(BytesIO(image_bytes))
fmt = (img.format or "png").lower()
return {
"jpeg": "image/jpeg", "jpg": "image/jpeg",
"png": "image/png", "gif": "image/gif",
"webp": "image/webp",
}.get(fmt, "image/png")
async def _download_and_encode(url: str) -> dict[str, Any] | None:
"""Download an image and return a Gemini inlineData part.
Returns None if download or encoding fails. # 😈
"""
img_bytes = await _download_image(url)
if not img_bytes:
return None
mime_type = await asyncio.to_thread(_detect_mime_type, img_bytes)
b64_data = base64.b64encode(img_bytes).decode("utf-8")
return {
"inlineData": {
"mimeType": mime_type,
"data": b64_data,
},
}
[docs]
async def run(
prompt: str,
image_urls: list[str] | None = None,
image_url: str | None = None,
aspect_ratio: str = "1:1",
ctx: ToolContext | None = None,
) -> str:
"""Execute this tool and return the result.
Args:
prompt: Description of edits or how to combine images.
image_urls: List of source image URLs.
image_url: Single source image URL (legacy compat).
aspect_ratio: Output aspect ratio.
ctx: Tool execution context.
Returns:
str: JSON result string.
"""
from tools.generate_image import (
_call_gemini_native,
_resolve_api_key,
_is_admin,
_IMAGE_DAILY_LIMIT,
IMAGE_RATE_LIMIT_ERROR,
)
if ctx is None or ctx.adapter is None:
return "Error: No platform adapter available."
# 🌀 Merge legacy single URL + array into one list
urls: list[str] = []
if image_url:
urls.append(image_url)
if image_urls:
urls.extend(image_urls)
# Deduplicate while preserving order
seen: set[str] = set()
unique_urls: list[str] = []
for u in urls:
if u not in seen:
seen.add(u)
unique_urls.append(u)
urls = unique_urls
if not urls:
return json.dumps({
"error": (
"No image URLs provided. Supply at least one image "
"via image_urls (array) or image_url (string)."
),
})
if len(urls) > _MAX_INPUT_IMAGES:
return json.dumps({
"error": (
f"Too many images ({len(urls)}). "
f"Maximum is {_MAX_INPUT_IMAGES}."
),
})
api_key, using_own_key = await _resolve_api_key(ctx)
# Rate-limit non-admin users on the default/fallback key
if not using_own_key and not _is_admin(ctx):
redis = getattr(ctx, "redis", None)
if redis and ctx.user_id:
from tools.manage_api_keys import (
check_default_key_limit,
)
allowed, current, limit = await check_default_key_limit(
ctx.user_id, "image_generation", redis,
daily_limit=_IMAGE_DAILY_LIMIT,
)
if not allowed:
return json.dumps({
"error": IMAGE_RATE_LIMIT_ERROR.format(
current=current, limit=limit,
),
})
# 🕷️ Download all images concurrently
download_tasks = [_download_and_encode(u) for u in urls]
results = await asyncio.gather(*download_tasks)
# Build Gemini parts: all images first, then text prompt
prompt_parts: list[dict[str, Any]] = []
failed_urls: list[str] = []
for url, result in zip(urls, results):
if result is not None:
prompt_parts.append(result)
else:
failed_urls.append(url)
if not prompt_parts:
return json.dumps({
"error": (
"Failed to download any of the provided images. "
f"Failed URLs: {failed_urls}"
),
})
# Append text prompt after all images # 💕
prompt_parts.append({"text": prompt})
try:
img_bytes = await _call_gemini_native(
prompt_parts, api_key, aspect_ratio,
)
if not img_bytes:
return json.dumps({
"error": "No image was generated.",
})
from PIL import Image
def _convert_to_png(data: bytes) -> bytes:
"""Convert image data to PNG format."""
img = Image.open(BytesIO(data))
buf = BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
png_bytes = await asyncio.to_thread(_convert_to_png, img_bytes)
h = hashlib.sha256(png_bytes).hexdigest()[:16]
fname = f"edited_{h}.png"
file_url = await ctx.adapter.send_file(
ctx.channel_id, png_bytes, fname, "image/png",
)
ctx.sent_files.append({
"data": png_bytes,
"filename": fname,
"mimetype": "image/png",
"file_url": file_url or "",
})
if not using_own_key and not _is_admin(ctx):
redis = getattr(ctx, "redis", None)
if redis and ctx.user_id:
from tools.manage_api_keys import increment_default_key_usage
await increment_default_key_usage(
ctx.user_id, "image_generation", redis,
)
result_info: dict[str, Any] = {
"success": True,
"images_used": len(prompt_parts) - 1,
"filename": fname,
"result": "Image edited and sent to the channel.",
}
if file_url:
result_info["file_url"] = file_url
if failed_urls:
result_info["warnings"] = (
f"Failed to download {len(failed_urls)} image(s): "
f"{failed_urls}"
)
return json.dumps(result_info)
except Exception as exc:
logger.error(
"Image edit error: %s", exc, exc_info=True,
)
return json.dumps({
"error": f"Image editing failed: {exc}",
})