Source code for tools.comfyui_generate_image

"""Image generation via Civitai API.

Generates AI images using Civitai's generation service and returns
the image file path or download URL.
"""

from __future__ import annotations

import asyncio
import json
import logging
import os
import random
import tempfile
from typing import Optional, Dict, Any, TYPE_CHECKING

import aiofiles
import aiohttp

from tools._safe_http import assert_safe_http_url

if TYPE_CHECKING:
    from tool_context import ToolContext

logger = logging.getLogger(__name__)

DEFAULT_MODEL_URN = "urn:air:sdxl:checkpoint:civitai:1277670@1896532"

SCHEDULER_MAP = {
    "euler": "EulerA", "euler_a": "EulerA", "eulera": "EulerA",
    "heun": "Heun", "lms": "LMS", "lms_karras": "LMSKarras",
    "dpm2": "DPM2", "dpm2_a": "DPM2A", "dpm2_a_karras": "DPM2AKarras",
    "dpmpp_2m": "DPM2M", "dpmpp_2m_karras": "DPM2MKarras",
    "dpmpp_sde": "DPMSDE", "dpmpp_sde_karras": "DPMSDEKarras",
    "dpm_fast": "DPMFast", "dpm_adaptive": "DPMAdaptive",
    "ddim": "DDIM", "plms": "PLMS", "unipc": "UniPC",
    "lcm": "LCM", "ddpm": "DDPM", "deis": "DEIS",
}


async def _download_image(url_or_blob: str, api_token: str = "") -> Optional[bytes]:
    """Internal helper: download image.

        Args:
            url_or_blob (str): The url or blob value.
            api_token (str): The api token value.

        Returns:
            Optional[bytes]: The result.
        """
    try:
        if len(url_or_blob) == 64 and all(c in '0123456789ABCDEFabcdef' for c in url_or_blob):
            api_token = api_token or os.getenv("CIVITAI_API_TOKEN", "")
            blob_urls = [
                (f"https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/{url_or_blob}/width=1024", {}),
                (f"https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/{url_or_blob}/original=true", {}),
                (f"https://civitai.com/api/download/blobs/{url_or_blob}", {"Authorization": f"Bearer {api_token}"} if api_token else {}),
            ]
            async with aiohttp.ClientSession() as session:
                for blob_url, headers in blob_urls:
                    try:
                        async with session.get(blob_url, headers=headers, timeout=60) as response:
                            if response.status == 200:
                                data = await response.read()
                                if data.startswith(b'<!DOCTYPE') or data.startswith(b'<html'):
                                    continue
                                if data.startswith(b'\x89PNG') or data.startswith(b'\xff\xd8\xff'):
                                    return data
                    except Exception:
                        continue
            return None
        try:
            safe_url = assert_safe_http_url(url_or_blob.strip())
        except ValueError as exc:
            logger.error("Blocked Civitai download URL: %s", exc)
            return None
        async with aiohttp.ClientSession() as session:
            async with session.get(safe_url, timeout=60) as response:
                return await response.read() if response.status == 200 else None
    except Exception as e:
        logger.error(f"Error downloading image: {e}")
        return None


async def _poll_job_status(job_token: str, max_wait: int = 300) -> Optional[Dict[str, Any]]:
    """Internal helper: poll job status.

        Args:
            job_token (str): The job token value.
            max_wait (int): The max wait value.

        Returns:
            Optional[Dict[str, Any]]: The result.
        """
    import civitai
    start_time = asyncio.get_event_loop().time()
    poll_interval = 2
    while True:
        elapsed = asyncio.get_event_loop().time() - start_time
        if elapsed > max_wait or not job_token:
            return None
        try:
            job_response = await civitai.jobs.get(token=job_token)
            if not job_response:
                await asyncio.sleep(poll_interval)
                continue
            jobs = job_response.get('jobs', [])
            if not jobs:
                if 'result' in job_response or 'jobId' in job_response:
                    jobs = [job_response]
                else:
                    await asyncio.sleep(poll_interval)
                    continue
            for job in jobs:
                result = job.get('result')
                if not result:
                    continue
                if isinstance(result, list):
                    if all(isinstance(r, dict) and r.get('available', False) for r in result) and result:
                        return {'token': job_response.get('token'), 'jobs': [job]}
                elif isinstance(result, dict) and (result.get('available', False) or result.get('blobUrl')):
                    return {'token': job_response.get('token'), 'jobs': [job]}
            await asyncio.sleep(poll_interval)
        except Exception:
            await asyncio.sleep(poll_interval)


