"""Image generation via Civitai API.
Generates AI images using Civitai's generation service and returns
the image file path or download URL.
"""
from __future__ import annotations
import asyncio
import jsonutil as json
import logging
import os
import random
import tempfile
from typing import Optional, Dict, Any, TYPE_CHECKING
import aiofiles
import aiohttp
import httpx
from tools._safe_http import safe_http_request, safe_httpx_client
if TYPE_CHECKING:
from tool_context import ToolContext
logger = logging.getLogger(__name__)
DEFAULT_MODEL_URN = "urn:air:sdxl:checkpoint:civitai:1277670@1896532"
SCHEDULER_MAP = {
"euler": "EulerA",
"euler_a": "EulerA",
"eulera": "EulerA",
"heun": "Heun",
"lms": "LMS",
"lms_karras": "LMSKarras",
"dpm2": "DPM2",
"dpm2_a": "DPM2A",
"dpm2_a_karras": "DPM2AKarras",
"dpmpp_2m": "DPM2M",
"dpmpp_2m_karras": "DPM2MKarras",
"dpmpp_sde": "DPMSDE",
"dpmpp_sde_karras": "DPMSDEKarras",
"dpm_fast": "DPMFast",
"dpm_adaptive": "DPMAdaptive",
"ddim": "DDIM",
"plms": "PLMS",
"unipc": "UniPC",
"lcm": "LCM",
"ddpm": "DDPM",
"deis": "DEIS",
}
async def _download_image(url_or_blob: str, api_token: str = "") -> Optional[bytes]:
"""Download generated image bytes from a Civitai URL or blob hash.
Civitai returns a result that may be either a direct image URL or a 64-char
hex blob hash. For a blob hash this tries the known Civitai CDN URL variants in
turn (width-resized, original, and the authenticated blob-download endpoint),
skipping HTML error pages and only accepting payloads with a PNG or JPEG magic
header. For a plain URL it fetches through ``tools._safe_http`` (SSRF-guarded,
redirects re-validated). Any failure is logged and yields ``None`` so the
caller can skip that image.
Called by :func:`run` for each generated image after the Civitai job completes.
Its side effect is the outbound HTTP fetch (via :class:`aiohttp` for blobs or
the safe ``httpx`` client for URLs); it reads ``CIVITAI_API_TOKEN`` from the
environment when no token is passed.
Args:
url_or_blob (str): A direct image URL or a 64-character hex blob hash.
api_token (str): Optional Civitai token for the authenticated blob endpoint;
falls back to the ``CIVITAI_API_TOKEN`` environment variable.
Returns:
Optional[bytes]: The decoded image bytes, or ``None`` if every source
failed or returned non-image content.
"""
try:
if len(url_or_blob) == 64 and all(
c in "0123456789ABCDEFabcdef" for c in url_or_blob
):
api_token = api_token or os.getenv("CIVITAI_API_TOKEN", "")
blob_urls = [
(
f"https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/{url_or_blob}/width=1024",
{},
),
(
f"https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/{url_or_blob}/original=true",
{},
),
(
f"https://civitai.com/api/download/blobs/{url_or_blob}",
{"Authorization": f"Bearer {api_token}"} if api_token else {},
),
]
async with aiohttp.ClientSession() as session:
for blob_url, headers in blob_urls:
try:
async with session.get(
blob_url, headers=headers, timeout=60
) as response:
if response.status == 200:
data = await response.read()
if data.startswith(b"<!DOCTYPE") or data.startswith(
b"<html"
):
continue
if data.startswith(b"\x89PNG") or data.startswith(
b"\xff\xd8\xff"
):
return data
except Exception:
continue
return None
try:
async with safe_httpx_client(timeout=httpx.Timeout(60.0)) as client:
response = await safe_http_request(
client, "GET", url_or_blob.strip(), max_redirects=5
)
return response.content if response.status_code == 200 else None
except ValueError as exc:
logger.error("Blocked Civitai download URL: %s", exc)
return None
except Exception as e:
logger.error(f"Error downloading image: {e}")
return None
def _patch_civitai_generator_client(civitai_sdk: Any) -> None:
"""Re-sync the civitai-py client singleton with the current API token.
The civitai-py SDK captures ``CIVITAI_API_TOKEN`` once at import time, so a
per-user token swapped into ``os.environ`` later would otherwise be ignored.
This patches the live client's ``api_token`` attribute and ``Authorization``
header in place to match the current environment value, making per-user keys
take effect without re-importing the SDK. It is a no-op when the client or
token is absent.
Called by :func:`run` (right after importing the SDK) and by
:func:`_poll_job_status` before each polling session. Its side effect is the
in-place mutation of the SDK's module-level client object.
Args:
civitai_sdk (Any): The imported ``civitai`` module whose ``civitai`` client
singleton is patched.
Returns:
None
"""
client = getattr(civitai_sdk, "civitai", None)
token = os.environ.get("CIVITAI_API_TOKEN", "")
if client is None or not token:
return
if hasattr(client, "api_token"):
client.api_token = token
headers = getattr(client, "headers", None)
if isinstance(headers, dict):
headers["Authorization"] = f"Bearer {token}"
def _require_civitai_generator_sdk(civitai_sdk: Any) -> None:
"""Guard that the imported ``civitai`` module is the civitai-py generator SDK.
Two different PyPI packages publish a top-level ``civitai`` module, and only
civitai-py exposes the image-generation API. This inspects the imported module
for the ``image`` and ``jobs`` namespaces and a callable ``image.create``, and
raises a clear, actionable error when the wrong package is installed so the
failure is obvious rather than a cryptic ``AttributeError`` mid-generation.
Called once at the top of :func:`run` immediately after importing the SDK. Pure
introspection with no side effects beyond the raised error.
Args:
civitai_sdk (Any): The imported ``civitai`` module to validate.
Returns:
None
Raises:
RuntimeError: If the module lacks the civitai-py generator surface.
"""
image_ns = getattr(civitai_sdk, "image", None)
jobs_ns = getattr(civitai_sdk, "jobs", None)
if not image_ns or not jobs_ns or not callable(getattr(image_ns, "create", None)):
raise RuntimeError(
"civitai_generate_image requires the civitai-py package (`pip install civitai-py`). "
"The PyPI package named 'civitai' is a different library and does not expose image generation."
)
async def _poll_job_status(
job_token: str, max_wait: int = 300
) -> Optional[Dict[str, Any]]:
"""Poll a Civitai generation job until its results are ready or it times out.
After ``image.create`` returns a job token, the image is produced
asynchronously, so this loops on ``civitai.jobs.get(token=...)`` every two
seconds, treating a job as complete once its ``result`` is available (handling
both the list-of-results and single-result shapes, and the ``blobUrl`` case).
It gives up and returns ``None`` once ``max_wait`` seconds elapse, the token is
empty, or repeated errors exhaust the window, swallowing transient errors so a
blip does not abort generation.
Called by :func:`run` once the create request returns a token. It calls
:func:`_patch_civitai_generator_client` first to honor the current token, then
performs the polling HTTP calls through the civitai-py SDK.
Args:
job_token (str): The Civitai job token returned by ``image.create``.
max_wait (int): Maximum seconds to poll before giving up (default 300).
Returns:
Optional[Dict[str, Any]]: A dict with ``token`` and a single completed
``jobs`` entry once results are available, or ``None`` on timeout/failure.
"""
import civitai as civitai_sdk
_patch_civitai_generator_client(civitai_sdk)
start_time = asyncio.get_event_loop().time()
poll_interval = 2
while True:
elapsed = asyncio.get_event_loop().time() - start_time
if elapsed > max_wait or not job_token:
return None
try:
job_response = await civitai_sdk.jobs.get(token=job_token)
if not job_response:
await asyncio.sleep(poll_interval)
continue
jobs = job_response.get("jobs", [])
if not jobs:
if "result" in job_response or "jobId" in job_response:
jobs = [job_response]
else:
await asyncio.sleep(poll_interval)
continue
for job in jobs:
result = job.get("result")
if not result:
continue
if isinstance(result, list):
if (
all(
isinstance(r, dict) and r.get("available", False)
for r in result
)
and result
):
return {"token": job_response.get("token"), "jobs": [job]}
elif isinstance(result, dict) and (
result.get("available", False) or result.get("blobUrl")
):
return {"token": job_response.get("token"), "jobs": [job]}
await asyncio.sleep(poll_interval)
except Exception:
await asyncio.sleep(poll_interval)
TOOL_NAME = "civitai_generate_image"
TOOL_DESCRIPTION = (
"Generate AI images using Civitai's generation service. "
"Returns the file path to the generated image. "
"Supports various models, schedulers, and generation parameters."
)
TOOL_PARAMETERS = {
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "Text description of the image to generate.",
},
"negative_prompt": {
"type": "string",
"description": "What to avoid in the image.",
},
"width": {
"type": "integer",
"description": "Image width in pixels (64-2048).",
"default": 1024,
},
"height": {
"type": "integer",
"description": "Image height in pixels (64-2048).",
"default": 1024,
},
"steps": {
"type": "integer",
"description": "Sampling steps (1-100).",
"default": 30,
},
"cfg": {
"type": "number",
"description": "CFG scale (1.0-20.0).",
"default": 4.0,
},
"seed": {"type": "integer", "description": "Random seed (-1 for random)."},
"scheduler": {
"type": "string",
"description": "Sampling scheduler.",
"default": "DPM2MKarras",
},
"model_urn": {"type": "string", "description": "Civitai model URN."},
"clip_skip": {
"type": "integer",
"description": "CLIP layers to skip.",
"default": 2,
},
},
"required": ["prompt"],
}
[docs]
async def run(
prompt: str,
negative_prompt: str = None,
width: int = 1024,
height: int = 1024,
steps: int = 30,
cfg: float = 4.0,
seed: int = None,
scheduler: str = "DPM2MKarras",
model_urn: str = None,
clip_skip: int = 2,
ctx: ToolContext | None = None,
) -> str:
"""Generate one or more AI images via Civitai and return their file paths.
Entry point for the ``civitai_generate_image`` tool. It resolves a Civitai API
token (per-user via ``tools.manage_api_keys.get_user_api_key`` when ``ctx`` has
Redis and a user id, otherwise the environment token), temporarily installs it
into ``os.environ`` so the civitai-py SDK can pick it up, clamps and normalizes
the generation parameters (mapping scheduler aliases via :data:`SCHEDULER_MAP`,
capping width/height to the generator's 1024 contract, defaulting the model to
:data:`DEFAULT_MODEL_URN`, randomizing the seed when unset), submits the job
with retries through ``civitai.image.create``, waits for it with
:func:`_poll_job_status`, downloads each result with :func:`_download_image`,
and writes them to per-image temp files. The original ``CIVITAI_API_TOKEN`` is
always restored in a ``finally`` block.
Dispatched by ``tool_loader.py`` as the ``civitai_generate_image`` handler
(located via ``getattr(module, "run")``); not called directly elsewhere. Side
effects include the temporary ``os.environ`` token swap, the Civitai HTTP
calls, and writing PNG files into freshly created temp directories.
Args:
prompt (str): Text description of the image to generate (required).
negative_prompt (str): Optional description of what to avoid.
width (int): Requested width in pixels (clamped to 64-2048; the API call is
additionally capped at 1024).
height (int): Requested height in pixels (clamped to 64-2048; API-capped at
1024).
steps (int): Sampling steps (clamped to 1-100).
cfg (float): CFG (classifier-free guidance) scale (1.0-20.0).
seed (int): Random seed; a random value is chosen when ``None``.
scheduler (str): Sampling scheduler name or alias (mapped via
:data:`SCHEDULER_MAP`).
model_urn (str): Optional Civitai model URN; defaults to
:data:`DEFAULT_MODEL_URN`.
clip_skip (int): Number of CLIP layers to skip.
ctx (ToolContext | None): Tool execution context; supplies Redis and the
user id used to resolve a per-user API key.
Returns:
str: A JSON object with ``success``, the saved ``files`` list, the ``seed``,
and the effective ``parameters`` on success, or an ``{"error": ...}`` object
on missing prompt, missing key, API failure, or download failure.
"""
if not prompt or not prompt.strip():
return json.dumps({"error": "Prompt is required"})
user_civitai_key = None
if ctx and ctx.redis and ctx.user_id:
from tools.manage_api_keys import get_user_api_key
user_civitai_key = await get_user_api_key(
ctx.user_id,
"civitai",
redis_client=ctx.redis,
channel_id=ctx.channel_id,
config=getattr(ctx, "config", None),
)
if not user_civitai_key and not os.getenv("CIVITAI_API_TOKEN"):
from tools.manage_api_keys import missing_api_key_error
return json.dumps({"error": missing_api_key_error("civitai")})
_prev_civitai_token = os.environ.get("CIVITAI_API_TOKEN")
if user_civitai_key:
os.environ["CIVITAI_API_TOKEN"] = user_civitai_key
width = max(64, min(2048, width))
height = max(64, min(2048, height))
# civitai-py validates params.width/height to max 1024 (Civitai generator contract).
api_width = min(width, 1024)
api_height = min(height, 1024)
steps = max(1, min(100, steps))
cfg = max(1.0, min(20.0, cfg))
scheduler = SCHEDULER_MAP.get(scheduler.lower(), scheduler)
if seed is None:
seed = random.randint(0, 2**32 - 1)
model_urn = model_urn or DEFAULT_MODEL_URN
try:
import civitai as civitai_sdk
_require_civitai_generator_sdk(civitai_sdk)
_patch_civitai_generator_client(civitai_sdk)
input_data = {
"model": model_urn,
"params": {
"prompt": prompt,
"scheduler": scheduler,
"steps": steps,
"cfgScale": cfg,
"width": api_width,
"height": api_height,
"seed": seed,
"clipSkip": clip_skip,
},
}
if negative_prompt:
input_data["params"]["negativePrompt"] = negative_prompt
response = None
for attempt in range(3):
try:
response = await civitai_sdk.image.create(input_data, wait=False)
break
except Exception as e:
if attempt == 2:
return json.dumps({"error": f"Civitai API error: {str(e)[:100]}"})
await asyncio.sleep(2 ** (attempt + 1))
if not response:
return json.dumps({"error": "No response from Civitai API"})
job_token = response.get("token")
job_result = await _poll_job_status(job_token, max_wait=300)
if not job_result:
return json.dumps({"error": "Job failed or timed out"})
results = []
for job in job_result.get("jobs", []):
job_res = job.get("result")
if isinstance(job_res, dict):
results.append(job_res)
elif isinstance(job_res, list):
results.extend(job_res)
saved_files = []
for i, result_item in enumerate(results):
if isinstance(result_item, dict) and not result_item.get("available", True):
continue
image_url = (
result_item.get("blobUrl")
or result_item.get("url")
or result_item.get("blobKey")
if isinstance(result_item, dict)
else result_item
)
if not image_url:
continue
image_data = await _download_image(
image_url, api_token=user_civitai_key or ""
)
if image_data:
tmp_dir = tempfile.mkdtemp()
filepath = os.path.join(tmp_dir, f"civitai_{seed}_{i}.png")
async with aiofiles.open(filepath, "wb") as f:
await f.write(image_data)
saved_files.append(filepath)
if not saved_files:
return json.dumps({"error": "Failed to download generated images"})
return json.dumps(
{
"success": True,
"message": f"Generated {len(saved_files)} image(s)",
"files": saved_files,
"seed": seed,
"parameters": {
"width": api_width,
"height": api_height,
"requested_width": width,
"requested_height": height,
"steps": steps,
"cfg": cfg,
"prompt": prompt,
"scheduler": scheduler,
"model": model_urn,
},
}
)
except Exception as e:
logger.error(f"Unexpected error: {e}", exc_info=True)
return json.dumps({"error": f"Unexpected error: {str(e)}"})
finally:
if user_civitai_key:
if _prev_civitai_token is not None:
os.environ["CIVITAI_API_TOKEN"] = _prev_civitai_token
else:
os.environ.pop("CIVITAI_API_TOKEN", None)