"""Real-time music steering for Lyria RealTime via ctx.adapter."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from tool_context import ToolContext
logger = logging.getLogger(__name__)
_SCALE_ENUM_DESC = (
"Musical scale. One of: C_MAJOR_A_MINOR, D_FLAT_MAJOR_B_FLAT_MINOR, "
"D_MAJOR_B_MINOR, E_FLAT_MAJOR_C_MINOR, E_MAJOR_D_FLAT_MINOR, "
"F_MAJOR_D_MINOR, G_FLAT_MAJOR_E_FLAT_MINOR, G_MAJOR_E_MINOR, "
"A_FLAT_MAJOR_F_MINOR, A_MAJOR_G_FLAT_MINOR, B_FLAT_MAJOR_G_MINOR, "
"B_MAJOR_A_FLAT_MINOR, SCALE_UNSPECIFIED."
)
TOOL_NAME = "steer_music"
TOOL_DESCRIPTION = (
"Steer currently playing music in real-time by changing parameters "
"like prompt, BPM, temperature, brightness, scale, mode, guidance, "
"density, muting options, and more. Requires active music session."
)
TOOL_PARAMETERS = {
"type": "object",
"properties": {
"guild_id": {"type": "string", "description": "Discord server (guild) ID."},
"prompt": {
"type": "string",
"description": "New weighted prompts (e.g., 'epic rock:1.5, calm strings:0.5').",
},
"bpm": {"type": "integer", "description": "New BPM (60-200)."},
"temperature": {"type": "number", "description": "New creativity level (0.0-3.0)."},
"brightness": {"type": "number", "description": "New tonal quality (0.0-1.0)."},
"scale": {"type": "string", "description": _SCALE_ENUM_DESC},
"music_generation_mode": {
"type": "string",
"description": "Mode: QUALITY, DIVERSITY, or VOCALIZATION.",
},
"guidance": {"type": "number", "description": "Prompt strictness (0.0-6.0)."},
"density": {"type": "number", "description": "Note density (0.0-1.0)."},
"mute_bass": {"type": "boolean", "description": "Reduce bass."},
"mute_drums": {"type": "boolean", "description": "Reduce drums."},
"only_bass_and_drums": {"type": "boolean", "description": "Only bass and drums."},
"top_k": {"type": "integer", "description": "Sampling parameter (1-1000)."},
"seed": {"type": "integer", "description": "Random seed (0-2147483647)."},
},
"required": ["guild_id"],
}
[docs]
async def run(
guild_id: str,
prompt: str = None,
bpm: int = None,
temperature: float = None,
brightness: float = None,
scale: str = None,
music_generation_mode: str = None,
guidance: float = None,
density: float = None,
mute_bass: bool = None,
mute_drums: bool = None,
only_bass_and_drums: bool = None,
top_k: int = None,
seed: int = None,
ctx: "ToolContext | None" = None,
) -> str:
"""Execute this tool and return the result."""
if ctx is None or not hasattr(ctx, "adapter") or ctx.adapter is None:
return "Error: Discord adapter (ctx.adapter) is required."
from tools._discord_helpers import get_discord_client
from services.lyria_session import parse_prompts
client = get_discord_client(ctx)
if isinstance(client, str):
return client
lyria = ctx.adapter.lyria_service
try:
guild_id_int = int(guild_id)
except (ValueError, TypeError):
return "Error: Invalid guild_id."
if not lyria.is_playing(guild_id_int):
return "Error: Music is not currently playing in that server."
from google.genai import types
steered_params = []
config_updates: dict[str, Any] = {}
if prompt:
parsed_prompts = parse_prompts(prompt)
if parsed_prompts:
prompts = parsed_prompts
else:
return "Error: Invalid prompt."
else:
prompts = None
if bpm is not None:
config_updates["bpm"] = int(bpm)
steered_params.append(f"BPM to {bpm}")
if temperature is not None:
config_updates["temperature"] = float(temperature)
steered_params.append(f"temperature to {temperature}")
if brightness is not None:
config_updates["brightness"] = float(brightness)
steered_params.append(f"brightness to {brightness}")
if scale is not None:
scale_enum = getattr(types.Scale, scale, None)
if scale_enum is None:
return f"Error: Invalid scale '{scale}'."
config_updates["scale"] = scale_enum
steered_params.append(f"scale to {scale}")
if music_generation_mode is not None:
mode_enum = getattr(types.MusicGenerationMode, music_generation_mode, None)
if mode_enum is None:
return f"Error: Invalid mode '{music_generation_mode}'."
config_updates["music_generation_mode"] = mode_enum
steered_params.append(f"generation mode to {music_generation_mode}")
for param_name, param_val, lo, hi in [
("guidance", guidance, 0.0, 6.0),
("density", density, 0.0, 1.0),
]:
if param_val is not None:
if not (lo <= param_val <= hi):
return f"Error: {param_name} must be between {lo} and {hi}."
config_updates[param_name] = float(param_val)
steered_params.append(f"{param_name} to {param_val}")
for bool_param_name, bool_val in [
("mute_bass", mute_bass),
("mute_drums", mute_drums),
("only_bass_and_drums", only_bass_and_drums),
]:
if bool_val is not None:
config_updates[bool_param_name] = bool(bool_val)
steered_params.append(f"{bool_param_name} to {bool_val}")
if top_k is not None:
if not (1 <= top_k <= 1000):
return "Error: top_k must be between 1 and 1000."
config_updates["top_k"] = int(top_k)
steered_params.append(f"top_k to {top_k}")
if seed is not None:
if not (0 <= seed <= 2147483647):
return "Error: seed must be between 0 and 2147483647."
config_updates["seed"] = int(seed)
steered_params.append(f"seed to {seed}")
if not steered_params and prompts is None:
return "No parameters were provided to steer the music."
await lyria.steer(
guild_id_int,
prompts=prompts,
config_updates=config_updates if config_updates else None,
)
if prompts is not None:
steered_params.insert(0, f"prompt to '{prompt}'")
return f"Successfully steered music: set {', '.join(steered_params)}."