Source code for tools.play_music

"""Play music in Discord voice channels using Lyria RealTime via ctx.adapter."""

from __future__ import annotations

import asyncio
import logging
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from tool_context import ToolContext

logger = logging.getLogger(__name__)

_SCALE_ENUM_DESC = (
    "Musical scale. One of: C_MAJOR_A_MINOR, D_FLAT_MAJOR_B_FLAT_MINOR, "
    "D_MAJOR_B_MINOR, E_FLAT_MAJOR_C_MINOR, E_MAJOR_D_FLAT_MINOR, "
    "F_MAJOR_D_MINOR, G_FLAT_MAJOR_E_FLAT_MINOR, G_MAJOR_E_MINOR, "
    "A_FLAT_MAJOR_F_MINOR, A_MAJOR_G_FLAT_MINOR, B_FLAT_MAJOR_G_MINOR, "
    "B_MAJOR_A_FLAT_MINOR, SCALE_UNSPECIFIED."
)
_PROMPT_DESC = (
    "Music description. Format: 'prompt1:weight1, prompt2:weight2'. "
    "Examples: instruments (Piano, Synth Pads, 808 Hip Hop Beat), "
    "genres (Minimal Techno, Lo-Fi Hip Hop, Jazz Fusion), "
    "moods (Ambient, Chill, Dreamy, Upbeat). Default: 'Violin metal'."
)

TOOL_NAME = "play_music"
TOOL_DESCRIPTION = (
    "Join a Discord voice channel and start playing AI-generated music "
    "using Lyria RealTime. Supports prompts, BPM, temperature, brightness, "
    "scale, generation mode, guidance, density, muting, and more. "
    "Requires ctx.adapter for Discord voice access."
)
TOOL_PARAMETERS = {
    "type": "object",
    "properties": {
        "guild_id": {"type": "string", "description": "Discord server (guild) ID."},
        "channel_id": {
            "type": "string",
            "description": "Voice channel ID (optional, auto-detects if omitted).",
        },
        "prompt": {"type": "string", "description": _PROMPT_DESC},
        "bpm": {"type": "integer", "description": "Beats per minute (60-200).", "default": 135},
        "temperature": {"type": "number", "description": "Creativity (0.0-3.0).", "default": 1.0},
        "brightness": {"type": "number", "description": "Tonal quality (0.0-1.0).", "default": 0.5},
        "scale": {"type": "string", "description": _SCALE_ENUM_DESC},
        "music_generation_mode": {
            "type": "string",
            "description": "Mode: QUALITY, DIVERSITY, or VOCALIZATION (vocal-like as instrument).",
        },
        "guidance": {"type": "number", "description": "Prompt strictness (0.0-6.0)."},
        "density": {"type": "number", "description": "Note density (0.0-1.0)."},
        "mute_bass": {"type": "boolean", "description": "Reduce bass output."},
        "mute_drums": {"type": "boolean", "description": "Reduce drums output."},
        "only_bass_and_drums": {"type": "boolean", "description": "Only bass and drums."},
        "top_k": {"type": "integer", "description": "Sampling parameter (1-1000)."},
        "seed": {"type": "integer", "description": "Random seed (0-2147483647)."},
    },
    "required": ["guild_id"],
}


