Source code for tools.comfyui_generate_image

"""Image generation via Civitai API.

Generates AI images using Civitai's generation service and returns
the image file path or download URL.
"""

from __future__ import annotations

import asyncio
import jsonutil as json
import logging
import os
import random
import tempfile
from typing import Optional, Dict, Any, TYPE_CHECKING

import aiofiles
import aiohttp
import httpx

from tools._safe_http import safe_http_request, safe_httpx_client

if TYPE_CHECKING:
    from tool_context import ToolContext

logger = logging.getLogger(__name__)

DEFAULT_MODEL_URN = "urn:air:sdxl:checkpoint:civitai:1277670@1896532"

SCHEDULER_MAP = {
    "euler": "EulerA",
    "euler_a": "EulerA",
    "eulera": "EulerA",
    "heun": "Heun",
    "lms": "LMS",
    "lms_karras": "LMSKarras",
    "dpm2": "DPM2",
    "dpm2_a": "DPM2A",
    "dpm2_a_karras": "DPM2AKarras",
    "dpmpp_2m": "DPM2M",
    "dpmpp_2m_karras": "DPM2MKarras",
    "dpmpp_sde": "DPMSDE",
    "dpmpp_sde_karras": "DPMSDEKarras",
    "dpm_fast": "DPMFast",
    "dpm_adaptive": "DPMAdaptive",
    "ddim": "DDIM",
    "plms": "PLMS",
    "unipc": "UniPC",
    "lcm": "LCM",
    "ddpm": "DDPM",
    "deis": "DEIS",
}


async def _download_image(url_or_blob: str, api_token: str = "") -> Optional[bytes]:
    """Download generated image bytes from a Civitai URL or blob hash.

    Civitai returns a result that may be either a direct image URL or a 64-char
    hex blob hash. For a blob hash this tries the known Civitai CDN URL variants in
    turn (width-resized, original, and the authenticated blob-download endpoint),
    skipping HTML error pages and only accepting payloads with a PNG or JPEG magic
    header. For a plain URL it fetches through ``tools._safe_http`` (SSRF-guarded,
    redirects re-validated). Any failure is logged and yields ``None`` so the
    caller can skip that image.

    Called by :func:`run` for each generated image after the Civitai job completes.
    Its side effect is the outbound HTTP fetch (via :class:`aiohttp` for blobs or
    the safe ``httpx`` client for URLs); it reads ``CIVITAI_API_TOKEN`` from the
    environment when no token is passed.

    Args:
        url_or_blob (str): A direct image URL or a 64-character hex blob hash.
        api_token (str): Optional Civitai token for the authenticated blob endpoint;
            falls back to the ``CIVITAI_API_TOKEN`` environment variable.

    Returns:
        Optional[bytes]: The decoded image bytes, or ``None`` if every source
        failed or returned non-image content.
    """
    try:
        if len(url_or_blob) == 64 and all(
            c in "0123456789ABCDEFabcdef" for c in url_or_blob
        ):
            api_token = api_token or os.getenv("CIVITAI_API_TOKEN", "")
            blob_urls = [
                (
                    f"https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/{url_or_blob}/width=1024",
                    {},
                ),
                (
                    f"https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/{url_or_blob}/original=true",
                    {},
                ),
                (
                    f"https://civitai.com/api/download/blobs/{url_or_blob}",
                    {"Authorization": f"Bearer {api_token}"} if api_token else {},
                ),
            ]
            async with aiohttp.ClientSession() as session:
                for blob_url, headers in blob_urls:
                    try:
                        async with session.get(
                            blob_url, headers=headers, timeout=60
                        ) as response:
                            if response.status == 200:
                                data = await response.read()
                                if data.startswith(b"<!DOCTYPE") or data.startswith(
                                    b"<html"
                                ):
                                    continue
                                if data.startswith(b"\x89PNG") or data.startswith(
                                    b"\xff\xd8\xff"
                                ):
                                    return data
                    except Exception:
                        continue
            return None
        try:
            async with safe_httpx_client(timeout=httpx.Timeout(60.0)) as client:
                response = await safe_http_request(
                    client, "GET", url_or_blob.strip(), max_redirects=5
                )
                return response.content if response.status_code == 200 else None
        except ValueError as exc:
            logger.error("Blocked Civitai download URL: %s", exc)
            return None
    except Exception as e:
        logger.error(f"Error downloading image: {e}")
        return None


