Source code for tools.music_steering

"""Real-time music steering for Lyria RealTime via ctx.adapter."""

from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
    from tool_context import ToolContext

logger = logging.getLogger(__name__)

_SCALE_ENUM_DESC = (
    "Musical scale. One of: C_MAJOR_A_MINOR, D_FLAT_MAJOR_B_FLAT_MINOR, "
    "D_MAJOR_B_MINOR, E_FLAT_MAJOR_C_MINOR, E_MAJOR_D_FLAT_MINOR, "
    "F_MAJOR_D_MINOR, G_FLAT_MAJOR_E_FLAT_MINOR, G_MAJOR_E_MINOR, "
    "A_FLAT_MAJOR_F_MINOR, A_MAJOR_G_FLAT_MINOR, B_FLAT_MAJOR_G_MINOR, "
    "B_MAJOR_A_FLAT_MINOR, SCALE_UNSPECIFIED."
)

TOOL_NAME = "steer_music"
TOOL_DESCRIPTION = (
    "Steer currently playing music in real-time by changing parameters "
    "like prompt, BPM, temperature, brightness, scale, mode, guidance, "
    "density, muting options, and more. Requires active music session."
)
TOOL_PARAMETERS = {
    "type": "object",
    "properties": {
        "guild_id": {"type": "string", "description": "Discord server (guild) ID."},
        "prompt": {
            "type": "string",
            "description": "New weighted prompts (e.g., 'epic rock:1.5, calm strings:0.5').",
        },
        "bpm": {"type": "integer", "description": "New BPM (60-200)."},
        "temperature": {
            "type": "number",
            "description": "New creativity level (0.0-3.0).",
        },
        "brightness": {"type": "number", "description": "New tonal quality (0.0-1.0)."},
        "scale": {"type": "string", "description": _SCALE_ENUM_DESC},
        "music_generation_mode": {
            "type": "string",
            "description": "Mode: QUALITY, DIVERSITY, or VOCALIZATION.",
        },
        "guidance": {"type": "number", "description": "Prompt strictness (0.0-6.0)."},
        "density": {"type": "number", "description": "Note density (0.0-1.0)."},
        "mute_bass": {"type": "boolean", "description": "Reduce bass."},
        "mute_drums": {"type": "boolean", "description": "Reduce drums."},
        "only_bass_and_drums": {
            "type": "boolean",
            "description": "Only bass and drums.",
        },
        "top_k": {"type": "integer", "description": "Sampling parameter (1-1000)."},
        "seed": {"type": "integer", "description": "Random seed (0-2147483647)."},
    },
    "required": ["guild_id"],
}