[docs] async def run( guild_id: str, channel_id: str = None, prompt: str = None, bpm: int = 135, temperature: float = 1.0, brightness: float = 0.5, scale: str = None, music_generation_mode: str = None, guidance: float = None, density: float = None, mute_bass: bool = None, mute_drums: bool = None, only_bass_and_drums: bool = None, top_k: int = None, seed: int = None, ctx: "ToolContext | None" = None, ) -> str: """Execute this tool and return the result.""" if ctx is None or not hasattr(ctx, "adapter") or ctx.adapter is None: return "Error: Discord adapter (ctx.adapter) is required for voice channel music playback." from tools._discord_helpers import get_discord_client from services.lyria_session import parse_prompts client = get_discord_client(ctx) if isinstance(client, str): return client lyria = ctx.adapter.lyria_service try: guild_id_int = int(guild_id) except (ValueError, TypeError): return "Error: Invalid guild_id. Must be an integer." guild = client.get_guild(guild_id_int) if not guild: return "Error: Guild not found." if lyria.is_playing(guild_id_int): return "Error: Music is already playing. Use steer_music to modify." import discord voice_channel = None if channel_id: try: voice_channel = guild.get_channel(int(channel_id)) if not voice_channel or not isinstance(voice_channel, discord.VoiceChannel): return f"Error: Voice channel {channel_id} not found or not a voice channel." except (ValueError, TypeError): return "Error: Invalid channel_id." else: voice_channels = [ch for ch in guild.voice_channels if len(ch.members) > 0] if voice_channels: voice_channel = voice_channels[0] else: return "Error: No voice channel specified and none with members found." if guild.voice_client is not None and guild.voice_client.is_playing(): return "Error: Already playing in a channel." _using_default_key = False gemini_key = None if ctx and ctx.redis and ctx.user_id: from tools.manage_api_keys import get_user_api_key gemini_key = await get_user_api_key( ctx.user_id, "gemini", redis_client=ctx.redis, channel_id=ctx.channel_id, config=getattr(ctx, "config", None), ) if not gemini_key: if ctx and ctx.redis and ctx.user_id: from tools.manage_api_keys import check_default_key_limit, default_key_limit_error allowed, current, limit = await check_default_key_limit( ctx.user_id, "play_music", ctx.redis, daily_limit=20, ) if not allowed: return default_key_limit_error("play_music", current, limit) api_key = gemini_key or "AIzaSyCCwz9WCsIKSWsfufU6E-JbPsP1acLhZTU" _using_default_key = not gemini_key await lyria.init_client(api_key) voice_client = guild.voice_client try: if voice_client is not None: if not voice_client.is_connected(): await voice_client.disconnect(force=True) voice_client = None elif voice_client.channel != voice_channel: await voice_client.move_to(voice_channel) if voice_client is None: voice_client = await voice_channel.connect( timeout=60.0, self_deaf=True, reconnect=True ) except Exception as e: return f"Error: Couldn't connect to voice channel: {e}" if prompt is None: prompt = "Violin metal" parsed_prompts = parse_prompts(prompt) if not parsed_prompts: return "Error: Invalid prompt format." from google.genai import types config_params = {"bpm": bpm, "brightness": brightness, "temperature": temperature} if scale is not None: scale_enum = getattr(types.Scale, scale, None) if scale_enum is None: return f"Error: Invalid scale '{scale}'." config_params["scale"] = scale_enum if music_generation_mode is not None: mode_enum = getattr(types.MusicGenerationMode, music_generation_mode, None) if mode_enum is None: return f"Error: Invalid mode '{music_generation_mode}'." config_params["music_generation_mode"] = mode_enum for param_name, param_val, lo, hi in [ ("guidance", guidance, 0.0, 6.0), ("density", density, 0.0, 1.0), ]: if param_val is not None: if not (lo <= param_val <= hi): return f"Error: {param_name} must be between {lo} and {hi}." config_params[param_name] = float(param_val) for bool_param in ["mute_bass", "mute_drums", "only_bass_and_drums"]: val = locals().get(bool_param) if val is not None: config_params[bool_param] = bool(val) if top_k is not None: if not (1 <= top_k <= 1000): return "Error: top_k must be between 1 and 1000." config_params["top_k"] = int(top_k) if seed is not None: if not (0 <= seed <= 2147483647): return "Error: seed must be between 0 and 2147483647." config_params["seed"] = int(seed) config = types.LiveMusicGenerationConfig(**config_params) event_loop = asyncio.get_event_loop() try: result = await lyria.start_session( guild_id=guild_id_int, voice_client=voice_client, guild=guild, prompts=parsed_prompts, config=config, event_loop=event_loop, cleanup_callback=None, ) if isinstance(result, str) and result.startswith("Error:"): return result if _using_default_key and ctx and ctx.redis and ctx.user_id: from tools.manage_api_keys import increment_default_key_usage await increment_default_key_usage(ctx.user_id, "play_music", ctx.redis) bitrate = result if isinstance(result, int) else 128 tier = getattr(guild, "premium_tier", 0) or 0 return f"Playing music: `{prompt}` in {voice_channel.name} ({bitrate} kbps, Tier {tier})" except Exception as e: logger.error("Error in play_music: %s", e, exc_info=True) await lyria.stop(guild_id_int) if guild.voice_client: try: await guild.voice_client.disconnect(force=True) except Exception: pass return f"Error: {e}"