Source code for tools.generate_image

"""Generate images via native Gemini API and send to the current channel."""

from __future__ import annotations

import asyncio
import base64
import hashlib
import json
import logging
from io import BytesIO
from typing import Any, TYPE_CHECKING

if TYPE_CHECKING:
    from tool_context import ToolContext

logger = logging.getLogger(__name__)

# Native Gemini API endpoint
GEMINI_API_BASE = "https://generativelanguage.googleapis.com/v1beta"
DEFAULT_IMAGE_MODEL = "gemini-3.1-flash-image-preview"
FALLBACK_API_KEY = "AIzaSyCCwz9WCsIKSWsfufU6E-JbPsP1acLhZTU"

SUPPORTED_ASPECT_RATIOS = {
    "1:1", "2:3", "3:2", "3:4", "4:3",
    "4:5", "5:4", "9:16", "16:9", "21:9",
}

_JSON_BODY_DECODE_THRESHOLD = 256 * 1024


def _json_loads_utf8(body: bytes) -> Any:
    return json.loads(body.decode("utf-8"))

IMAGE_GENERATION_SYSTEM_PROMPT = (
    "You are an expert image generation artist specializing in "
    "extremely high-quality, detailed artwork.\n\n"
    "DEFAULT STYLE - ANIME/ILLUSTRATION:\n"
    "By default, generate images in a premium anime/illustration "
    "style with rich, vibrant color palettes, masterful lighting, "
    "expressive character designs, beautiful backgrounds, smooth "
    "gradients, and cinematic quality.\n\n"
    "REALISM MODE:\n"
    "When the user explicitly requests realistic or photorealistic "
    "imagery, switch to hyperrealistic rendering with photographic "
    "accuracy, natural lighting, and true-to-life textures.\n\n"
    "QUALITY STANDARDS:\n"
    "Always aim for the highest possible quality. Create visually "
    "striking images with excellent composition, lighting, and "
    "color harmony."
)

TOOL_NAME = "generate_image"
TOOL_DESCRIPTION = (
    "Generate an AI image from a text prompt using Gemini and "
    "send it to the current channel. Supports multiple "
    "aspect ratios."
)
TOOL_PARAMETERS = {
    "type": "object",
    "properties": {
        "prompt": {
            "type": "string",
            "description": (
                "Text description of the image to generate."
            ),
        },
        "aspect_ratio": {
            "type": "string",
            "description": (
                "Aspect ratio. Supported: 1:1, 2:3, 3:2, 3:4, "
                "4:3, 4:5, 5:4, 9:16, 16:9, 21:9. "
                "Default: 16:9."
            ),
        },
        "model": {
            "type": "string",
            "description": (
                "Model name. Default: gemini-3.1-flash-image-preview. "
                "Also available: gemini-3-pro-image-preview."
            ),
        },
    },
    "required": ["prompt"],
}


def _is_admin(ctx: ToolContext | None) -> bool:
    """Check whether the calling user is a bot admin."""
    if ctx is None:
        return False
    user_id = getattr(ctx, "user_id", None)
    config = getattr(ctx, "config", None)
    if not user_id or not config:
        return False
    admin_ids = getattr(config, "admin_user_ids", None) or []
    return str(user_id) in admin_ids


IMAGE_RATE_LIMIT_ERROR = (
    "User has reached their daily image generation limit ({current}/{limit}). "
    "Image generation is expensive, and we can't subsidize this at scale "
    "for everyone. To unlock unlimited image generation, provide your own "
    "Gemini API key:\n"
    "1. Get a key at: https://aistudio.google.com/apikey\n"
    "2. Send it via DM: set_user_api_key service=gemini api_key=YOUR_KEY\n"
    "Your own key has no daily limit."
)

_IMAGE_DAILY_LIMIT = 5


async def _resolve_api_key(ctx: ToolContext | None) -> tuple[str, bool]:
    """Resolve a Gemini API key: user key -> pool -> fallback.

    Returns (api_key, using_own_key).
    """
    if ctx is not None and getattr(ctx, "user_id", None):
        try:
            from tools.manage_api_keys import get_user_api_key
            user_key = await get_user_api_key(
                ctx.user_id, "gemini",
                redis_client=getattr(ctx, "redis", None),
                channel_id=getattr(ctx, "channel_id", None),
                config=getattr(ctx, "config", None),
            )
            if user_key:
                return user_key, True
        except Exception as exc:
            logger.warning("Failed to resolve user Gemini key: %s", exc)
    return FALLBACK_API_KEY, False


