"""Image generation via Civitai API.
Generates AI images using Civitai's generation service and returns
the image file path or download URL.
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import random
import tempfile
from typing import Optional, Dict, Any, TYPE_CHECKING
import aiofiles
import aiohttp
from tools._safe_http import assert_safe_http_url
if TYPE_CHECKING:
from tool_context import ToolContext
logger = logging.getLogger(__name__)
DEFAULT_MODEL_URN = "urn:air:sdxl:checkpoint:civitai:1277670@1896532"
SCHEDULER_MAP = {
"euler": "EulerA", "euler_a": "EulerA", "eulera": "EulerA",
"heun": "Heun", "lms": "LMS", "lms_karras": "LMSKarras",
"dpm2": "DPM2", "dpm2_a": "DPM2A", "dpm2_a_karras": "DPM2AKarras",
"dpmpp_2m": "DPM2M", "dpmpp_2m_karras": "DPM2MKarras",
"dpmpp_sde": "DPMSDE", "dpmpp_sde_karras": "DPMSDEKarras",
"dpm_fast": "DPMFast", "dpm_adaptive": "DPMAdaptive",
"ddim": "DDIM", "plms": "PLMS", "unipc": "UniPC",
"lcm": "LCM", "ddpm": "DDPM", "deis": "DEIS",
}
async def _download_image(url_or_blob: str, api_token: str = "") -> Optional[bytes]:
"""Internal helper: download image.
Args:
url_or_blob (str): The url or blob value.
api_token (str): The api token value.
Returns:
Optional[bytes]: The result.
"""
try:
if len(url_or_blob) == 64 and all(c in '0123456789ABCDEFabcdef' for c in url_or_blob):
api_token = api_token or os.getenv("CIVITAI_API_TOKEN", "")
blob_urls = [
(f"https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/{url_or_blob}/width=1024", {}),
(f"https://image.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/{url_or_blob}/original=true", {}),
(f"https://civitai.com/api/download/blobs/{url_or_blob}", {"Authorization": f"Bearer {api_token}"} if api_token else {}),
]
async with aiohttp.ClientSession() as session:
for blob_url, headers in blob_urls:
try:
async with session.get(blob_url, headers=headers, timeout=60) as response:
if response.status == 200:
data = await response.read()
if data.startswith(b'<!DOCTYPE') or data.startswith(b'<html'):
continue
if data.startswith(b'\x89PNG') or data.startswith(b'\xff\xd8\xff'):
return data
except Exception:
continue
return None
try:
safe_url = assert_safe_http_url(url_or_blob.strip())
except ValueError as exc:
logger.error("Blocked Civitai download URL: %s", exc)
return None
async with aiohttp.ClientSession() as session:
async with session.get(safe_url, timeout=60) as response:
return await response.read() if response.status == 200 else None
except Exception as e:
logger.error(f"Error downloading image: {e}")
return None
async def _poll_job_status(job_token: str, max_wait: int = 300) -> Optional[Dict[str, Any]]:
"""Internal helper: poll job status.
Args:
job_token (str): The job token value.
max_wait (int): The max wait value.
Returns:
Optional[Dict[str, Any]]: The result.
"""
import civitai
start_time = asyncio.get_event_loop().time()
poll_interval = 2
while True:
elapsed = asyncio.get_event_loop().time() - start_time
if elapsed > max_wait or not job_token:
return None
try:
job_response = await civitai.jobs.get(token=job_token)
if not job_response:
await asyncio.sleep(poll_interval)
continue
jobs = job_response.get('jobs', [])
if not jobs:
if 'result' in job_response or 'jobId' in job_response:
jobs = [job_response]
else:
await asyncio.sleep(poll_interval)
continue
for job in jobs:
result = job.get('result')
if not result:
continue
if isinstance(result, list):
if all(isinstance(r, dict) and r.get('available', False) for r in result) and result:
return {'token': job_response.get('token'), 'jobs': [job]}
elif isinstance(result, dict) and (result.get('available', False) or result.get('blobUrl')):
return {'token': job_response.get('token'), 'jobs': [job]}
await asyncio.sleep(poll_interval)
except Exception:
await asyncio.sleep(poll_interval)
TOOL_NAME = "civitai_generate_image"
TOOL_DESCRIPTION = (
"Generate AI images using Civitai's generation service. "
"Returns the file path to the generated image. "
"Supports various models, schedulers, and generation parameters."
)
TOOL_PARAMETERS = {
"type": "object",
"properties": {
"prompt": {"type": "string", "description": "Text description of the image to generate."},
"negative_prompt": {"type": "string", "description": "What to avoid in the image."},
"width": {"type": "integer", "description": "Image width in pixels (64-2048).", "default": 1024},
"height": {"type": "integer", "description": "Image height in pixels (64-2048).", "default": 1024},
"steps": {"type": "integer", "description": "Sampling steps (1-100).", "default": 30},
"cfg": {"type": "number", "description": "CFG scale (1.0-20.0).", "default": 4.0},
"seed": {"type": "integer", "description": "Random seed (-1 for random)."},
"scheduler": {"type": "string", "description": "Sampling scheduler.", "default": "DPM2MKarras"},
"model_urn": {"type": "string", "description": "Civitai model URN."},
"clip_skip": {"type": "integer", "description": "CLIP layers to skip.", "default": 2},
},
"required": ["prompt"],
}
[docs]
async def run(
prompt: str, negative_prompt: str = None, width: int = 1024, height: int = 1024,
steps: int = 30, cfg: float = 4.0, seed: int = None, scheduler: str = "DPM2MKarras",
model_urn: str = None, clip_skip: int = 2,
ctx: ToolContext | None = None,
) -> str:
"""Execute this tool and return the result.
Args:
prompt (str): The prompt value.
negative_prompt (str): The negative prompt value.
width (int): The width value.
height (int): The height value.
steps (int): The steps value.
cfg (float): Bot configuration object.
seed (int): The seed value.
scheduler (str): The scheduler value.
model_urn (str): The model urn value.
clip_skip (int): The clip skip value.
ctx (ToolContext | None): Tool execution context providing access to bot internals.
Returns:
str: Result string.
"""
import civitai
if not prompt or not prompt.strip():
return json.dumps({"error": "Prompt is required"})
user_civitai_key = None
if ctx and ctx.redis and ctx.user_id:
from tools.manage_api_keys import get_user_api_key
user_civitai_key = await get_user_api_key(
ctx.user_id, "civitai",
redis_client=ctx.redis, channel_id=ctx.channel_id,
config=getattr(ctx, "config", None),
)
if not user_civitai_key and not os.getenv("CIVITAI_API_TOKEN"):
from tools.manage_api_keys import missing_api_key_error
return json.dumps({"error": missing_api_key_error("civitai")})
_prev_civitai_token = os.environ.get("CIVITAI_API_TOKEN")
if user_civitai_key:
os.environ["CIVITAI_API_TOKEN"] = user_civitai_key
width = max(64, min(2048, width))
height = max(64, min(2048, height))
steps = max(1, min(100, steps))
cfg = max(1.0, min(20.0, cfg))
scheduler = SCHEDULER_MAP.get(scheduler.lower(), scheduler)
if seed is None:
seed = random.randint(0, 2**32 - 1)
model_urn = model_urn or DEFAULT_MODEL_URN
try:
input_data = {
"model": model_urn,
"params": {
"prompt": prompt, "scheduler": scheduler, "steps": steps,
"cfgScale": cfg, "width": width, "height": height,
"seed": seed, "clipSkip": clip_skip,
}
}
if negative_prompt:
input_data["params"]["negativePrompt"] = negative_prompt
response = None
for attempt in range(3):
try:
response = await civitai.image.create(input_data, wait=False)
break
except Exception as e:
if attempt == 2:
return json.dumps({"error": f"Civitai API error: {str(e)[:100]}"})
await asyncio.sleep(2 ** (attempt + 1))
if not response:
return json.dumps({"error": "No response from Civitai API"})
job_token = response.get('token')
job_result = await _poll_job_status(job_token, max_wait=300)
if not job_result:
return json.dumps({"error": "Job failed or timed out"})
results = []
for job in job_result.get('jobs', []):
job_res = job.get('result')
if isinstance(job_res, dict):
results.append(job_res)
elif isinstance(job_res, list):
results.extend(job_res)
saved_files = []
for i, result_item in enumerate(results):
if isinstance(result_item, dict) and not result_item.get('available', True):
continue
image_url = result_item.get('blobUrl') or result_item.get('url') or result_item.get('blobKey') if isinstance(result_item, dict) else result_item
if not image_url:
continue
image_data = await _download_image(image_url, api_token=user_civitai_key or "")
if image_data:
tmp_dir = tempfile.mkdtemp()
filepath = os.path.join(tmp_dir, f"civitai_{seed}_{i}.png")
async with aiofiles.open(filepath, 'wb') as f:
await f.write(image_data)
saved_files.append(filepath)
if not saved_files:
return json.dumps({"error": "Failed to download generated images"})
return json.dumps({
"success": True,
"message": f"Generated {len(saved_files)} image(s)",
"files": saved_files,
"seed": seed,
"parameters": {"width": width, "height": height, "steps": steps, "cfg": cfg, "prompt": prompt, "scheduler": scheduler, "model": model_urn},
})
except Exception as e:
logger.error(f"Unexpected error: {e}", exc_info=True)
return json.dumps({"error": f"Unexpected error: {str(e)}"})
finally:
if user_civitai_key:
if _prev_civitai_token is not None:
os.environ["CIVITAI_API_TOKEN"] = _prev_civitai_token
else:
os.environ.pop("CIVITAI_API_TOKEN", None)