import argparse from concurrent.futures import ProcessPoolExecutor import os import subprocess as sp from tempfile import NamedTemporaryFile import time import warnings import torch import gradio as gr from audiocraft.data.audio_utils import convert_audio from audiocraft.data.audio import audio_write from audiocraft.models import MusicGen from gradio.themes.utils import sizes theme = gr.themes.Default(radius_size=sizes.radius_none).set( block_label_text_color = '#4D63FF', block_title_text_color = '#4D63FF', button_primary_text_color = '#4D63FF', button_primary_background_fill='#FFFFFF', button_primary_border_color='#4D63FF', button_primary_background_fill_hover='#EDEFFF', ) MODEL = None # Last used model IS_BATCHED = "facebook/MusicGen" in os.environ.get('SPACE_ID', '') MAX_BATCH_SIZE = 12 BATCHED_DURATION = 15 INTERRUPTING = False # We have to wrap subprocess call to clean a bit the log when using gr.make_waveform _old_call = sp.call def _call_nostderr(*args, **kwargs): # Avoid ffmpeg vomitting on the logs. kwargs['stderr'] = sp.DEVNULL kwargs['stdout'] = sp.DEVNULL _old_call(*args, **kwargs) sp.call = _call_nostderr # Preallocating the pool of processes. pool = ProcessPoolExecutor(4) pool.__enter__() def interrupt(): global INTERRUPTING INTERRUPTING = True def make_waveform(*args, **kwargs): # Further remove some warnings. be = time.time() with warnings.catch_warnings(): warnings.simplefilter('ignore') out = gr.make_waveform(*args, **kwargs) print("Make a video took", time.time() - be) return out def load_model(version='melody'): global MODEL print("Loading model", version) if MODEL is None or MODEL.name != version: MODEL = MusicGen.get_pretrained(version) def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs): MODEL.set_generation_params(duration=duration, **gen_kwargs) print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies]) be = time.time() processed_melodies = [] target_sr = 32000 target_ac = 1 for melody in melodies: if melody is None: processed_melodies.append(None) else: sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t() if melody.dim() == 1: melody = melody[None] melody = melody[..., :int(sr * duration)] melody = convert_audio(melody, sr, target_sr, target_ac) processed_melodies.append(melody) if any(m is not None for m in processed_melodies): outputs = MODEL.generate_with_chroma( descriptions=texts, melody_wavs=processed_melodies, melody_sample_rate=target_sr, progress=progress, ) else: outputs = MODEL.generate(texts, progress=progress) outputs = outputs.detach().cpu().float() out_files = [] for output in outputs: with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file: audio_write( file.name, output, MODEL.sample_rate, strategy="loudness", loudness_headroom_db=16, loudness_compressor=True, add_suffix=False) out_files.append(pool.submit(make_waveform, file.name)) res = [out_file.result() for out_file in out_files] print("batch finished", len(texts), time.time() - be) return res def predict_batched(texts, melodies): max_text_length = 512 texts = [text[:max_text_length] for text in texts] load_model('melody') res = _do_predictions(texts, melodies, BATCHED_DURATION) return [res] def predict_full(model, text, melody, duration, topk, topp, temperature, cfg_coef, progress=gr.Progress()): global INTERRUPTING INTERRUPTING = False if temperature < 0: raise gr.Error("Temperature must be >= 0.") if topk < 0: raise gr.Error("Topk must be non-negative.") if topp < 0: raise gr.Error("Topp must be non-negative.") topk = int(topk) load_model(model) def _progress(generated, to_generate): progress((generated, to_generate)) if INTERRUPTING: raise gr.Error("Interrupted.") MODEL.set_custom_progress_callback(_progress) outs = _do_predictions( [text], [melody], duration, progress=True, top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef) return outs[0] def ui_full(launch_kwargs): with gr.Blocks(theme=theme, css="footer {visibility: hidden}") as interface: gr.Markdown( """
音乐生成
""" ) with gr.Row(): with gr.Column(): with gr.Row(): text = gr.Text(label="输入文本", interactive=True) melody = gr.Audio(source="upload", type="numpy", label="旋律(可选)", interactive=True) with gr.Row(): submit = gr.Button("Submit") # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license. _ = gr.Button("中断").click(fn=interrupt, queue=False) with gr.Row(): model = gr.Radio(["melody", "medium", "small", "large"], label="模型", value="melody", interactive=True) with gr.Row(): duration = gr.Slider(minimum=1, maximum=120, value=10, label="Duration", interactive=True) with gr.Row(): topk = gr.Number(label="Top-k", value=250, interactive=True) topp = gr.Number(label="Top-p", value=0, interactive=True) temperature = gr.Number(label="Temperature", value=1.0, interactive=True) cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True) with gr.Column(): output = gr.Video(label="生成的音乐") submit.click(predict_full, inputs=[model, text, melody, duration, topk, topp, temperature, cfg_coef], outputs=[output]) gr.Examples( fn=predict_full, examples=[ [ "An 80s driving pop song with heavy drums and synth pads in the background", "./assets/bach.mp3", "melody" ], [ "A cheerful country song with acoustic guitars", "./assets/bolero_ravel.mp3", "melody" ], [ "90s rock song with electric guitar and heavy drums", None, "medium" ], [ "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions", "./assets/bach.mp3", "melody" ], [ "lofi slow bpm electro chill with organic samples", None, "medium", ], ], inputs=[text, melody, model], outputs=[output], label="例子" ) interface.queue().launch(**launch_kwargs) def ui_batched(launch_kwargs): with gr.Blocks(theme=theme, css="footer {visibility: hidden}") as demo: gr.Markdown( """
音乐生成
""" ) with gr.Row(): with gr.Column(): with gr.Row(): text = gr.Text(label="Describe your music", lines=2, interactive=True) melody = gr.Audio(source="upload", type="numpy", label="Condition on a melody (optional)", interactive=True) with gr.Row(): submit = gr.Button("Generate") with gr.Column(): output = gr.Video(label="Generated Music") submit.click(predict_batched, inputs=[text, melody], outputs=[output], batch=True, max_batch_size=MAX_BATCH_SIZE) gr.Examples( fn=predict_batched, examples=[ [ "An 80s driving pop song with heavy drums and synth pads in the background", "./assets/bach.mp3", ], [ "A cheerful country song with acoustic guitars", "./assets/bolero_ravel.mp3", ], [ "90s rock song with electric guitar and heavy drums", None, ], [ "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130", "./assets/bach.mp3", ], [ "lofi slow bpm electro chill with organic samples", None, ], ], inputs=[text, melody], outputs=[output], label="例子" ) demo.queue(max_size=8 * 4).launch(**launch_kwargs) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( '--listen', type=str, default='0.0.0.0' if 'SPACE_ID' in os.environ else '0.0.0.0', help='IP to listen on for connections to Gradio', ) parser.add_argument( '--username', type=str, default='', help='Username for authentication' ) parser.add_argument( '--password', type=str, default='', help='Password for authentication' ) parser.add_argument( '--server_port', type=int, default=0, help='Port to run the server listener on', ) parser.add_argument( '--inbrowser', action='store_true', help='Open in browser' ) parser.add_argument( '--share', action='store_true', help='Share the gradio UI' ) args = parser.parse_args() launch_kwargs = {} launch_kwargs['server_name'] = args.listen if args.username and args.password: launch_kwargs['auth'] = (args.username, args.password) if args.server_port: launch_kwargs['server_port'] = args.server_port if args.inbrowser: launch_kwargs['inbrowser'] = args.inbrowser if args.share: launch_kwargs['share'] = args.share # Show the interface if IS_BATCHED: ui_batched(launch_kwargs) else: ui_full(launch_kwargs)