TOOL_NAME = "civitai_generate_image"
TOOL_DESCRIPTION = (
    "Generate AI images using Civitai's generation service. "
    "Returns the file path to the generated image. "
    "Supports various models, schedulers, and generation parameters."
)
TOOL_PARAMETERS = {
    "type": "object",
    "properties": {
        "prompt": {"type": "string", "description": "Text description of the image to generate."},
        "negative_prompt": {"type": "string", "description": "What to avoid in the image."},
        "width": {"type": "integer", "description": "Image width in pixels (64-2048).", "default": 1024},
        "height": {"type": "integer", "description": "Image height in pixels (64-2048).", "default": 1024},
        "steps": {"type": "integer", "description": "Sampling steps (1-100).", "default": 30},
        "cfg": {"type": "number", "description": "CFG scale (1.0-20.0).", "default": 4.0},
        "seed": {"type": "integer", "description": "Random seed (-1 for random)."},
        "scheduler": {"type": "string", "description": "Sampling scheduler.", "default": "DPM2MKarras"},
        "model_urn": {"type": "string", "description": "Civitai model URN."},
        "clip_skip": {"type": "integer", "description": "CLIP layers to skip.", "default": 2},
    },
    "required": ["prompt"],
}


[docs] async def run( prompt: str, negative_prompt: str = None, width: int = 1024, height: int = 1024, steps: int = 30, cfg: float = 4.0, seed: int = None, scheduler: str = "DPM2MKarras", model_urn: str = None, clip_skip: int = 2, ctx: ToolContext | None = None, ) -> str: """Execute this tool and return the result. Args: prompt (str): The prompt value. negative_prompt (str): The negative prompt value. width (int): The width value. height (int): The height value. steps (int): The steps value. cfg (float): Bot configuration object. seed (int): The seed value. scheduler (str): The scheduler value. model_urn (str): The model urn value. clip_skip (int): The clip skip value. ctx (ToolContext | None): Tool execution context providing access to bot internals. Returns: str: Result string. """ import civitai if not prompt or not prompt.strip(): return json.dumps({"error": "Prompt is required"}) user_civitai_key = None if ctx and ctx.redis and ctx.user_id: from tools.manage_api_keys import get_user_api_key user_civitai_key = await get_user_api_key( ctx.user_id, "civitai", redis_client=ctx.redis, channel_id=ctx.channel_id, config=getattr(ctx, "config", None), ) if not user_civitai_key and not os.getenv("CIVITAI_API_TOKEN"): from tools.manage_api_keys import missing_api_key_error return json.dumps({"error": missing_api_key_error("civitai")}) _prev_civitai_token = os.environ.get("CIVITAI_API_TOKEN") if user_civitai_key: os.environ["CIVITAI_API_TOKEN"] = user_civitai_key width = max(64, min(2048, width)) height = max(64, min(2048, height)) steps = max(1, min(100, steps)) cfg = max(1.0, min(20.0, cfg)) scheduler = SCHEDULER_MAP.get(scheduler.lower(), scheduler) if seed is None: seed = random.randint(0, 2**32 - 1) model_urn = model_urn or DEFAULT_MODEL_URN try: input_data = { "model": model_urn, "params": { "prompt": prompt, "scheduler": scheduler, "steps": steps, "cfgScale": cfg, "width": width, "height": height, "seed": seed, "clipSkip": clip_skip, } } if negative_prompt: input_data["params"]["negativePrompt"] = negative_prompt response = None for attempt in range(3): try: response = await civitai.image.create(input_data, wait=False) break except Exception as e: if attempt == 2: return json.dumps({"error": f"Civitai API error: {str(e)[:100]}"}) await asyncio.sleep(2 ** (attempt + 1)) if not response: return json.dumps({"error": "No response from Civitai API"}) job_token = response.get('token') job_result = await _poll_job_status(job_token, max_wait=300) if not job_result: return json.dumps({"error": "Job failed or timed out"}) results = [] for job in job_result.get('jobs', []): job_res = job.get('result') if isinstance(job_res, dict): results.append(job_res) elif isinstance(job_res, list): results.extend(job_res) saved_files = [] for i, result_item in enumerate(results): if isinstance(result_item, dict) and not result_item.get('available', True): continue image_url = result_item.get('blobUrl') or result_item.get('url') or result_item.get('blobKey') if isinstance(result_item, dict) else result_item if not image_url: continue image_data = await _download_image(image_url, api_token=user_civitai_key or "") if image_data: tmp_dir = tempfile.mkdtemp() filepath = os.path.join(tmp_dir, f"civitai_{seed}_{i}.png") async with aiofiles.open(filepath, 'wb') as f: await f.write(image_data) saved_files.append(filepath) if not saved_files: return json.dumps({"error": "Failed to download generated images"}) return json.dumps({ "success": True, "message": f"Generated {len(saved_files)} image(s)", "files": saved_files, "seed": seed, "parameters": {"width": width, "height": height, "steps": steps, "cfg": cfg, "prompt": prompt, "scheduler": scheduler, "model": model_urn}, }) except Exception as e: logger.error(f"Unexpected error: {e}", exc_info=True) return json.dumps({"error": f"Unexpected error: {str(e)}"}) finally: if user_civitai_key: if _prev_civitai_token is not None: os.environ["CIVITAI_API_TOKEN"] = _prev_civitai_token else: os.environ.pop("CIVITAI_API_TOKEN", None)