def _patch_civitai_generator_client(civitai_sdk: Any) -> None:
    """Re-sync the civitai-py client singleton with the current API token.

    The civitai-py SDK captures ``CIVITAI_API_TOKEN`` once at import time, so a
    per-user token swapped into ``os.environ`` later would otherwise be ignored.
    This patches the live client's ``api_token`` attribute and ``Authorization``
    header in place to match the current environment value, making per-user keys
    take effect without re-importing the SDK. It is a no-op when the client or
    token is absent.

    Called by :func:`run` (right after importing the SDK) and by
    :func:`_poll_job_status` before each polling session. Its side effect is the
    in-place mutation of the SDK's module-level client object.

    Args:
        civitai_sdk (Any): The imported ``civitai`` module whose ``civitai`` client
            singleton is patched.

    Returns:
        None
    """
    client = getattr(civitai_sdk, "civitai", None)
    token = os.environ.get("CIVITAI_API_TOKEN", "")
    if client is None or not token:
        return
    if hasattr(client, "api_token"):
        client.api_token = token
    headers = getattr(client, "headers", None)
    if isinstance(headers, dict):
        headers["Authorization"] = f"Bearer {token}"


def _require_civitai_generator_sdk(civitai_sdk: Any) -> None:
    """Guard that the imported ``civitai`` module is the civitai-py generator SDK.

    Two different PyPI packages publish a top-level ``civitai`` module, and only
    civitai-py exposes the image-generation API. This inspects the imported module
    for the ``image`` and ``jobs`` namespaces and a callable ``image.create``, and
    raises a clear, actionable error when the wrong package is installed so the
    failure is obvious rather than a cryptic ``AttributeError`` mid-generation.

    Called once at the top of :func:`run` immediately after importing the SDK. Pure
    introspection with no side effects beyond the raised error.

    Args:
        civitai_sdk (Any): The imported ``civitai`` module to validate.

    Returns:
        None

    Raises:
        RuntimeError: If the module lacks the civitai-py generator surface.
    """
    image_ns = getattr(civitai_sdk, "image", None)
    jobs_ns = getattr(civitai_sdk, "jobs", None)
    if not image_ns or not jobs_ns or not callable(getattr(image_ns, "create", None)):
        raise RuntimeError(
            "civitai_generate_image requires the civitai-py package (`pip install civitai-py`). "
            "The PyPI package named 'civitai' is a different library and does not expose image generation."
        )


async def _poll_job_status(
    job_token: str, max_wait: int = 300
) -> Optional[Dict[str, Any]]:
    """Poll a Civitai generation job until its results are ready or it times out.

    After ``image.create`` returns a job token, the image is produced
    asynchronously, so this loops on ``civitai.jobs.get(token=...)`` every two
    seconds, treating a job as complete once its ``result`` is available (handling
    both the list-of-results and single-result shapes, and the ``blobUrl`` case).
    It gives up and returns ``None`` once ``max_wait`` seconds elapse, the token is
    empty, or repeated errors exhaust the window, swallowing transient errors so a
    blip does not abort generation.

    Called by :func:`run` once the create request returns a token. It calls
    :func:`_patch_civitai_generator_client` first to honor the current token, then
    performs the polling HTTP calls through the civitai-py SDK.

    Args:
        job_token (str): The Civitai job token returned by ``image.create``.
        max_wait (int): Maximum seconds to poll before giving up (default 300).

    Returns:
        Optional[Dict[str, Any]]: A dict with ``token`` and a single completed
        ``jobs`` entry once results are available, or ``None`` on timeout/failure.
    """
    import civitai as civitai_sdk

    _patch_civitai_generator_client(civitai_sdk)
    start_time = asyncio.get_event_loop().time()
    poll_interval = 2
    while True:
        elapsed = asyncio.get_event_loop().time() - start_time
        if elapsed > max_wait or not job_token:
            return None
        try:
            job_response = await civitai_sdk.jobs.get(token=job_token)
            if not job_response:
                await asyncio.sleep(poll_interval)
                continue
            jobs = job_response.get("jobs", [])
            if not jobs:
                if "result" in job_response or "jobId" in job_response:
                    jobs = [job_response]
                else:
                    await asyncio.sleep(poll_interval)
                    continue
            for job in jobs:
                result = job.get("result")
                if not result:
                    continue
                if isinstance(result, list):
                    if (
                        all(
                            isinstance(r, dict) and r.get("available", False)
                            for r in result
                        )
                        and result
                    ):
                        return {"token": job_response.get("token"), "jobs": [job]}
                elif isinstance(result, dict) and (
                    result.get("available", False) or result.get("blobUrl")
                ):
                    return {"token": job_response.get("token"), "jobs": [job]}
            await asyncio.sleep(poll_interval)
        except Exception:
            await asyncio.sleep(poll_interval)