[docs] async def run( guild_id: str, prompt: str = None, bpm: int = None, temperature: float = None, brightness: float = None, scale: str = None, music_generation_mode: str = None, guidance: float = None, density: float = None, mute_bass: bool = None, mute_drums: bool = None, only_bass_and_drums: bool = None, top_k: int = None, seed: int = None, ctx: "ToolContext | None" = None, ) -> str: """Steer the currently playing Lyria RealTime music session in real time. Applies the supplied subset of generation parameters (weighted prompts, BPM, temperature, brightness, scale, generation mode, guidance, density, the bass/drums muting flags, ``top_k``, and ``seed``) to a live music stream so the model can nudge the soundtrack without restarting it. Each provided value is validated and range-checked, and a human-readable summary of exactly what changed is returned. Reaches the audio engine through the inference worker's ``ctx.adapter`` (:class:`core.proxy_adapter.ProxyPlatformAdapter`): it pulls the Discord client via :func:`tools._discord_helpers.get_discord_client`, then uses ``ctx.adapter.lyria_service`` (:class:`services.lyria_service.LyriaService`) — first calling its ``is_playing`` guard and finally its async ``steer`` to push the new prompts and ``config_updates`` to the running session. Weighted prompt strings are parsed with :func:`services.lyria_session.parse_prompts`, and ``google.genai.types`` is consulted to resolve the ``scale`` and ``music_generation_mode`` enums. As the single-tool module's entry point it is registered under ``TOOL_NAME`` (``steer_music``) by :func:`tool_loader.load_tools_from_directory` and invoked by the inference worker's tool-execution path; it is not called directly elsewhere in the repo. Args: guild_id: Discord guild (server) id of the active session; must parse to an int and have music currently playing. prompt: New weighted prompt string (e.g. ``"epic rock:1.5, calm strings:0.5"``); parsed into weighted prompts. bpm: New tempo in beats per minute. temperature: New creativity/sampling temperature. brightness: New tonal brightness. scale: Musical scale enum name resolved against ``types.Scale``. music_generation_mode: Generation mode enum name resolved against ``types.MusicGenerationMode`` (``QUALITY``/``DIVERSITY``/ ``VOCALIZATION``). guidance: Prompt-adherence strength; range-checked to ``0.0``–``6.0``. density: Note density; range-checked to ``0.0``–``1.0``. mute_bass: When set, reduce/mute the bass. mute_drums: When set, reduce/mute the drums. only_bass_and_drums: When set, restrict output to bass and drums. top_k: Sampling parameter; range-checked to ``1``–``1000``. seed: Random seed; range-checked to ``0``–``2147483647``. ctx: Tool context supplying ``adapter`` (with ``lyria_service``). Required; a missing adapter yields an error result. Returns: str: A human-readable success message naming every parameter that was steered, or an error string (missing adapter/client, invalid ``guild_id``, no active playback, an invalid value, or no parameters supplied). """ if ctx is None or not hasattr(ctx, "adapter") or ctx.adapter is None: return "Error: Discord adapter (ctx.adapter) is required." from tools._discord_helpers import get_discord_client from services.lyria_session import parse_prompts client = get_discord_client(ctx) if isinstance(client, str): return client lyria = ctx.adapter.lyria_service try: guild_id_int = int(guild_id) except (ValueError, TypeError): return "Error: Invalid guild_id." if not lyria.is_playing(guild_id_int): return "Error: Music is not currently playing in that server." from google.genai import types steered_params = [] config_updates: dict[str, Any] = {} if prompt: parsed_prompts = parse_prompts(prompt) if parsed_prompts: prompts = parsed_prompts else: return "Error: Invalid prompt." else: prompts = None if bpm is not None: config_updates["bpm"] = int(bpm) steered_params.append(f"BPM to {bpm}") if temperature is not None: config_updates["temperature"] = float(temperature) steered_params.append(f"temperature to {temperature}") if brightness is not None: config_updates["brightness"] = float(brightness) steered_params.append(f"brightness to {brightness}") if scale is not None: scale_enum = getattr(types.Scale, scale, None) if scale_enum is None: return f"Error: Invalid scale '{scale}'." config_updates["scale"] = scale_enum steered_params.append(f"scale to {scale}") if music_generation_mode is not None: mode_enum = getattr(types.MusicGenerationMode, music_generation_mode, None) if mode_enum is None: return f"Error: Invalid mode '{music_generation_mode}'." config_updates["music_generation_mode"] = mode_enum steered_params.append(f"generation mode to {music_generation_mode}") for param_name, param_val, lo, hi in [ ("guidance", guidance, 0.0, 6.0), ("density", density, 0.0, 1.0), ]: if param_val is not None: if not (lo <= param_val <= hi): return f"Error: {param_name} must be between {lo} and {hi}." config_updates[param_name] = float(param_val) steered_params.append(f"{param_name} to {param_val}") for bool_param_name, bool_val in [ ("mute_bass", mute_bass), ("mute_drums", mute_drums), ("only_bass_and_drums", only_bass_and_drums), ]: if bool_val is not None: config_updates[bool_param_name] = bool(bool_val) steered_params.append(f"{bool_param_name} to {bool_val}") if top_k is not None: if not (1 <= top_k <= 1000): return "Error: top_k must be between 1 and 1000." config_updates["top_k"] = int(top_k) steered_params.append(f"top_k to {top_k}") if seed is not None: if not (0 <= seed <= 2147483647): return "Error: seed must be between 0 and 2147483647." config_updates["seed"] = int(seed) steered_params.append(f"seed to {seed}") if not steered_params and prompts is None: return "No parameters were provided to steer the music." await lyria.steer( guild_id_int, prompts=prompts, config_updates=config_updates if config_updates else None, ) if prompts is not None: steered_params.insert(0, f"prompt to '{prompt}'") return f"Successfully steered music: set {', '.join(steered_params)}."