Source code for tools.music_steering

"""Real-time music steering for Lyria RealTime via ctx.adapter."""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
    from tool_context import ToolContext

logger = logging.getLogger(__name__)

_SCALE_ENUM_DESC = (
    "Musical scale. One of: C_MAJOR_A_MINOR, D_FLAT_MAJOR_B_FLAT_MINOR, "
    "D_MAJOR_B_MINOR, E_FLAT_MAJOR_C_MINOR, E_MAJOR_D_FLAT_MINOR, "
    "F_MAJOR_D_MINOR, G_FLAT_MAJOR_E_FLAT_MINOR, G_MAJOR_E_MINOR, "
    "A_FLAT_MAJOR_F_MINOR, A_MAJOR_G_FLAT_MINOR, B_FLAT_MAJOR_G_MINOR, "
    "B_MAJOR_A_FLAT_MINOR, SCALE_UNSPECIFIED."
)

TOOL_NAME = "steer_music"
TOOL_DESCRIPTION = (
    "Steer currently playing music in real-time by changing parameters "
    "like prompt, BPM, temperature, brightness, scale, mode, guidance, "
    "density, muting options, and more. Requires active music session."
)
TOOL_PARAMETERS = {
    "type": "object",
    "properties": {
        "guild_id": {"type": "string", "description": "Discord server (guild) ID."},
        "prompt": {
            "type": "string",
            "description": "New weighted prompts (e.g., 'epic rock:1.5, calm strings:0.5').",
        },
        "bpm": {"type": "integer", "description": "New BPM (60-200)."},
        "temperature": {"type": "number", "description": "New creativity level (0.0-3.0)."},
        "brightness": {"type": "number", "description": "New tonal quality (0.0-1.0)."},
        "scale": {"type": "string", "description": _SCALE_ENUM_DESC},
        "music_generation_mode": {
            "type": "string",
            "description": "Mode: QUALITY, DIVERSITY, or VOCALIZATION.",
        },
        "guidance": {"type": "number", "description": "Prompt strictness (0.0-6.0)."},
        "density": {"type": "number", "description": "Note density (0.0-1.0)."},
        "mute_bass": {"type": "boolean", "description": "Reduce bass."},
        "mute_drums": {"type": "boolean", "description": "Reduce drums."},
        "only_bass_and_drums": {"type": "boolean", "description": "Only bass and drums."},
        "top_k": {"type": "integer", "description": "Sampling parameter (1-1000)."},
        "seed": {"type": "integer", "description": "Random seed (0-2147483647)."},
    },
    "required": ["guild_id"],
}


[docs] async def run( guild_id: str, prompt: str = None, bpm: int = None, temperature: float = None, brightness: float = None, scale: str = None, music_generation_mode: str = None, guidance: float = None, density: float = None, mute_bass: bool = None, mute_drums: bool = None, only_bass_and_drums: bool = None, top_k: int = None, seed: int = None, ctx: "ToolContext | None" = None, ) -> str: """Execute this tool and return the result.""" if ctx is None or not hasattr(ctx, "adapter") or ctx.adapter is None: return "Error: Discord adapter (ctx.adapter) is required." from tools._discord_helpers import get_discord_client from services.lyria_session import parse_prompts client = get_discord_client(ctx) if isinstance(client, str): return client lyria = ctx.adapter.lyria_service try: guild_id_int = int(guild_id) except (ValueError, TypeError): return "Error: Invalid guild_id." if not lyria.is_playing(guild_id_int): return "Error: Music is not currently playing in that server." from google.genai import types steered_params = [] config_updates: dict[str, Any] = {} if prompt: parsed_prompts = parse_prompts(prompt) if parsed_prompts: prompts = parsed_prompts else: return "Error: Invalid prompt." else: prompts = None if bpm is not None: config_updates["bpm"] = int(bpm) steered_params.append(f"BPM to {bpm}") if temperature is not None: config_updates["temperature"] = float(temperature) steered_params.append(f"temperature to {temperature}") if brightness is not None: config_updates["brightness"] = float(brightness) steered_params.append(f"brightness to {brightness}") if scale is not None: scale_enum = getattr(types.Scale, scale, None) if scale_enum is None: return f"Error: Invalid scale '{scale}'." config_updates["scale"] = scale_enum steered_params.append(f"scale to {scale}") if music_generation_mode is not None: mode_enum = getattr(types.MusicGenerationMode, music_generation_mode, None) if mode_enum is None: return f"Error: Invalid mode '{music_generation_mode}'." config_updates["music_generation_mode"] = mode_enum steered_params.append(f"generation mode to {music_generation_mode}") for param_name, param_val, lo, hi in [ ("guidance", guidance, 0.0, 6.0), ("density", density, 0.0, 1.0), ]: if param_val is not None: if not (lo <= param_val <= hi): return f"Error: {param_name} must be between {lo} and {hi}." config_updates[param_name] = float(param_val) steered_params.append(f"{param_name} to {param_val}") for bool_param_name, bool_val in [ ("mute_bass", mute_bass), ("mute_drums", mute_drums), ("only_bass_and_drums", only_bass_and_drums), ]: if bool_val is not None: config_updates[bool_param_name] = bool(bool_val) steered_params.append(f"{bool_param_name} to {bool_val}") if top_k is not None: if not (1 <= top_k <= 1000): return "Error: top_k must be between 1 and 1000." config_updates["top_k"] = int(top_k) steered_params.append(f"top_k to {top_k}") if seed is not None: if not (0 <= seed <= 2147483647): return "Error: seed must be between 0 and 2147483647." config_updates["seed"] = int(seed) steered_params.append(f"seed to {seed}") if not steered_params and prompts is None: return "No parameters were provided to steer the music." await lyria.steer( guild_id_int, prompts=prompts, config_updates=config_updates if config_updates else None, ) if prompts is not None: steered_params.insert(0, f"prompt to '{prompt}'") return f"Successfully steered music: set {', '.join(steered_params)}."