TOOL_NAME = "civitai_generate_image"
TOOL_DESCRIPTION = (
    "Generate AI images using Civitai's generation service. "
    "Returns the file path to the generated image. "
    "Supports various models, schedulers, and generation parameters."
)
TOOL_PARAMETERS = {
    "type": "object",
    "properties": {
        "prompt": {
            "type": "string",
            "description": "Text description of the image to generate.",
        },
        "negative_prompt": {
            "type": "string",
            "description": "What to avoid in the image.",
        },
        "width": {
            "type": "integer",
            "description": "Image width in pixels (64-2048).",
            "default": 1024,
        },
        "height": {
            "type": "integer",
            "description": "Image height in pixels (64-2048).",
            "default": 1024,
        },
        "steps": {
            "type": "integer",
            "description": "Sampling steps (1-100).",
            "default": 30,
        },
        "cfg": {
            "type": "number",
            "description": "CFG scale (1.0-20.0).",
            "default": 4.0,
        },
        "seed": {"type": "integer", "description": "Random seed (-1 for random)."},
        "scheduler": {
            "type": "string",
            "description": "Sampling scheduler.",
            "default": "DPM2MKarras",
        },
        "model_urn": {"type": "string", "description": "Civitai model URN."},
        "clip_skip": {
            "type": "integer",
            "description": "CLIP layers to skip.",
            "default": 2,
        },
    },
    "required": ["prompt"],
}


