"""Play music in Discord voice channels using Lyria RealTime via ctx.adapter."""
from __future__ import annotations
import asyncio
import logging
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from tool_context import ToolContext
logger = logging.getLogger(__name__)
_SCALE_ENUM_DESC = (
"Musical scale. One of: C_MAJOR_A_MINOR, D_FLAT_MAJOR_B_FLAT_MINOR, "
"D_MAJOR_B_MINOR, E_FLAT_MAJOR_C_MINOR, E_MAJOR_D_FLAT_MINOR, "
"F_MAJOR_D_MINOR, G_FLAT_MAJOR_E_FLAT_MINOR, G_MAJOR_E_MINOR, "
"A_FLAT_MAJOR_F_MINOR, A_MAJOR_G_FLAT_MINOR, B_FLAT_MAJOR_G_MINOR, "
"B_MAJOR_A_FLAT_MINOR, SCALE_UNSPECIFIED."
)
_PROMPT_DESC = (
"Music description. Format: 'prompt1:weight1, prompt2:weight2'. "
"Examples: instruments (Piano, Synth Pads, 808 Hip Hop Beat), "
"genres (Minimal Techno, Lo-Fi Hip Hop, Jazz Fusion), "
"moods (Ambient, Chill, Dreamy, Upbeat). Default: 'Violin metal'."
)
TOOL_NAME = "play_music"
TOOL_DESCRIPTION = (
"Join a Discord voice channel and start playing AI-generated music "
"using Lyria RealTime. Supports prompts, BPM, temperature, brightness, "
"scale, generation mode, guidance, density, muting, and more. "
"Requires ctx.adapter for Discord voice access."
)
TOOL_PARAMETERS = {
"type": "object",
"properties": {
"guild_id": {"type": "string", "description": "Discord server (guild) ID."},
"channel_id": {
"type": "string",
"description": "Voice channel ID (optional, auto-detects if omitted).",
},
"prompt": {"type": "string", "description": _PROMPT_DESC},
"bpm": {"type": "integer", "description": "Beats per minute (60-200).", "default": 135},
"temperature": {"type": "number", "description": "Creativity (0.0-3.0).", "default": 1.0},
"brightness": {"type": "number", "description": "Tonal quality (0.0-1.0).", "default": 0.5},
"scale": {"type": "string", "description": _SCALE_ENUM_DESC},
"music_generation_mode": {
"type": "string",
"description": "Mode: QUALITY, DIVERSITY, or VOCALIZATION (vocal-like as instrument).",
},
"guidance": {"type": "number", "description": "Prompt strictness (0.0-6.0)."},
"density": {"type": "number", "description": "Note density (0.0-1.0)."},
"mute_bass": {"type": "boolean", "description": "Reduce bass output."},
"mute_drums": {"type": "boolean", "description": "Reduce drums output."},
"only_bass_and_drums": {"type": "boolean", "description": "Only bass and drums."},
"top_k": {"type": "integer", "description": "Sampling parameter (1-1000)."},
"seed": {"type": "integer", "description": "Random seed (0-2147483647)."},
},
"required": ["guild_id"],
}
[docs]
async def run(
guild_id: str,
channel_id: str = None,
prompt: str = None,
bpm: int = 135,
temperature: float = 1.0,
brightness: float = 0.5,
scale: str = None,
music_generation_mode: str = None,
guidance: float = None,
density: float = None,
mute_bass: bool = None,
mute_drums: bool = None,
only_bass_and_drums: bool = None,
top_k: int = None,
seed: int = None,
ctx: "ToolContext | None" = None,
) -> str:
"""Execute this tool and return the result."""
if ctx is None or not hasattr(ctx, "adapter") or ctx.adapter is None:
return "Error: Discord adapter (ctx.adapter) is required for voice channel music playback."
from tools._discord_helpers import get_discord_client
from services.lyria_session import parse_prompts
client = get_discord_client(ctx)
if isinstance(client, str):
return client
lyria = ctx.adapter.lyria_service
try:
guild_id_int = int(guild_id)
except (ValueError, TypeError):
return "Error: Invalid guild_id. Must be an integer."
guild = client.get_guild(guild_id_int)
if not guild:
return "Error: Guild not found."
if lyria.is_playing(guild_id_int):
return "Error: Music is already playing. Use steer_music to modify."
import discord
voice_channel = None
if channel_id:
try:
voice_channel = guild.get_channel(int(channel_id))
if not voice_channel or not isinstance(voice_channel, discord.VoiceChannel):
return f"Error: Voice channel {channel_id} not found or not a voice channel."
except (ValueError, TypeError):
return "Error: Invalid channel_id."
else:
voice_channels = [ch for ch in guild.voice_channels if len(ch.members) > 0]
if voice_channels:
voice_channel = voice_channels[0]
else:
return "Error: No voice channel specified and none with members found."
if guild.voice_client is not None and guild.voice_client.is_playing():
return "Error: Already playing in a channel."
_using_default_key = False
gemini_key = None
if ctx and ctx.redis and ctx.user_id:
from tools.manage_api_keys import get_user_api_key
gemini_key = await get_user_api_key(
ctx.user_id,
"gemini",
redis_client=ctx.redis,
channel_id=ctx.channel_id,
config=getattr(ctx, "config", None),
)
if not gemini_key:
if ctx and ctx.redis and ctx.user_id:
from tools.manage_api_keys import check_default_key_limit, default_key_limit_error
allowed, current, limit = await check_default_key_limit(
ctx.user_id, "play_music", ctx.redis, daily_limit=20,
)
if not allowed:
return default_key_limit_error("play_music", current, limit)
api_key = gemini_key or "AIzaSyCCwz9WCsIKSWsfufU6E-JbPsP1acLhZTU"
_using_default_key = not gemini_key
await lyria.init_client(api_key)
voice_client = guild.voice_client
try:
if voice_client is not None:
if not voice_client.is_connected():
await voice_client.disconnect(force=True)
voice_client = None
elif voice_client.channel != voice_channel:
await voice_client.move_to(voice_channel)
if voice_client is None:
voice_client = await voice_channel.connect(
timeout=60.0, self_deaf=True, reconnect=True
)
except Exception as e:
return f"Error: Couldn't connect to voice channel: {e}"
if prompt is None:
prompt = "Violin metal"
parsed_prompts = parse_prompts(prompt)
if not parsed_prompts:
return "Error: Invalid prompt format."
from google.genai import types
config_params = {"bpm": bpm, "brightness": brightness, "temperature": temperature}
if scale is not None:
scale_enum = getattr(types.Scale, scale, None)
if scale_enum is None:
return f"Error: Invalid scale '{scale}'."
config_params["scale"] = scale_enum
if music_generation_mode is not None:
mode_enum = getattr(types.MusicGenerationMode, music_generation_mode, None)
if mode_enum is None:
return f"Error: Invalid mode '{music_generation_mode}'."
config_params["music_generation_mode"] = mode_enum
for param_name, param_val, lo, hi in [
("guidance", guidance, 0.0, 6.0),
("density", density, 0.0, 1.0),
]:
if param_val is not None:
if not (lo <= param_val <= hi):
return f"Error: {param_name} must be between {lo} and {hi}."
config_params[param_name] = float(param_val)
for bool_param in ["mute_bass", "mute_drums", "only_bass_and_drums"]:
val = locals().get(bool_param)
if val is not None:
config_params[bool_param] = bool(val)
if top_k is not None:
if not (1 <= top_k <= 1000):
return "Error: top_k must be between 1 and 1000."
config_params["top_k"] = int(top_k)
if seed is not None:
if not (0 <= seed <= 2147483647):
return "Error: seed must be between 0 and 2147483647."
config_params["seed"] = int(seed)
config = types.LiveMusicGenerationConfig(**config_params)
event_loop = asyncio.get_event_loop()
try:
result = await lyria.start_session(
guild_id=guild_id_int,
voice_client=voice_client,
guild=guild,
prompts=parsed_prompts,
config=config,
event_loop=event_loop,
cleanup_callback=None,
)
if isinstance(result, str) and result.startswith("Error:"):
return result
if _using_default_key and ctx and ctx.redis and ctx.user_id:
from tools.manage_api_keys import increment_default_key_usage
await increment_default_key_usage(ctx.user_id, "play_music", ctx.redis)
bitrate = result if isinstance(result, int) else 128
tier = getattr(guild, "premium_tier", 0) or 0
return f"Playing music: `{prompt}` in {voice_channel.name} ({bitrate} kbps, Tier {tier})"
except Exception as e:
logger.error("Error in play_music: %s", e, exc_info=True)
await lyria.stop(guild_id_int)
if guild.voice_client:
try:
await guild.voice_client.disconnect(force=True)
except Exception:
pass
return f"Error: {e}"