"""Generate images via native Gemini API and send to the current channel."""
from __future__ import annotations
import asyncio
import base64
import hashlib
import json
import logging
from io import BytesIO
from typing import Any, TYPE_CHECKING
if TYPE_CHECKING:
from tool_context import ToolContext
logger = logging.getLogger(__name__)
# Native Gemini API endpoint
GEMINI_API_BASE = "https://generativelanguage.googleapis.com/v1beta"
DEFAULT_IMAGE_MODEL = "gemini-3.1-flash-image-preview"
FALLBACK_API_KEY = "AIzaSyCCwz9WCsIKSWsfufU6E-JbPsP1acLhZTU"
SUPPORTED_ASPECT_RATIOS = {
"1:1", "2:3", "3:2", "3:4", "4:3",
"4:5", "5:4", "9:16", "16:9", "21:9",
}
_JSON_BODY_DECODE_THRESHOLD = 256 * 1024
def _json_loads_utf8(body: bytes) -> Any:
return json.loads(body.decode("utf-8"))
IMAGE_GENERATION_SYSTEM_PROMPT = (
"You are an expert image generation artist specializing in "
"extremely high-quality, detailed artwork.\n\n"
"DEFAULT STYLE - ANIME/ILLUSTRATION:\n"
"By default, generate images in a premium anime/illustration "
"style with rich, vibrant color palettes, masterful lighting, "
"expressive character designs, beautiful backgrounds, smooth "
"gradients, and cinematic quality.\n\n"
"REALISM MODE:\n"
"When the user explicitly requests realistic or photorealistic "
"imagery, switch to hyperrealistic rendering with photographic "
"accuracy, natural lighting, and true-to-life textures.\n\n"
"QUALITY STANDARDS:\n"
"Always aim for the highest possible quality. Create visually "
"striking images with excellent composition, lighting, and "
"color harmony."
)
TOOL_NAME = "generate_image"
TOOL_DESCRIPTION = (
"Generate an AI image from a text prompt using Gemini and "
"send it to the current channel. Supports multiple "
"aspect ratios."
)
TOOL_PARAMETERS = {
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": (
"Text description of the image to generate."
),
},
"aspect_ratio": {
"type": "string",
"description": (
"Aspect ratio. Supported: 1:1, 2:3, 3:2, 3:4, "
"4:3, 4:5, 5:4, 9:16, 16:9, 21:9. "
"Default: 16:9."
),
},
"model": {
"type": "string",
"description": (
"Model name. Default: gemini-3.1-flash-image-preview. "
"Also available: gemini-3-pro-image-preview."
),
},
},
"required": ["prompt"],
}
def _is_admin(ctx: ToolContext | None) -> bool:
"""Check whether the calling user is a bot admin."""
if ctx is None:
return False
user_id = getattr(ctx, "user_id", None)
config = getattr(ctx, "config", None)
if not user_id or not config:
return False
admin_ids = getattr(config, "admin_user_ids", None) or []
return str(user_id) in admin_ids
IMAGE_RATE_LIMIT_ERROR = (
"User has reached their daily image generation limit ({current}/{limit}). "
"Image generation is expensive, and we can't subsidize this at scale "
"for everyone. To unlock unlimited image generation, provide your own "
"Gemini API key:\n"
"1. Get a key at: https://aistudio.google.com/apikey\n"
"2. Send it via DM: set_user_api_key service=gemini api_key=YOUR_KEY\n"
"Your own key has no daily limit."
)
_IMAGE_DAILY_LIMIT = 5
async def _resolve_api_key(ctx: ToolContext | None) -> tuple[str, bool]:
"""Resolve a Gemini API key: user key -> pool -> fallback.
Returns (api_key, using_own_key).
"""
if ctx is not None and getattr(ctx, "user_id", None):
try:
from tools.manage_api_keys import get_user_api_key
user_key = await get_user_api_key(
ctx.user_id, "gemini",
redis_client=getattr(ctx, "redis", None),
channel_id=getattr(ctx, "channel_id", None),
config=getattr(ctx, "config", None),
)
if user_key:
return user_key, True
except Exception as exc:
logger.warning("Failed to resolve user Gemini key: %s", exc)
return FALLBACK_API_KEY, False
async def _call_gemini_native(
prompt_parts: list[dict[str, Any]],
api_key: str,
aspect_ratio: str = "16:9",
model: str | None = None,
) -> bytes | None:
"""Call native Gemini API and return generated image bytes."""
import httpx
if aspect_ratio not in SUPPORTED_ASPECT_RATIOS:
aspect_ratio = "16:9"
model = model or DEFAULT_IMAGE_MODEL
url = f"{GEMINI_API_BASE}/models/{model}:generateContent"
payload: dict[str, Any] = {
"contents": [{"parts": prompt_parts}],
"systemInstruction": {
"parts": [{"text": IMAGE_GENERATION_SYSTEM_PROMPT}],
},
"generationConfig": {
"responseModalities": ["TEXT", "IMAGE"],
"imageConfig": {"aspectRatio": aspect_ratio},
},
}
headers = {
"x-goog-api-key": api_key,
"Content-Type": "application/json",
}
async with httpx.AsyncClient(timeout=120.0) as http:
resp = await http.post(url, headers=headers, json=payload)
if resp.status_code != 200:
logger.error(
"Gemini API error: %d - %s",
resp.status_code, resp.text[:500],
)
return None
raw = await resp.aread()
if len(raw) >= _JSON_BODY_DECODE_THRESHOLD:
result = await asyncio.to_thread(_json_loads_utf8, raw)
else:
result = json.loads(raw.decode("utf-8"))
# Parse native Gemini response: candidates[0].content.parts[]
candidates = result.get("candidates", [])
if not candidates:
logger.warning("Gemini API returned no candidates")
return None
parts = candidates[0].get("content", {}).get("parts", [])
for part in parts:
inline_data = part.get("inlineData")
if inline_data and inline_data.get("data"):
try:
return base64.b64decode(inline_data["data"])
except Exception as exc:
logger.error("Base64 decode failed: %s", exc)
return None
[docs]
async def run(
prompt: str,
aspect_ratio: str = "16:9",
model: str | None = None,
ctx: ToolContext | None = None,
) -> str:
"""Execute this tool and return the result.
Args:
prompt (str): The prompt value.
aspect_ratio (str): The aspect ratio value.
model (str | None): The model value.
ctx (ToolContext | None): Tool execution context providing access to bot internals.
Returns:
str: Result string.
"""
if ctx is None or ctx.adapter is None:
return "Error: No platform adapter available."
api_key, using_own_key = await _resolve_api_key(ctx)
# Rate-limit non-admin users on the default/fallback key
if not using_own_key and not _is_admin(ctx):
redis = getattr(ctx, "redis", None)
if redis and ctx.user_id:
from tools.manage_api_keys import (
check_default_key_limit,
increment_default_key_usage,
)
allowed, current, limit = await check_default_key_limit(
ctx.user_id, "image_generation", redis,
daily_limit=_IMAGE_DAILY_LIMIT,
)
if not allowed:
return json.dumps({
"error": IMAGE_RATE_LIMIT_ERROR.format(
current=current, limit=limit,
),
})
prompt_parts = [{"text": prompt}]
try:
img_bytes = await _call_gemini_native(
prompt_parts, api_key, aspect_ratio, model,
)
if not img_bytes:
return json.dumps({
"error": "No image was generated by the model.",
})
from PIL import Image
def _convert_to_png(data: bytes) -> bytes:
"""Internal helper: convert to png.
Args:
data (bytes): Input data payload.
Returns:
bytes: The result.
"""
img = Image.open(BytesIO(data))
buf = BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
png_bytes = await asyncio.to_thread(_convert_to_png, img_bytes)
h = hashlib.sha256(png_bytes).hexdigest()[:16]
fname = f"generated_{h}.png"
file_url = await ctx.adapter.send_file(
ctx.channel_id, png_bytes, fname, "image/png",
)
ctx.sent_files.append({
"data": png_bytes,
"filename": fname,
"mimetype": "image/png",
"file_url": file_url or "",
})
if not using_own_key and not _is_admin(ctx):
redis = getattr(ctx, "redis", None)
if redis and ctx.user_id:
from tools.manage_api_keys import increment_default_key_usage
await increment_default_key_usage(
ctx.user_id, "image_generation", redis,
)
result: dict[str, Any] = {
"success": True,
"filename": fname,
"result": "Image generated and sent to the channel.",
}
if file_url:
result["file_url"] = file_url
return json.dumps(result)
except Exception as exc:
logger.error(
"Image generation error: %s", exc, exc_info=True,
)
return json.dumps({"error": f"Image generation failed: {exc}"})