[docs] async def run( prompt: str, negative_prompt: str = None, width: int = 1024, height: int = 1024, steps: int = 30, cfg: float = 4.0, seed: int = None, scheduler: str = "DPM2MKarras", model_urn: str = None, clip_skip: int = 2, ctx: ToolContext | None = None, ) -> str: """Generate one or more AI images via Civitai and return their file paths. Entry point for the ``civitai_generate_image`` tool. It resolves a Civitai API token (per-user via ``tools.manage_api_keys.get_user_api_key`` when ``ctx`` has Redis and a user id, otherwise the environment token), temporarily installs it into ``os.environ`` so the civitai-py SDK can pick it up, clamps and normalizes the generation parameters (mapping scheduler aliases via :data:`SCHEDULER_MAP`, capping width/height to the generator's 1024 contract, defaulting the model to :data:`DEFAULT_MODEL_URN`, randomizing the seed when unset), submits the job with retries through ``civitai.image.create``, waits for it with :func:`_poll_job_status`, downloads each result with :func:`_download_image`, and writes them to per-image temp files. The original ``CIVITAI_API_TOKEN`` is always restored in a ``finally`` block. Dispatched by ``tool_loader.py`` as the ``civitai_generate_image`` handler (located via ``getattr(module, "run")``); not called directly elsewhere. Side effects include the temporary ``os.environ`` token swap, the Civitai HTTP calls, and writing PNG files into freshly created temp directories. Args: prompt (str): Text description of the image to generate (required). negative_prompt (str): Optional description of what to avoid. width (int): Requested width in pixels (clamped to 64-2048; the API call is additionally capped at 1024). height (int): Requested height in pixels (clamped to 64-2048; API-capped at 1024). steps (int): Sampling steps (clamped to 1-100). cfg (float): CFG (classifier-free guidance) scale (1.0-20.0). seed (int): Random seed; a random value is chosen when ``None``. scheduler (str): Sampling scheduler name or alias (mapped via :data:`SCHEDULER_MAP`). model_urn (str): Optional Civitai model URN; defaults to :data:`DEFAULT_MODEL_URN`. clip_skip (int): Number of CLIP layers to skip. ctx (ToolContext | None): Tool execution context; supplies Redis and the user id used to resolve a per-user API key. Returns: str: A JSON object with ``success``, the saved ``files`` list, the ``seed``, and the effective ``parameters`` on success, or an ``{"error": ...}`` object on missing prompt, missing key, API failure, or download failure. """ if not prompt or not prompt.strip(): return json.dumps({"error": "Prompt is required"}) user_civitai_key = None if ctx and ctx.redis and ctx.user_id: from tools.manage_api_keys import get_user_api_key user_civitai_key = await get_user_api_key( ctx.user_id, "civitai", redis_client=ctx.redis, channel_id=ctx.channel_id, config=getattr(ctx, "config", None), ) if not user_civitai_key and not os.getenv("CIVITAI_API_TOKEN"): from tools.manage_api_keys import missing_api_key_error return json.dumps({"error": missing_api_key_error("civitai")}) _prev_civitai_token = os.environ.get("CIVITAI_API_TOKEN") if user_civitai_key: os.environ["CIVITAI_API_TOKEN"] = user_civitai_key width = max(64, min(2048, width)) height = max(64, min(2048, height)) # civitai-py validates params.width/height to max 1024 (Civitai generator contract). api_width = min(width, 1024) api_height = min(height, 1024) steps = max(1, min(100, steps)) cfg = max(1.0, min(20.0, cfg)) scheduler = SCHEDULER_MAP.get(scheduler.lower(), scheduler) if seed is None: seed = random.randint(0, 2**32 - 1) model_urn = model_urn or DEFAULT_MODEL_URN try: import civitai as civitai_sdk _require_civitai_generator_sdk(civitai_sdk) _patch_civitai_generator_client(civitai_sdk) input_data = { "model": model_urn, "params": { "prompt": prompt, "scheduler": scheduler, "steps": steps, "cfgScale": cfg, "width": api_width, "height": api_height, "seed": seed, "clipSkip": clip_skip, }, } if negative_prompt: input_data["params"]["negativePrompt"] = negative_prompt response = None for attempt in range(3): try: response = await civitai_sdk.image.create(input_data, wait=False) break except Exception as e: if attempt == 2: return json.dumps({"error": f"Civitai API error: {str(e)[:100]}"}) await asyncio.sleep(2 ** (attempt + 1)) if not response: return json.dumps({"error": "No response from Civitai API"}) job_token = response.get("token") job_result = await _poll_job_status(job_token, max_wait=300) if not job_result: return json.dumps({"error": "Job failed or timed out"}) results = [] for job in job_result.get("jobs", []): job_res = job.get("result") if isinstance(job_res, dict): results.append(job_res) elif isinstance(job_res, list): results.extend(job_res) saved_files = [] for i, result_item in enumerate(results): if isinstance(result_item, dict) and not result_item.get("available", True): continue image_url = ( result_item.get("blobUrl") or result_item.get("url") or result_item.get("blobKey") if isinstance(result_item, dict) else result_item ) if not image_url: continue image_data = await _download_image( image_url, api_token=user_civitai_key or "" ) if image_data: tmp_dir = tempfile.mkdtemp() filepath = os.path.join(tmp_dir, f"civitai_{seed}_{i}.png") async with aiofiles.open(filepath, "wb") as f: await f.write(image_data) saved_files.append(filepath) if not saved_files: return json.dumps({"error": "Failed to download generated images"}) return json.dumps( { "success": True, "message": f"Generated {len(saved_files)} image(s)", "files": saved_files, "seed": seed, "parameters": { "width": api_width, "height": api_height, "requested_width": width, "requested_height": height, "steps": steps, "cfg": cfg, "prompt": prompt, "scheduler": scheduler, "model": model_urn, }, } ) except Exception as e: logger.error(f"Unexpected error: {e}", exc_info=True) return json.dumps({"error": f"Unexpected error: {str(e)}"}) finally: if user_civitai_key: if _prev_civitai_token is not None: os.environ["CIVITAI_API_TOKEN"] = _prev_civitai_token else: os.environ.pop("CIVITAI_API_TOKEN", None)