import argparse from concurrent.futures import ProcessPoolExecutor import os import subprocess as sp from tempfile import NamedTemporaryFile import time import warnings import torch import gradio as gr from audiocraft.data.audio_utils import convert_audio from audiocraft.data.audio import audio_write from audiocraft.models import MusicGen from gradio.themes.utils import sizes theme = gr.themes.Default(radius_size=sizes.radius_none).set( block_label_text_color = '#4D63FF', block_title_text_color = '#4D63FF', button_primary_text_color = '#4D63FF', button_primary_background_fill='#FFFFFF', button_primary_border_color='#4D63FF', button_primary_background_fill_hover='#EDEFFF', ) MODEL = None # Last used model IS_BATCHED = "facebook/MusicGen" in os.environ.get('SPACE_ID', '') MAX_BATCH_SIZE = 12 BATCHED_DURATION = 15 INTERRUPTING = False # We have to wrap subprocess call to clean a bit the log when using gr.make_waveform _old_call = sp.call def _call_nostderr(*args, **kwargs): # Avoid ffmpeg vomitting on the logs. kwargs['stderr'] = sp.DEVNULL kwargs['stdout'] = sp.DEVNULL _old_call(*args, **kwargs) sp.call = _call_nostderr # Preallocating the pool of processes. pool = ProcessPoolExecutor(4) pool.__enter__() def interrupt(): global INTERRUPTING INTERRUPTING = True def make_waveform(*args, **kwargs): # Further remove some warnings. be = time.time() with warnings.catch_warnings(): warnings.simplefilter('ignore') out = gr.make_waveform(*args, **kwargs) print("Make a video took", time.time() - be) return out def load_model(version='melody'): global MODEL print("Loading model", version) if MODEL is None or MODEL.name != version: MODEL = MusicGen.get_pretrained(version) def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs): MODEL.set_generation_params(duration=duration, **gen_kwargs) print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies]) be = time.time() processed_melodies = [] target_sr = 32000 target_ac = 1 for melody in melodies: if melody is None: processed_melodies.append(None) else: sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t() if melody.dim() == 1: melody = melody[None] melody = melody[..., :int(sr * duration)] melody = convert_audio(melody, sr, target_sr, target_ac) processed_melodies.append(melody) if any(m is not None for m in processed_melodies): outputs = MODEL.generate_with_chroma( descriptions=texts, melody_wavs=processed_melodies, melody_sample_rate=target_sr, progress=progress, ) else: outputs = MODEL.generate(texts, progress=progress) outputs = outputs.detach().cpu().float() out_files = [] for output in outputs: with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file: audio_write( file.name, output, MODEL.sample_rate, strategy="loudness", loudness_headroom_db=16, loudness_compressor=True, add_suffix=False) out_files.append(pool.submit(make_waveform, file.name)) res = [out_file.result() for out_file in out_files] print("batch finished", len(texts), time.time() - be) return res def predict_batched(texts, melodies): max_text_length = 512 texts = [text[:max_text_length] for text in texts] load_model('melody') res = _do_predictions(texts, melodies, BATCHED_DURATION) return [res] def predict_full(model, text, melody, duration, topk, topp, temperature, cfg_coef, progress=gr.Progress()): global INTERRUPTING INTERRUPTING = False if temperature < 0: raise gr.Error("Temperature must be >= 0.") if topk < 0: raise gr.Error("Topk must be non-negative.") if topp < 0: raise gr.Error("Topp must be non-negative.") topk = int(topk) load_model(model) def _progress(generated, to_generate): progress((generated, to_generate)) if INTERRUPTING: raise gr.Error("Interrupted.") MODEL.set_custom_progress_callback(_progress) outs = _do_predictions( [text], [melody], duration, progress=True, top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef) return outs[0] def ui_full(launch_kwargs): with gr.Blocks(theme=theme, css="footer {visibility: hidden}") as interface: gr.Markdown( """