async def _call_gemini_native(
    prompt_parts: list[dict[str, Any]],
    api_key: str,
    aspect_ratio: str = "16:9",
    model: str | None = None,
) -> bytes | None:
    """Call native Gemini API and return generated image bytes."""
    import httpx

    if aspect_ratio not in SUPPORTED_ASPECT_RATIOS:
        aspect_ratio = "16:9"

    model = model or DEFAULT_IMAGE_MODEL
    url = f"{GEMINI_API_BASE}/models/{model}:generateContent"

    payload: dict[str, Any] = {
        "contents": [{"parts": prompt_parts}],
        "systemInstruction": {
            "parts": [{"text": IMAGE_GENERATION_SYSTEM_PROMPT}],
        },
        "generationConfig": {
            "responseModalities": ["TEXT", "IMAGE"],
            "imageConfig": {"aspectRatio": aspect_ratio},
        },
    }

    headers = {
        "x-goog-api-key": api_key,
        "Content-Type": "application/json",
    }

    async with httpx.AsyncClient(timeout=120.0) as http:
        resp = await http.post(url, headers=headers, json=payload)
        if resp.status_code != 200:
            logger.error(
                "Gemini API error: %d - %s",
                resp.status_code, resp.text[:500],
            )
            return None
        raw = await resp.aread()
        if len(raw) >= _JSON_BODY_DECODE_THRESHOLD:
            result = await asyncio.to_thread(_json_loads_utf8, raw)
        else:
            result = json.loads(raw.decode("utf-8"))

    # Parse native Gemini response: candidates[0].content.parts[]
    candidates = result.get("candidates", [])
    if not candidates:
        logger.warning("Gemini API returned no candidates")
        return None

    parts = candidates[0].get("content", {}).get("parts", [])
    for part in parts:
        inline_data = part.get("inlineData")
        if inline_data and inline_data.get("data"):
            try:
                return base64.b64decode(inline_data["data"])
            except Exception as exc:
                logger.error("Base64 decode failed: %s", exc)
    return None


[docs] async def run( prompt: str, aspect_ratio: str = "16:9", model: str | None = None, ctx: ToolContext | None = None, ) -> str: """Execute this tool and return the result. Args: prompt (str): The prompt value. aspect_ratio (str): The aspect ratio value. model (str | None): The model value. ctx (ToolContext | None): Tool execution context providing access to bot internals. Returns: str: Result string. """ if ctx is None or ctx.adapter is None: return "Error: No platform adapter available." api_key, using_own_key = await _resolve_api_key(ctx) # Rate-limit non-admin users on the default/fallback key if not using_own_key and not _is_admin(ctx): redis = getattr(ctx, "redis", None) if redis and ctx.user_id: from tools.manage_api_keys import ( check_default_key_limit, increment_default_key_usage, ) allowed, current, limit = await check_default_key_limit( ctx.user_id, "image_generation", redis, daily_limit=_IMAGE_DAILY_LIMIT, ) if not allowed: return json.dumps({ "error": IMAGE_RATE_LIMIT_ERROR.format( current=current, limit=limit, ), }) prompt_parts = [{"text": prompt}] try: img_bytes = await _call_gemini_native( prompt_parts, api_key, aspect_ratio, model, ) if not img_bytes: return json.dumps({ "error": "No image was generated by the model.", }) from PIL import Image def _convert_to_png(data: bytes) -> bytes: """Internal helper: convert to png. Args: data (bytes): Input data payload. Returns: bytes: The result. """ img = Image.open(BytesIO(data)) buf = BytesIO() img.save(buf, format="PNG") return buf.getvalue() png_bytes = await asyncio.to_thread(_convert_to_png, img_bytes) h = hashlib.sha256(png_bytes).hexdigest()[:16] fname = f"generated_{h}.png" file_url = await ctx.adapter.send_file( ctx.channel_id, png_bytes, fname, "image/png", ) ctx.sent_files.append({ "data": png_bytes, "filename": fname, "mimetype": "image/png", "file_url": file_url or "", }) if not using_own_key and not _is_admin(ctx): redis = getattr(ctx, "redis", None) if redis and ctx.user_id: from tools.manage_api_keys import increment_default_key_usage await increment_default_key_usage( ctx.user_id, "image_generation", redis, ) result: dict[str, Any] = { "success": True, "filename": fname, "result": "Image generated and sent to the channel.", } if file_url: result["file_url"] = file_url return json.dumps(result) except Exception as exc: logger.error( "Image generation error: %s", exc, exc_info=True, ) return json.dumps({"error": f"Image generation failed: {exc}"})