music-gen
Build-Deploy-Actions
Details
Build-Deploy-Actions
Details
This commit is contained in:
parent
e619938f16
commit
7b6c296b49
|
@ -0,0 +1,47 @@
|
|||
name: Build
|
||||
run-name: ${{ github.actor }} is upgrade release 🚀
|
||||
on: [push]
|
||||
env:
|
||||
REPOSITORY: ${{ github.repository }}
|
||||
COMMIT_ID: ${{ github.sha }}
|
||||
jobs:
|
||||
Build-Deploy-Actions:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- run: echo "🎉 The job was automatically triggered by a ${{ github.event_name }} event."
|
||||
- run: echo "🐧 This job is now running on a ${{ runner.os }} server hosted by Gitea!"
|
||||
- run: echo "🔎 The name of your branch is ${{ github.ref }} and your repository is ${{ github.repository }}."
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v3
|
||||
-
|
||||
name: Setup Git LFS
|
||||
run: |
|
||||
git lfs install
|
||||
git lfs fetch
|
||||
git lfs checkout
|
||||
- name: List files in the repository
|
||||
run: |
|
||||
ls ${{ github.workspace }}
|
||||
-
|
||||
name: Docker Image Info
|
||||
id: image-info
|
||||
run: |
|
||||
echo "::set-output name=image_name::$(echo $REPOSITORY | tr '[:upper:]' '[:lower:]')"
|
||||
echo "::set-output name=image_tag::${COMMIT_ID:0:10}"
|
||||
-
|
||||
name: Login to Docker Hub
|
||||
uses: docker/login-action@v2
|
||||
with:
|
||||
registry: artifacts.iflytek.com
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v2
|
||||
-
|
||||
name: Build and push
|
||||
run: |
|
||||
docker version
|
||||
docker buildx build -t artifacts.iflytek.com/docker-private/atp/${{ steps.image-info.outputs.image_name }}:${{ steps.image-info.outputs.image_tag }} . --file ${{ github.workspace }}/Dockerfile --load
|
||||
docker push artifacts.iflytek.com/docker-private/atp/${{ steps.image-info.outputs.image_name }}:${{ steps.image-info.outputs.image_tag }}
|
||||
docker rmi artifacts.iflytek.com/docker-private/atp/${{ steps.image-info.outputs.image_name }}:${{ steps.image-info.outputs.image_tag }}
|
||||
- run: echo "🍏 This job's status is ${{ job.status }}."
|
|
@ -0,0 +1,12 @@
|
|||
FROM python:3.8.13
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY . /app
|
||||
|
||||
RUN pip config set global.index-url https://pypi.mirrors.ustc.edu.cn/simple
|
||||
RUN apt -y update && apt -y upgrade
|
||||
RUN apt -y install ffmpeg
|
||||
RUN pip install -r requirements.txt
|
||||
|
||||
CMD ["python", "app.py"]
|
|
@ -0,0 +1,304 @@
|
|||
import argparse
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
import os
|
||||
import subprocess as sp
|
||||
from tempfile import NamedTemporaryFile
|
||||
import time
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import gradio as gr
|
||||
|
||||
from audiocraft.data.audio_utils import convert_audio
|
||||
from audiocraft.data.audio import audio_write
|
||||
from audiocraft.models import MusicGen
|
||||
from gradio.themes.utils import sizes
|
||||
|
||||
|
||||
theme = gr.themes.Default(radius_size=sizes.radius_none).set(
|
||||
block_label_text_color = '#4D63FF',
|
||||
block_title_text_color = '#4D63FF',
|
||||
button_primary_text_color = '#4D63FF',
|
||||
button_primary_background_fill='#FFFFFF',
|
||||
button_primary_border_color='#4D63FF',
|
||||
button_primary_background_fill_hover='#EDEFFF',
|
||||
)
|
||||
|
||||
MODEL = None # Last used model
|
||||
IS_BATCHED = "facebook/MusicGen" in os.environ.get('SPACE_ID', '')
|
||||
MAX_BATCH_SIZE = 12
|
||||
BATCHED_DURATION = 15
|
||||
INTERRUPTING = False
|
||||
# We have to wrap subprocess call to clean a bit the log when using gr.make_waveform
|
||||
_old_call = sp.call
|
||||
|
||||
|
||||
def _call_nostderr(*args, **kwargs):
|
||||
# Avoid ffmpeg vomitting on the logs.
|
||||
kwargs['stderr'] = sp.DEVNULL
|
||||
kwargs['stdout'] = sp.DEVNULL
|
||||
_old_call(*args, **kwargs)
|
||||
|
||||
|
||||
sp.call = _call_nostderr
|
||||
# Preallocating the pool of processes.
|
||||
pool = ProcessPoolExecutor(4)
|
||||
pool.__enter__()
|
||||
|
||||
|
||||
def interrupt():
|
||||
global INTERRUPTING
|
||||
INTERRUPTING = True
|
||||
|
||||
|
||||
def make_waveform(*args, **kwargs):
|
||||
# Further remove some warnings.
|
||||
be = time.time()
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter('ignore')
|
||||
out = gr.make_waveform(*args, **kwargs)
|
||||
print("Make a video took", time.time() - be)
|
||||
return out
|
||||
|
||||
|
||||
def load_model(version='melody'):
|
||||
global MODEL
|
||||
print("Loading model", version)
|
||||
if MODEL is None or MODEL.name != version:
|
||||
MODEL = MusicGen.get_pretrained(version)
|
||||
|
||||
|
||||
def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs):
|
||||
MODEL.set_generation_params(duration=duration, **gen_kwargs)
|
||||
print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
|
||||
be = time.time()
|
||||
processed_melodies = []
|
||||
target_sr = 32000
|
||||
target_ac = 1
|
||||
for melody in melodies:
|
||||
if melody is None:
|
||||
processed_melodies.append(None)
|
||||
else:
|
||||
sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t()
|
||||
if melody.dim() == 1:
|
||||
melody = melody[None]
|
||||
melody = melody[..., :int(sr * duration)]
|
||||
melody = convert_audio(melody, sr, target_sr, target_ac)
|
||||
processed_melodies.append(melody)
|
||||
|
||||
if any(m is not None for m in processed_melodies):
|
||||
outputs = MODEL.generate_with_chroma(
|
||||
descriptions=texts,
|
||||
melody_wavs=processed_melodies,
|
||||
melody_sample_rate=target_sr,
|
||||
progress=progress,
|
||||
)
|
||||
else:
|
||||
outputs = MODEL.generate(texts, progress=progress)
|
||||
|
||||
outputs = outputs.detach().cpu().float()
|
||||
out_files = []
|
||||
for output in outputs:
|
||||
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
|
||||
audio_write(
|
||||
file.name, output, MODEL.sample_rate, strategy="loudness",
|
||||
loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
|
||||
out_files.append(pool.submit(make_waveform, file.name))
|
||||
res = [out_file.result() for out_file in out_files]
|
||||
print("batch finished", len(texts), time.time() - be)
|
||||
return res
|
||||
|
||||
|
||||
def predict_batched(texts, melodies):
|
||||
max_text_length = 512
|
||||
texts = [text[:max_text_length] for text in texts]
|
||||
load_model('melody')
|
||||
res = _do_predictions(texts, melodies, BATCHED_DURATION)
|
||||
return [res]
|
||||
|
||||
|
||||
def predict_full(model, text, melody, duration, topk, topp, temperature, cfg_coef, progress=gr.Progress()):
|
||||
global INTERRUPTING
|
||||
INTERRUPTING = False
|
||||
if temperature < 0:
|
||||
raise gr.Error("Temperature must be >= 0.")
|
||||
if topk < 0:
|
||||
raise gr.Error("Topk must be non-negative.")
|
||||
if topp < 0:
|
||||
raise gr.Error("Topp must be non-negative.")
|
||||
|
||||
topk = int(topk)
|
||||
load_model(model)
|
||||
|
||||
def _progress(generated, to_generate):
|
||||
progress((generated, to_generate))
|
||||
if INTERRUPTING:
|
||||
raise gr.Error("Interrupted.")
|
||||
MODEL.set_custom_progress_callback(_progress)
|
||||
|
||||
outs = _do_predictions(
|
||||
[text], [melody], duration, progress=True,
|
||||
top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef)
|
||||
return outs[0]
|
||||
|
||||
|
||||
def ui_full(launch_kwargs):
|
||||
with gr.Blocks(theme=theme, css="footer {visibility: hidden}") as interface:
|
||||
gr.Markdown(
|
||||
"""
|
||||
<div align='center' ><font size='60'>音乐生成</font></div>
|
||||
"""
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
text = gr.Text(label="输入文本", interactive=True)
|
||||
melody = gr.Audio(source="upload", type="numpy", label="旋律(可选)", interactive=True)
|
||||
with gr.Row():
|
||||
submit = gr.Button("Submit")
|
||||
# Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
|
||||
_ = gr.Button("中断").click(fn=interrupt, queue=False)
|
||||
with gr.Row():
|
||||
model = gr.Radio(["melody", "medium", "small", "large"], label="模型", value="melody", interactive=True)
|
||||
with gr.Row():
|
||||
duration = gr.Slider(minimum=1, maximum=120, value=10, label="Duration", interactive=True)
|
||||
with gr.Row():
|
||||
topk = gr.Number(label="Top-k", value=250, interactive=True)
|
||||
topp = gr.Number(label="Top-p", value=0, interactive=True)
|
||||
temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
|
||||
cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
|
||||
with gr.Column():
|
||||
output = gr.Video(label="生成的音乐")
|
||||
submit.click(predict_full, inputs=[model, text, melody, duration, topk, topp, temperature, cfg_coef], outputs=[output])
|
||||
gr.Examples(
|
||||
fn=predict_full,
|
||||
examples=[
|
||||
[
|
||||
"An 80s driving pop song with heavy drums and synth pads in the background",
|
||||
"./assets/bach.mp3",
|
||||
"melody"
|
||||
],
|
||||
[
|
||||
"A cheerful country song with acoustic guitars",
|
||||
"./assets/bolero_ravel.mp3",
|
||||
"melody"
|
||||
],
|
||||
[
|
||||
"90s rock song with electric guitar and heavy drums",
|
||||
None,
|
||||
"medium"
|
||||
],
|
||||
[
|
||||
"a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions",
|
||||
"./assets/bach.mp3",
|
||||
"melody"
|
||||
],
|
||||
[
|
||||
"lofi slow bpm electro chill with organic samples",
|
||||
None,
|
||||
"medium",
|
||||
],
|
||||
],
|
||||
inputs=[text, melody, model],
|
||||
outputs=[output],
|
||||
label="例子"
|
||||
)
|
||||
|
||||
interface.queue().launch(**launch_kwargs)
|
||||
|
||||
|
||||
def ui_batched(launch_kwargs):
|
||||
with gr.Blocks(theme=theme, css="footer {visibility: hidden}") as demo:
|
||||
gr.Markdown(
|
||||
"""
|
||||
<div align='center' ><font size='60'>音乐生成</font></div>
|
||||
"""
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
text = gr.Text(label="Describe your music", lines=2, interactive=True)
|
||||
melody = gr.Audio(source="upload", type="numpy", label="Condition on a melody (optional)", interactive=True)
|
||||
with gr.Row():
|
||||
submit = gr.Button("Generate")
|
||||
with gr.Column():
|
||||
output = gr.Video(label="Generated Music")
|
||||
submit.click(predict_batched, inputs=[text, melody], outputs=[output], batch=True, max_batch_size=MAX_BATCH_SIZE)
|
||||
gr.Examples(
|
||||
fn=predict_batched,
|
||||
examples=[
|
||||
[
|
||||
"An 80s driving pop song with heavy drums and synth pads in the background",
|
||||
"./assets/bach.mp3",
|
||||
],
|
||||
[
|
||||
"A cheerful country song with acoustic guitars",
|
||||
"./assets/bolero_ravel.mp3",
|
||||
],
|
||||
[
|
||||
"90s rock song with electric guitar and heavy drums",
|
||||
None,
|
||||
],
|
||||
[
|
||||
"a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130",
|
||||
"./assets/bach.mp3",
|
||||
],
|
||||
[
|
||||
"lofi slow bpm electro chill with organic samples",
|
||||
None,
|
||||
],
|
||||
],
|
||||
inputs=[text, melody],
|
||||
outputs=[output],
|
||||
label="例子"
|
||||
)
|
||||
|
||||
demo.queue(max_size=8 * 4).launch(**launch_kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--listen',
|
||||
type=str,
|
||||
default='0.0.0.0' if 'SPACE_ID' in os.environ else '0.0.0.0',
|
||||
help='IP to listen on for connections to Gradio',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--username', type=str, default='', help='Username for authentication'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--password', type=str, default='', help='Password for authentication'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--server_port',
|
||||
type=int,
|
||||
default=0,
|
||||
help='Port to run the server listener on',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--inbrowser', action='store_true', help='Open in browser'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--share', action='store_true', help='Share the gradio UI'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
launch_kwargs = {}
|
||||
launch_kwargs['server_name'] = args.listen
|
||||
|
||||
if args.username and args.password:
|
||||
launch_kwargs['auth'] = (args.username, args.password)
|
||||
if args.server_port:
|
||||
launch_kwargs['server_port'] = args.server_port
|
||||
if args.inbrowser:
|
||||
launch_kwargs['inbrowser'] = args.inbrowser
|
||||
if args.share:
|
||||
launch_kwargs['share'] = args.share
|
||||
|
||||
# Show the interface
|
||||
if IS_BATCHED:
|
||||
ui_batched(launch_kwargs)
|
||||
else:
|
||||
ui_full(launch_kwargs)
|
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,10 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# flake8: noqa
|
||||
from . import data, modules, models
|
||||
|
||||
__version__ = '0.0.2a2'
|
|
@ -0,0 +1,8 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# flake8: noqa
|
||||
from . import audio, audio_dataset
|
|
@ -0,0 +1,215 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Audio IO methods are defined in this module (info, read, write),
|
||||
We rely on av library for faster read when possible, otherwise on torchaudio.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
import logging
|
||||
import typing as tp
|
||||
|
||||
import numpy as np
|
||||
import soundfile
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
import torchaudio as ta
|
||||
|
||||
import av
|
||||
|
||||
from .audio_utils import f32_pcm, i16_pcm, normalize_audio
|
||||
|
||||
|
||||
_av_initialized = False
|
||||
|
||||
|
||||
def _init_av():
|
||||
global _av_initialized
|
||||
if _av_initialized:
|
||||
return
|
||||
logger = logging.getLogger('libav.mp3')
|
||||
logger.setLevel(logging.ERROR)
|
||||
_av_initialized = True
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AudioFileInfo:
|
||||
sample_rate: int
|
||||
duration: float
|
||||
channels: int
|
||||
|
||||
|
||||
def _av_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
|
||||
_init_av()
|
||||
with av.open(str(filepath)) as af:
|
||||
stream = af.streams.audio[0]
|
||||
sample_rate = stream.codec_context.sample_rate
|
||||
duration = float(stream.duration * stream.time_base)
|
||||
channels = stream.channels
|
||||
return AudioFileInfo(sample_rate, duration, channels)
|
||||
|
||||
|
||||
def _soundfile_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
|
||||
info = soundfile.info(filepath)
|
||||
return AudioFileInfo(info.samplerate, info.duration, info.channels)
|
||||
|
||||
|
||||
def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
|
||||
# torchaudio no longer returns useful duration informations for some formats like mp3s.
|
||||
filepath = Path(filepath)
|
||||
if filepath.suffix in ['.flac', '.ogg']: # TODO: Validate .ogg can be safely read with av_info
|
||||
# ffmpeg has some weird issue with flac.
|
||||
return _soundfile_info(filepath)
|
||||
else:
|
||||
return _av_info(filepath)
|
||||
|
||||
|
||||
def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: float = -1.) -> tp.Tuple[torch.Tensor, int]:
|
||||
"""FFMPEG-based audio file reading using PyAV bindings.
|
||||
Soundfile cannot read mp3 and av_read is more efficient than torchaudio.
|
||||
|
||||
Args:
|
||||
filepath (str or Path): Path to audio file to read.
|
||||
seek_time (float): Time at which to start reading in the file.
|
||||
duration (float): Duration to read from the file. If set to -1, the whole file is read.
|
||||
Returns:
|
||||
Tuple[torch.Tensor, int]: Tuple containing audio data and sample rate
|
||||
"""
|
||||
_init_av()
|
||||
with av.open(str(filepath)) as af:
|
||||
stream = af.streams.audio[0]
|
||||
sr = stream.codec_context.sample_rate
|
||||
num_frames = int(sr * duration) if duration >= 0 else -1
|
||||
frame_offset = int(sr * seek_time)
|
||||
# we need a small negative offset otherwise we get some edge artifact
|
||||
# from the mp3 decoder.
|
||||
af.seek(int(max(0, (seek_time - 0.1)) / stream.time_base), stream=stream)
|
||||
frames = []
|
||||
length = 0
|
||||
for frame in af.decode(streams=stream.index):
|
||||
current_offset = int(frame.rate * frame.pts * frame.time_base)
|
||||
strip = max(0, frame_offset - current_offset)
|
||||
buf = torch.from_numpy(frame.to_ndarray())
|
||||
if buf.shape[0] != stream.channels:
|
||||
buf = buf.view(-1, stream.channels).t()
|
||||
buf = buf[:, strip:]
|
||||
frames.append(buf)
|
||||
length += buf.shape[1]
|
||||
if num_frames > 0 and length >= num_frames:
|
||||
break
|
||||
assert frames
|
||||
# If the above assert fails, it is likely because we seeked past the end of file point,
|
||||
# in which case ffmpeg returns a single frame with only zeros, and a weird timestamp.
|
||||
# This will need proper debugging, in due time.
|
||||
wav = torch.cat(frames, dim=1)
|
||||
assert wav.shape[0] == stream.channels
|
||||
if num_frames > 0:
|
||||
wav = wav[:, :num_frames]
|
||||
return f32_pcm(wav), sr
|
||||
|
||||
|
||||
def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
|
||||
duration: float = -1., pad: bool = False) -> tp.Tuple[torch.Tensor, int]:
|
||||
"""Read audio by picking the most appropriate backend tool based on the audio format.
|
||||
|
||||
Args:
|
||||
filepath (str or Path): Path to audio file to read.
|
||||
seek_time (float): Time at which to start reading in the file.
|
||||
duration (float): Duration to read from the file. If set to -1, the whole file is read.
|
||||
pad (bool): Pad output audio if not reaching expected duration.
|
||||
Returns:
|
||||
Tuple[torch.Tensor, int]: Tuple containing audio data and sample rate.
|
||||
"""
|
||||
fp = Path(filepath)
|
||||
if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg
|
||||
# There is some bug with ffmpeg and reading flac
|
||||
info = _soundfile_info(filepath)
|
||||
frames = -1 if duration <= 0 else int(duration * info.sample_rate)
|
||||
frame_offset = int(seek_time * info.sample_rate)
|
||||
wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32)
|
||||
assert info.sample_rate == sr, f"Mismatch of sample rates {info.sample_rate} {sr}"
|
||||
wav = torch.from_numpy(wav).t().contiguous()
|
||||
if len(wav.shape) == 1:
|
||||
wav = torch.unsqueeze(wav, 0)
|
||||
elif (
|
||||
fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats()
|
||||
and duration <= 0 and seek_time == 0
|
||||
):
|
||||
# Torchaudio is faster if we load an entire file at once.
|
||||
wav, sr = ta.load(fp)
|
||||
else:
|
||||
wav, sr = _av_read(filepath, seek_time, duration)
|
||||
if pad and duration > 0:
|
||||
expected_frames = int(duration * sr)
|
||||
wav = F.pad(wav, (0, expected_frames - wav.shape[-1]))
|
||||
return wav, sr
|
||||
|
||||
|
||||
def audio_write(stem_name: tp.Union[str, Path],
|
||||
wav: torch.Tensor, sample_rate: int,
|
||||
format: str = 'wav', mp3_rate: int = 320, normalize: bool = True,
|
||||
strategy: str = 'peak', peak_clip_headroom_db: float = 1,
|
||||
rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
|
||||
loudness_compressor: bool = False,
|
||||
log_clipping: bool = True, make_parent_dir: bool = True,
|
||||
add_suffix: bool = True) -> Path:
|
||||
"""Convenience function for saving audio to disk. Returns the filename the audio was written to.
|
||||
|
||||
Args:
|
||||
stem_name (str or Path): Filename without extension which will be added automatically.
|
||||
format (str): Either "wav" or "mp3".
|
||||
mp3_rate (int): kbps when using mp3s.
|
||||
normalize (bool): if `True` (default), normalizes according to the prescribed
|
||||
strategy (see after). If `False`, the strategy is only used in case clipping
|
||||
would happen.
|
||||
strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
|
||||
i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
|
||||
with extra headroom to avoid clipping. 'clip' just clips.
|
||||
peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
|
||||
rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
|
||||
than the `peak_clip` one to avoid further clipping.
|
||||
loudness_headroom_db (float): Target loudness for loudness normalization.
|
||||
loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'.
|
||||
when strategy is 'loudness'log_clipping (bool): If True, basic logging on stderr when clipping still
|
||||
occurs despite strategy (only for 'rms').
|
||||
make_parent_dir (bool): Make parent directory if it doesn't exist.
|
||||
Returns:
|
||||
Path: Path of the saved audio.
|
||||
"""
|
||||
assert wav.dtype.is_floating_point, "wav is not floating point"
|
||||
if wav.dim() == 1:
|
||||
wav = wav[None]
|
||||
elif wav.dim() > 2:
|
||||
raise ValueError("Input wav should be at most 2 dimension.")
|
||||
assert wav.isfinite().all()
|
||||
wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
|
||||
rms_headroom_db, loudness_headroom_db, log_clipping=log_clipping,
|
||||
sample_rate=sample_rate, stem_name=str(stem_name))
|
||||
kwargs: dict = {}
|
||||
if format == 'mp3':
|
||||
suffix = '.mp3'
|
||||
kwargs.update({"compression": mp3_rate})
|
||||
elif format == 'wav':
|
||||
wav = i16_pcm(wav)
|
||||
suffix = '.wav'
|
||||
kwargs.update({"encoding": "PCM_S", "bits_per_sample": 16})
|
||||
else:
|
||||
raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
|
||||
if not add_suffix:
|
||||
suffix = ''
|
||||
path = Path(str(stem_name) + suffix)
|
||||
if make_parent_dir:
|
||||
path.parent.mkdir(exist_ok=True, parents=True)
|
||||
try:
|
||||
ta.save(path, wav, sample_rate, **kwargs)
|
||||
except Exception:
|
||||
if path.exists():
|
||||
# we do not want to leave half written files around.
|
||||
path.unlink()
|
||||
raise
|
||||
return path
|
|
@ -0,0 +1,525 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
from concurrent.futures import ThreadPoolExecutor, Future
|
||||
from dataclasses import dataclass, fields
|
||||
from contextlib import ExitStack
|
||||
import gzip
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import random
|
||||
import sys
|
||||
import typing as tp
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .audio import audio_read, audio_info
|
||||
from .audio_utils import convert_audio
|
||||
from .zip import PathInZip
|
||||
|
||||
try:
|
||||
import dora
|
||||
except ImportError:
|
||||
dora = None # type: ignore
|
||||
|
||||
|
||||
@dataclass(order=True)
|
||||
class BaseInfo:
|
||||
|
||||
@classmethod
|
||||
def _dict2fields(cls, dictionary: dict):
|
||||
return {
|
||||
field.name: dictionary[field.name]
|
||||
for field in fields(cls) if field.name in dictionary
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, dictionary: dict):
|
||||
_dictionary = cls._dict2fields(dictionary)
|
||||
return cls(**_dictionary)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
field.name: self.__getattribute__(field.name)
|
||||
for field in fields(self)
|
||||
}
|
||||
|
||||
|
||||
@dataclass(order=True)
|
||||
class AudioMeta(BaseInfo):
|
||||
path: str
|
||||
duration: float
|
||||
sample_rate: int
|
||||
amplitude: tp.Optional[float] = None
|
||||
weight: tp.Optional[float] = None
|
||||
# info_path is used to load additional information about the audio file that is stored in zip files.
|
||||
info_path: tp.Optional[PathInZip] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, dictionary: dict):
|
||||
base = cls._dict2fields(dictionary)
|
||||
if 'info_path' in base and base['info_path'] is not None:
|
||||
base['info_path'] = PathInZip(base['info_path'])
|
||||
return cls(**base)
|
||||
|
||||
def to_dict(self):
|
||||
d = super().to_dict()
|
||||
if d['info_path'] is not None:
|
||||
d['info_path'] = str(d['info_path'])
|
||||
return d
|
||||
|
||||
|
||||
@dataclass(order=True)
|
||||
class SegmentInfo(BaseInfo):
|
||||
meta: AudioMeta
|
||||
seek_time: float
|
||||
n_frames: int # actual number of frames without padding
|
||||
total_frames: int # total number of frames, padding included
|
||||
sample_rate: int # actual sample rate
|
||||
|
||||
|
||||
DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_audio_meta(file_path: str, minimal: bool = True) -> AudioMeta:
|
||||
"""AudioMeta from a path to an audio file.
|
||||
|
||||
Args:
|
||||
file_path (str): Resolved path of valid audio file.
|
||||
minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
|
||||
Returns:
|
||||
AudioMeta: Audio file path and its metadata.
|
||||
"""
|
||||
info = audio_info(file_path)
|
||||
amplitude: tp.Optional[float] = None
|
||||
if not minimal:
|
||||
wav, sr = audio_read(file_path)
|
||||
amplitude = wav.abs().max().item()
|
||||
return AudioMeta(file_path, info.duration, info.sample_rate, amplitude)
|
||||
|
||||
|
||||
def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta:
|
||||
"""If Dora is available as a dependency, try to resolve potential relative paths
|
||||
in list of AudioMeta. This method is expected to be used when loading meta from file.
|
||||
|
||||
Args:
|
||||
m (AudioMeta): Audio meta to resolve.
|
||||
fast (bool): If True, uses a really fast check for determining if a file is already absolute or not.
|
||||
Only valid on Linux/Mac.
|
||||
Returns:
|
||||
AudioMeta: Audio meta with resolved path.
|
||||
"""
|
||||
def is_abs(m):
|
||||
if fast:
|
||||
return str(m)[0] == '/'
|
||||
else:
|
||||
os.path.isabs(str(m))
|
||||
|
||||
if not dora:
|
||||
return m
|
||||
|
||||
if not is_abs(m.path):
|
||||
m.path = dora.git_save.to_absolute_path(m.path)
|
||||
if m.info_path is not None and not is_abs(m.info_path.zip_path):
|
||||
m.info_path.zip_path = dora.git_save.to_absolute_path(m.path)
|
||||
return m
|
||||
|
||||
|
||||
def find_audio_files(path: tp.Union[Path, str],
|
||||
exts: tp.List[str] = DEFAULT_EXTS,
|
||||
resolve: bool = True,
|
||||
minimal: bool = True,
|
||||
progress: bool = False,
|
||||
workers: int = 0) -> tp.List[AudioMeta]:
|
||||
"""Build a list of AudioMeta from a given path,
|
||||
collecting relevant audio files and fetching meta info.
|
||||
|
||||
Args:
|
||||
path (str or Path): Path to folder containing audio files.
|
||||
exts (list of str): List of file extensions to consider for audio files.
|
||||
minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
|
||||
progress (bool): Whether to log progress on audio files collection.
|
||||
workers (int): number of parallel workers, if 0, use only the current thread.
|
||||
Returns:
|
||||
List[AudioMeta]: List of audio file path and its metadata.
|
||||
"""
|
||||
audio_files = []
|
||||
futures: tp.List[Future] = []
|
||||
pool: tp.Optional[ThreadPoolExecutor] = None
|
||||
with ExitStack() as stack:
|
||||
if workers > 0:
|
||||
pool = ThreadPoolExecutor(workers)
|
||||
stack.enter_context(pool)
|
||||
|
||||
if progress:
|
||||
print("Finding audio files...")
|
||||
for root, folders, files in os.walk(path, followlinks=True):
|
||||
for file in files:
|
||||
full_path = Path(root) / file
|
||||
if full_path.suffix.lower() in exts:
|
||||
audio_files.append(full_path)
|
||||
if pool is not None:
|
||||
futures.append(pool.submit(_get_audio_meta, str(audio_files[-1]), minimal))
|
||||
if progress:
|
||||
print(format(len(audio_files), " 8d"), end='\r', file=sys.stderr)
|
||||
|
||||
if progress:
|
||||
print("Getting audio metadata...")
|
||||
meta: tp.List[AudioMeta] = []
|
||||
for idx, file_path in enumerate(audio_files):
|
||||
try:
|
||||
if pool is None:
|
||||
m = _get_audio_meta(str(file_path), minimal)
|
||||
else:
|
||||
m = futures[idx].result()
|
||||
if resolve:
|
||||
m = _resolve_audio_meta(m)
|
||||
except Exception as err:
|
||||
print("Error with", str(file_path), err, file=sys.stderr)
|
||||
continue
|
||||
meta.append(m)
|
||||
if progress:
|
||||
print(format((1 + idx) / len(audio_files), " 3.1%"), end='\r', file=sys.stderr)
|
||||
meta.sort()
|
||||
return meta
|
||||
|
||||
|
||||
def load_audio_meta(path: tp.Union[str, Path],
|
||||
resolve: bool = True, fast: bool = True) -> tp.List[AudioMeta]:
|
||||
"""Load list of AudioMeta from an optionally compressed json file.
|
||||
|
||||
Args:
|
||||
path (str or Path): Path to JSON file.
|
||||
resolve (bool): Whether to resolve the path from AudioMeta (default=True).
|
||||
fast (bool): activates some tricks to make things faster.
|
||||
Returns:
|
||||
List[AudioMeta]: List of audio file path and its total duration.
|
||||
"""
|
||||
open_fn = gzip.open if str(path).lower().endswith('.gz') else open
|
||||
with open_fn(path, 'rb') as fp: # type: ignore
|
||||
lines = fp.readlines()
|
||||
meta = []
|
||||
for line in lines:
|
||||
d = json.loads(line)
|
||||
m = AudioMeta.from_dict(d)
|
||||
if resolve:
|
||||
m = _resolve_audio_meta(m, fast=fast)
|
||||
meta.append(m)
|
||||
return meta
|
||||
|
||||
|
||||
def save_audio_meta(path: tp.Union[str, Path], meta: tp.List[AudioMeta]):
|
||||
"""Save the audio metadata to the file pointer as json.
|
||||
|
||||
Args:
|
||||
path (str or Path): Path to JSON file.
|
||||
metadata (list of BaseAudioMeta): List of audio meta to save.
|
||||
"""
|
||||
Path(path).parent.mkdir(exist_ok=True, parents=True)
|
||||
open_fn = gzip.open if str(path).lower().endswith('.gz') else open
|
||||
with open_fn(path, 'wb') as fp: # type: ignore
|
||||
for m in meta:
|
||||
json_str = json.dumps(m.to_dict()) + '\n'
|
||||
json_bytes = json_str.encode('utf-8')
|
||||
fp.write(json_bytes)
|
||||
|
||||
|
||||
class AudioDataset:
|
||||
"""Base audio dataset.
|
||||
|
||||
The dataset takes a list of AudioMeta and create a dataset composed of segments of audio
|
||||
and potentially additional information, by creating random segments from the list of audio
|
||||
files referenced in the metadata and applying minimal data pre-processing such as resampling,
|
||||
mixing of channels, padding, etc.
|
||||
|
||||
If no segment_duration value is provided, the AudioDataset will return the full wav for each
|
||||
audio file. Otherwise, it will randomly sample audio files and create a segment of the specified
|
||||
duration, applying padding if required.
|
||||
|
||||
By default, only the torch Tensor corresponding to the waveform is returned. Setting return_info=True
|
||||
allows to return a tuple containing the torch Tensor and additional metadata on the segment and the
|
||||
original audio meta.
|
||||
|
||||
Args:
|
||||
meta (tp.List[AudioMeta]): List of audio files metadata.
|
||||
segment_duration (float): Optional segment duration of audio to load.
|
||||
If not specified, the dataset will load the full audio segment from the file.
|
||||
shuffle (bool): Set to `True` to have the data reshuffled at every epoch.
|
||||
sample_rate (int): Target sample rate of the loaded audio samples.
|
||||
channels (int): Target number of channels of the loaded audio samples.
|
||||
sample_on_duration (bool): Set to `True` to sample segments with probability
|
||||
dependent on audio file duration. This is only used if `segment_duration` is provided.
|
||||
sample_on_weight (bool): Set to `True` to sample segments using the `weight` entry of
|
||||
`AudioMeta`. If `sample_on_duration` is also True, the actual weight will be the product
|
||||
of the file duration and file weight. This is only used if `segment_duration` is provided.
|
||||
min_segment_ratio (float): Minimum segment ratio to use when the audio file
|
||||
is shorter than the desired segment.
|
||||
max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset.
|
||||
return_info (bool): Whether to return the wav only or return wav along with segment info and metadata.
|
||||
min_audio_duration (tp.Optional[float], optional): Minimum audio file duration, in seconds, if provided
|
||||
audio shorter than this will be filtered out.
|
||||
max_audio_duration (tp.Optional[float], optional): Maximal audio file duration in seconds, if provided
|
||||
audio longer than this will be filtered out.
|
||||
"""
|
||||
def __init__(self,
|
||||
meta: tp.List[AudioMeta],
|
||||
segment_duration: tp.Optional[float] = None,
|
||||
shuffle: bool = True,
|
||||
num_samples: int = 10_000,
|
||||
sample_rate: int = 48_000,
|
||||
channels: int = 2,
|
||||
pad: bool = True,
|
||||
sample_on_duration: bool = True,
|
||||
sample_on_weight: bool = True,
|
||||
min_segment_ratio: float = 0.5,
|
||||
max_read_retry: int = 10,
|
||||
return_info: bool = False,
|
||||
min_audio_duration: tp.Optional[float] = None,
|
||||
max_audio_duration: tp.Optional[float] = None
|
||||
):
|
||||
assert len(meta) > 0, 'No audio meta provided to AudioDataset. Please check loading of audio meta.'
|
||||
assert segment_duration is None or segment_duration > 0
|
||||
assert segment_duration is None or min_segment_ratio >= 0
|
||||
logging.debug(f'sample_on_duration: {sample_on_duration}')
|
||||
logging.debug(f'sample_on_weight: {sample_on_weight}')
|
||||
logging.debug(f'pad: {pad}')
|
||||
logging.debug(f'min_segment_ratio: {min_segment_ratio}')
|
||||
|
||||
self.segment_duration = segment_duration
|
||||
self.min_segment_ratio = min_segment_ratio
|
||||
self.max_audio_duration = max_audio_duration
|
||||
self.min_audio_duration = min_audio_duration
|
||||
if self.min_audio_duration is not None and self.max_audio_duration is not None:
|
||||
assert self.min_audio_duration <= self.max_audio_duration
|
||||
self.meta: tp.List[AudioMeta] = self._filter_duration(meta)
|
||||
assert len(self.meta) # Fail fast if all data has been filtered.
|
||||
self.total_duration = sum(d.duration for d in self.meta)
|
||||
|
||||
if segment_duration is None:
|
||||
num_samples = len(self.meta)
|
||||
self.num_samples = num_samples
|
||||
self.shuffle = shuffle
|
||||
self.sample_rate = sample_rate
|
||||
self.channels = channels
|
||||
self.pad = pad
|
||||
self.sample_on_weight = sample_on_weight
|
||||
self.sample_on_duration = sample_on_duration
|
||||
self.sampling_probabilities = self._get_sampling_probabilities()
|
||||
self.max_read_retry = max_read_retry
|
||||
self.return_info = return_info
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def _get_sampling_probabilities(self, normalized: bool = True):
|
||||
"""Return the sampling probabilities for each file inside `self.meta`.
|
||||
"""
|
||||
scores: tp.List[float] = []
|
||||
for file_meta in self.meta:
|
||||
score = 1.
|
||||
if self.sample_on_weight and file_meta.weight is not None:
|
||||
score *= file_meta.weight
|
||||
if self.sample_on_duration:
|
||||
score *= file_meta.duration
|
||||
scores.append(score)
|
||||
probabilities = torch.tensor(scores)
|
||||
if normalized:
|
||||
probabilities /= probabilities.sum()
|
||||
return probabilities
|
||||
|
||||
def sample_file(self, rng: torch.Generator) -> AudioMeta:
|
||||
"""Sample a given file from `self.meta`. Can be overriden in subclasses.
|
||||
This is only called if `segment_duration` is not None.
|
||||
|
||||
You must use the provided random number generator `rng` for reproducibility.
|
||||
"""
|
||||
if not self.sample_on_weight and not self.sample_on_duration:
|
||||
file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
|
||||
else:
|
||||
file_index = int(torch.multinomial(self.sampling_probabilities, 1, generator=rng).item())
|
||||
|
||||
return self.meta[file_index]
|
||||
|
||||
def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]:
|
||||
if self.segment_duration is None:
|
||||
file_meta = self.meta[index]
|
||||
out, sr = audio_read(file_meta.path)
|
||||
out = convert_audio(out, sr, self.sample_rate, self.channels)
|
||||
n_frames = out.shape[-1]
|
||||
segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames,
|
||||
sample_rate=self.sample_rate)
|
||||
else:
|
||||
rng = torch.Generator()
|
||||
if self.shuffle:
|
||||
# We use index, plus extra randomness
|
||||
rng.manual_seed(index + self.num_samples * random.randint(0, 2**24))
|
||||
else:
|
||||
# We only use index
|
||||
rng.manual_seed(index)
|
||||
|
||||
for retry in range(self.max_read_retry):
|
||||
file_meta = self.sample_file(rng)
|
||||
# We add some variance in the file position even if audio file is smaller than segment
|
||||
# without ending up with empty segments
|
||||
max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio)
|
||||
seek_time = torch.rand(1, generator=rng).item() * max_seek
|
||||
try:
|
||||
out, sr = audio_read(file_meta.path, seek_time, self.segment_duration, pad=False)
|
||||
out = convert_audio(out, sr, self.sample_rate, self.channels)
|
||||
n_frames = out.shape[-1]
|
||||
target_frames = int(self.segment_duration * self.sample_rate)
|
||||
if self.pad:
|
||||
out = F.pad(out, (0, target_frames - n_frames))
|
||||
segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames,
|
||||
sample_rate=self.sample_rate)
|
||||
except Exception as exc:
|
||||
logger.warning("Error opening file %s: %r", file_meta.path, exc)
|
||||
if retry == self.max_read_retry - 1:
|
||||
raise
|
||||
else:
|
||||
break
|
||||
|
||||
if self.return_info:
|
||||
# Returns the wav and additional information on the wave segment
|
||||
return out, segment_info
|
||||
else:
|
||||
return out
|
||||
|
||||
def collater(self, samples):
|
||||
"""The collater function has to be provided to the dataloader
|
||||
if AudioDataset has return_info=True in order to properly collate
|
||||
the samples of a batch.
|
||||
"""
|
||||
if self.segment_duration is None and len(samples) > 1:
|
||||
assert self.pad, "Must allow padding when batching examples of different durations."
|
||||
|
||||
# In this case the audio reaching the collater is of variable length as segment_duration=None.
|
||||
to_pad = self.segment_duration is None and self.pad
|
||||
if to_pad:
|
||||
max_len = max([wav.shape[-1] for wav, _ in samples])
|
||||
|
||||
def _pad_wav(wav):
|
||||
return F.pad(wav, (0, max_len - wav.shape[-1]))
|
||||
|
||||
if self.return_info:
|
||||
if len(samples) > 0:
|
||||
assert len(samples[0]) == 2
|
||||
assert isinstance(samples[0][0], torch.Tensor)
|
||||
assert isinstance(samples[0][1], SegmentInfo)
|
||||
|
||||
wavs = [wav for wav, _ in samples]
|
||||
segment_infos = [copy.deepcopy(info) for _, info in samples]
|
||||
|
||||
if to_pad:
|
||||
# Each wav could be of a different duration as they are not segmented.
|
||||
for i in range(len(samples)):
|
||||
# Determines the total legth of the signal with padding, so we update here as we pad.
|
||||
segment_infos[i].total_frames = max_len
|
||||
wavs[i] = _pad_wav(wavs[i])
|
||||
|
||||
wav = torch.stack(wavs)
|
||||
return wav, segment_infos
|
||||
else:
|
||||
assert isinstance(samples[0], torch.Tensor)
|
||||
if to_pad:
|
||||
samples = [_pad_wav(s) for s in samples]
|
||||
return torch.stack(samples)
|
||||
|
||||
def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
|
||||
"""Filters out audio files with short durations.
|
||||
Removes from meta files that have durations that will not allow to samples examples from them.
|
||||
"""
|
||||
orig_len = len(meta)
|
||||
|
||||
# Filter data that is too short.
|
||||
if self.min_audio_duration is not None:
|
||||
meta = [m for m in meta if m.duration >= self.min_audio_duration]
|
||||
|
||||
# Filter data that is too long.
|
||||
if self.max_audio_duration is not None:
|
||||
meta = [m for m in meta if m.duration <= self.max_audio_duration]
|
||||
|
||||
filtered_len = len(meta)
|
||||
removed_percentage = 100*(1-float(filtered_len)/orig_len)
|
||||
msg = 'Removed %.2f percent of the data because it was too short or too long.' % removed_percentage
|
||||
if removed_percentage < 10:
|
||||
logging.debug(msg)
|
||||
else:
|
||||
logging.warning(msg)
|
||||
return meta
|
||||
|
||||
@classmethod
|
||||
def from_meta(cls, root: tp.Union[str, Path], **kwargs):
|
||||
"""Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file.
|
||||
|
||||
Args:
|
||||
root (str or Path): Path to root folder containing audio files.
|
||||
kwargs: Additional keyword arguments for the AudioDataset.
|
||||
"""
|
||||
root = Path(root)
|
||||
if root.is_dir():
|
||||
if (root / 'data.jsonl').exists():
|
||||
root = root / 'data.jsonl'
|
||||
elif (root / 'data.jsonl.gz').exists():
|
||||
root = root / 'data.jsonl.gz'
|
||||
else:
|
||||
raise ValueError("Don't know where to read metadata from in the dir. "
|
||||
"Expecting either a data.jsonl or data.jsonl.gz file but none found.")
|
||||
meta = load_audio_meta(root)
|
||||
return cls(meta, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True,
|
||||
exts: tp.List[str] = DEFAULT_EXTS, **kwargs):
|
||||
"""Instantiate AudioDataset from a path containing (possibly nested) audio files.
|
||||
|
||||
Args:
|
||||
root (str or Path): Path to root folder containing audio files.
|
||||
minimal_meta (bool): Whether to only load minimal metadata or not.
|
||||
exts (list of str): Extensions for audio files.
|
||||
kwargs: Additional keyword arguments for the AudioDataset.
|
||||
"""
|
||||
root = Path(root)
|
||||
if root.is_file():
|
||||
meta = load_audio_meta(root, resolve=True)
|
||||
else:
|
||||
meta = find_audio_files(root, exts, minimal=minimal_meta, resolve=True)
|
||||
return cls(meta, **kwargs)
|
||||
|
||||
|
||||
def main():
|
||||
logging.basicConfig(stream=sys.stderr, level=logging.INFO)
|
||||
parser = argparse.ArgumentParser(
|
||||
prog='audio_dataset',
|
||||
description='Generate .jsonl files by scanning a folder.')
|
||||
parser.add_argument('root', help='Root folder with all the audio files')
|
||||
parser.add_argument('output_meta_file',
|
||||
help='Output file to store the metadata, ')
|
||||
parser.add_argument('--complete',
|
||||
action='store_false', dest='minimal', default=True,
|
||||
help='Retrieve all metadata, even the one that are expansive '
|
||||
'to compute (e.g. normalization).')
|
||||
parser.add_argument('--resolve',
|
||||
action='store_true', default=False,
|
||||
help='Resolve the paths to be absolute and with no symlinks.')
|
||||
parser.add_argument('--workers',
|
||||
default=10, type=int,
|
||||
help='Number of workers.')
|
||||
args = parser.parse_args()
|
||||
meta = find_audio_files(args.root, DEFAULT_EXTS, progress=True,
|
||||
resolve=args.resolve, minimal=args.minimal, workers=args.workers)
|
||||
save_audio_meta(args.output_meta_file, meta)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,174 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import sys
|
||||
import typing as tp
|
||||
|
||||
import julius
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
|
||||
def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor:
|
||||
"""Convert audio to the given number of channels.
|
||||
|
||||
Args:
|
||||
wav (torch.Tensor): Audio wave of shape [B, C, T].
|
||||
channels (int): Expected number of channels as output.
|
||||
Returns:
|
||||
torch.Tensor: Downmixed or unchanged audio wave [B, C, T].
|
||||
"""
|
||||
*shape, src_channels, length = wav.shape
|
||||
if src_channels == channels:
|
||||
pass
|
||||
elif channels == 1:
|
||||
# Case 1:
|
||||
# The caller asked 1-channel audio, and the stream has multiple
|
||||
# channels, downmix all channels.
|
||||
wav = wav.mean(dim=-2, keepdim=True)
|
||||
elif src_channels == 1:
|
||||
# Case 2:
|
||||
# The caller asked for multiple channels, but the input file has
|
||||
# a single channel, replicate the audio over all channels.
|
||||
wav = wav.expand(*shape, channels, length)
|
||||
elif src_channels >= channels:
|
||||
# Case 3:
|
||||
# The caller asked for multiple channels, and the input file has
|
||||
# more channels than requested. In that case return the first channels.
|
||||
wav = wav[..., :channels, :]
|
||||
else:
|
||||
# Case 4: What is a reasonable choice here?
|
||||
raise ValueError('The audio file has less channels than requested but is not mono.')
|
||||
return wav
|
||||
|
||||
|
||||
def convert_audio(wav: torch.Tensor, from_rate: float,
|
||||
to_rate: float, to_channels: int) -> torch.Tensor:
|
||||
"""Convert audio to new sample rate and number of audio channels.
|
||||
"""
|
||||
wav = julius.resample_frac(wav, int(from_rate), int(to_rate))
|
||||
wav = convert_audio_channels(wav, to_channels)
|
||||
return wav
|
||||
|
||||
|
||||
def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14,
|
||||
loudness_compressor: bool = False, energy_floor: float = 2e-3):
|
||||
"""Normalize an input signal to a user loudness in dB LKFS.
|
||||
Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.
|
||||
|
||||
Args:
|
||||
wav (torch.Tensor): Input multichannel audio data.
|
||||
sample_rate (int): Sample rate.
|
||||
loudness_headroom_db (float): Target loudness of the output in dB LUFS.
|
||||
loudness_compressor (bool): Uses tanh for soft clipping.
|
||||
energy_floor (float): anything below that RMS level will not be rescaled.
|
||||
Returns:
|
||||
output (torch.Tensor): Loudness normalized output data.
|
||||
"""
|
||||
energy = wav.pow(2).mean().sqrt().item()
|
||||
if energy < energy_floor:
|
||||
return wav
|
||||
transform = torchaudio.transforms.Loudness(sample_rate)
|
||||
input_loudness_db = transform(wav).item()
|
||||
# calculate the gain needed to scale to the desired loudness level
|
||||
delta_loudness = -loudness_headroom_db - input_loudness_db
|
||||
gain = 10.0 ** (delta_loudness / 20.0)
|
||||
output = gain * wav
|
||||
if loudness_compressor:
|
||||
output = torch.tanh(output)
|
||||
assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt())
|
||||
return output
|
||||
|
||||
|
||||
def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: tp.Optional[str] = None) -> None:
|
||||
"""Utility function to clip the audio with logging if specified."""
|
||||
max_scale = wav.abs().max()
|
||||
if log_clipping and max_scale > 1:
|
||||
clamp_prob = (wav.abs() > 1).float().mean().item()
|
||||
print(f"CLIPPING {stem_name or ''} happening with proba (a bit of clipping is okay):",
|
||||
clamp_prob, "maximum scale: ", max_scale.item(), file=sys.stderr)
|
||||
wav.clamp_(-1, 1)
|
||||
|
||||
|
||||
def normalize_audio(wav: torch.Tensor, normalize: bool = True,
|
||||
strategy: str = 'peak', peak_clip_headroom_db: float = 1,
|
||||
rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
|
||||
loudness_compressor: bool = False, log_clipping: bool = False,
|
||||
sample_rate: tp.Optional[int] = None,
|
||||
stem_name: tp.Optional[str] = None) -> torch.Tensor:
|
||||
"""Normalize the audio according to the prescribed strategy (see after).
|
||||
|
||||
Args:
|
||||
wav (torch.Tensor): Audio data.
|
||||
normalize (bool): if `True` (default), normalizes according to the prescribed
|
||||
strategy (see after). If `False`, the strategy is only used in case clipping
|
||||
would happen.
|
||||
strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
|
||||
i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
|
||||
with extra headroom to avoid clipping. 'clip' just clips.
|
||||
peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
|
||||
rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
|
||||
than the `peak_clip` one to avoid further clipping.
|
||||
loudness_headroom_db (float): Target loudness for loudness normalization.
|
||||
loudness_compressor (bool): If True, uses tanh based soft clipping.
|
||||
log_clipping (bool): If True, basic logging on stderr when clipping still
|
||||
occurs despite strategy (only for 'rms').
|
||||
sample_rate (int): Sample rate for the audio data (required for loudness).
|
||||
stem_name (Optional[str]): Stem name for clipping logging.
|
||||
Returns:
|
||||
torch.Tensor: Normalized audio.
|
||||
"""
|
||||
scale_peak = 10 ** (-peak_clip_headroom_db / 20)
|
||||
scale_rms = 10 ** (-rms_headroom_db / 20)
|
||||
if strategy == 'peak':
|
||||
rescaling = (scale_peak / wav.abs().max())
|
||||
if normalize or rescaling < 1:
|
||||
wav = wav * rescaling
|
||||
elif strategy == 'clip':
|
||||
wav = wav.clamp(-scale_peak, scale_peak)
|
||||
elif strategy == 'rms':
|
||||
mono = wav.mean(dim=0)
|
||||
rescaling = scale_rms / mono.pow(2).mean().sqrt()
|
||||
if normalize or rescaling < 1:
|
||||
wav = wav * rescaling
|
||||
_clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
|
||||
elif strategy == 'loudness':
|
||||
assert sample_rate is not None, "Loudness normalization requires sample rate."
|
||||
wav = normalize_loudness(wav, sample_rate, loudness_headroom_db, loudness_compressor)
|
||||
_clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
|
||||
else:
|
||||
assert wav.abs().max() < 1
|
||||
assert strategy == '' or strategy == 'none', f"Unexpected strategy: '{strategy}'"
|
||||
return wav
|
||||
|
||||
|
||||
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert audio to float 32 bits PCM format.
|
||||
"""
|
||||
if wav.dtype.is_floating_point:
|
||||
return wav
|
||||
else:
|
||||
assert wav.dtype == torch.int16
|
||||
return wav.float() / 2**15
|
||||
|
||||
|
||||
def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert audio to int 16 bits PCM format.
|
||||
|
||||
..Warning:: There exist many formula for doing this convertion. None are perfect
|
||||
due to the asymetry of the int16 range. One either have possible clipping, DC offset,
|
||||
or inconsistancies with f32_pcm. If the given wav doesn't have enough headroom,
|
||||
it is possible that `i16_pcm(f32_pcm)) != Identity`.
|
||||
"""
|
||||
if wav.dtype.is_floating_point:
|
||||
assert wav.abs().max() <= 1
|
||||
candidate = (wav * 2 ** 15).round()
|
||||
if candidate.max() >= 2 ** 15: # clipping would occur
|
||||
candidate = (wav * (2 ** 15 - 1)).round()
|
||||
return candidate.short()
|
||||
else:
|
||||
assert wav.dtype == torch.int16
|
||||
return wav
|
|
@ -0,0 +1,74 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import typing
|
||||
import zipfile
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
DEFAULT_SIZE = 32
|
||||
MODE = Literal['r', 'w', 'x', 'a']
|
||||
|
||||
|
||||
@dataclass(order=True)
|
||||
class PathInZip:
|
||||
"""Class for holding a path of file within a zip file.
|
||||
|
||||
Args:
|
||||
path: The convention is <path_to_zip>:<relative_path_inside_zip>
|
||||
Let's assume there is a zip file /some/location/foo.zip
|
||||
and inside of it is a json file located at /data/file1.json,
|
||||
Then we expect path = "/some/location/foo.zip:/data/file1.json"
|
||||
"""
|
||||
|
||||
INFO_PATH_SEP = ':'
|
||||
zip_path: str
|
||||
file_path: str
|
||||
|
||||
def __init__(self, path: str) -> None:
|
||||
split_path = path.split(self.INFO_PATH_SEP)
|
||||
assert len(split_path) == 2
|
||||
self.zip_path, self.file_path = split_path
|
||||
|
||||
@classmethod
|
||||
def from_paths(cls, zip_path: str, file_path: str):
|
||||
return cls(zip_path + cls.INFO_PATH_SEP + file_path)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.zip_path + self.INFO_PATH_SEP + self.file_path
|
||||
|
||||
|
||||
def _open_zip(path: str, mode: MODE = 'r'):
|
||||
return zipfile.ZipFile(path, mode)
|
||||
|
||||
|
||||
_cached_open_zip = lru_cache(DEFAULT_SIZE)(_open_zip)
|
||||
|
||||
|
||||
def set_zip_cache_size(max_size: int):
|
||||
"""Sets the maximal LRU caching for zip file opening.
|
||||
|
||||
Args:
|
||||
max_size: the maximal LRU cache.
|
||||
"""
|
||||
global _cached_open_zip
|
||||
_cached_open_zip = lru_cache(max_size)(_open_zip)
|
||||
|
||||
|
||||
def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO:
|
||||
"""Opens a file stored inside a zip and returns a file-like object.
|
||||
|
||||
Args:
|
||||
path_in_zip: A PathInZip object representing the file to return a file-like object of.
|
||||
mode: The mode in which to open the file with.
|
||||
Returns:
|
||||
A file-like object for PathInZip.
|
||||
"""
|
||||
zf = _cached_open_zip(path_in_zip.zip_path)
|
||||
return zf.open(path_in_zip.file_path)
|
|
@ -0,0 +1,10 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# flake8: noqa
|
||||
from .musicgen import MusicGen
|
||||
from .lm import LMModel
|
||||
from .encodec import CompressionModel, EncodecModel
|
|
@ -0,0 +1,218 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
All the functions to build the relevant models and modules
|
||||
from the Hydra config.
|
||||
"""
|
||||
|
||||
import typing as tp
|
||||
import warnings
|
||||
|
||||
import audiocraft
|
||||
import omegaconf
|
||||
import torch
|
||||
|
||||
from .encodec import CompressionModel, EncodecModel, FlattenedCompressionModel # noqa
|
||||
from .lm import LMModel
|
||||
from ..modules.codebooks_patterns import (
|
||||
CodebooksPatternProvider,
|
||||
DelayedPatternProvider,
|
||||
ParallelPatternProvider,
|
||||
UnrolledPatternProvider,
|
||||
VALLEPattern,
|
||||
MusicLMPattern,
|
||||
)
|
||||
from ..modules.conditioners import (
|
||||
BaseConditioner,
|
||||
ConditioningProvider,
|
||||
LUTConditioner,
|
||||
T5Conditioner,
|
||||
ConditionFuser,
|
||||
ChromaStemConditioner,
|
||||
)
|
||||
from .. import quantization as qt
|
||||
from ..utils.utils import dict_from_config
|
||||
|
||||
|
||||
def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> qt.BaseQuantizer:
|
||||
klass = {
|
||||
'no_quant': qt.DummyQuantizer,
|
||||
'rvq': qt.ResidualVectorQuantizer
|
||||
}[quantizer]
|
||||
kwargs = dict_from_config(getattr(cfg, quantizer))
|
||||
if quantizer != 'no_quant':
|
||||
kwargs['dimension'] = dimension
|
||||
return klass(**kwargs)
|
||||
|
||||
|
||||
def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
|
||||
if encoder_name == 'seanet':
|
||||
kwargs = dict_from_config(getattr(cfg, 'seanet'))
|
||||
encoder_override_kwargs = kwargs.pop('encoder')
|
||||
decoder_override_kwargs = kwargs.pop('decoder')
|
||||
encoder_kwargs = {**kwargs, **encoder_override_kwargs}
|
||||
decoder_kwargs = {**kwargs, **decoder_override_kwargs}
|
||||
encoder = audiocraft.modules.SEANetEncoder(**encoder_kwargs)
|
||||
decoder = audiocraft.modules.SEANetDecoder(**decoder_kwargs)
|
||||
return encoder, decoder
|
||||
else:
|
||||
raise KeyError(f'Unexpected compression model {cfg.compression_model}')
|
||||
|
||||
|
||||
def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel:
|
||||
"""Instantiate a compression model.
|
||||
"""
|
||||
if cfg.compression_model == 'encodec':
|
||||
kwargs = dict_from_config(getattr(cfg, 'encodec'))
|
||||
encoder_name = kwargs.pop('autoencoder')
|
||||
quantizer_name = kwargs.pop('quantizer')
|
||||
encoder, decoder = get_encodec_autoencoder(encoder_name, cfg)
|
||||
quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
|
||||
frame_rate = kwargs['sample_rate'] // encoder.hop_length
|
||||
renormalize = kwargs.pop('renormalize', None)
|
||||
renorm = kwargs.pop('renorm')
|
||||
if renormalize is None:
|
||||
renormalize = renorm is not None
|
||||
warnings.warn("You are using a deprecated EnCodec model. Please migrate to new renormalization.")
|
||||
return EncodecModel(encoder, decoder, quantizer,
|
||||
frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
|
||||
else:
|
||||
raise KeyError(f'Unexpected compression model {cfg.compression_model}')
|
||||
|
||||
|
||||
def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
|
||||
"""Instantiate a transformer LM.
|
||||
"""
|
||||
if cfg.lm_model == 'transformer_lm':
|
||||
kwargs = dict_from_config(getattr(cfg, 'transformer_lm'))
|
||||
n_q = kwargs['n_q']
|
||||
q_modeling = kwargs.pop('q_modeling', None)
|
||||
codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
|
||||
attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
|
||||
cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
|
||||
cfg_prob, cfg_coef = cls_free_guidance["training_dropout"], cls_free_guidance["inference_coef"]
|
||||
fuser = get_condition_fuser(cfg)
|
||||
condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device)
|
||||
if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programatically
|
||||
kwargs['cross_attention'] = True
|
||||
if codebooks_pattern_cfg.modeling is None:
|
||||
assert q_modeling is not None, \
|
||||
'LM model should either have a codebook pattern defined or transformer_lm.q_modeling'
|
||||
codebooks_pattern_cfg = omegaconf.OmegaConf.create(
|
||||
{'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}}
|
||||
)
|
||||
pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
|
||||
return LMModel(
|
||||
pattern_provider=pattern_provider,
|
||||
condition_provider=condition_provider,
|
||||
fuser=fuser,
|
||||
cfg_dropout=cfg_prob,
|
||||
cfg_coef=cfg_coef,
|
||||
attribute_dropout=attribute_dropout,
|
||||
dtype=getattr(torch, cfg.dtype),
|
||||
device=cfg.device,
|
||||
**kwargs
|
||||
).to(cfg.device)
|
||||
else:
|
||||
raise KeyError(f'Unexpected LM model {cfg.lm_model}')
|
||||
|
||||
|
||||
def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider:
|
||||
"""Instantiate a conditioning model.
|
||||
"""
|
||||
device = cfg.device
|
||||
duration = cfg.dataset.segment_duration
|
||||
cfg = getattr(cfg, "conditioners")
|
||||
cfg = omegaconf.OmegaConf.create({}) if cfg is None else cfg
|
||||
conditioners: tp.Dict[str, BaseConditioner] = {}
|
||||
with omegaconf.open_dict(cfg):
|
||||
condition_provider_args = cfg.pop('args', {})
|
||||
for cond, cond_cfg in cfg.items():
|
||||
model_type = cond_cfg["model"]
|
||||
model_args = cond_cfg[model_type]
|
||||
if model_type == "t5":
|
||||
conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args)
|
||||
elif model_type == "lut":
|
||||
conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args)
|
||||
elif model_type == "chroma_stem":
|
||||
model_args.pop('cache_path', None)
|
||||
conditioners[str(cond)] = ChromaStemConditioner(
|
||||
output_dim=output_dim,
|
||||
duration=duration,
|
||||
device=device,
|
||||
**model_args
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unrecognized conditioning model: {model_type}")
|
||||
conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args)
|
||||
return conditioner
|
||||
|
||||
|
||||
def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
|
||||
"""Instantiate a condition fuser object.
|
||||
"""
|
||||
fuser_cfg = getattr(cfg, "fuser")
|
||||
fuser_methods = ["sum", "cross", "prepend", "input_interpolate"]
|
||||
fuse2cond = {k: fuser_cfg[k] for k in fuser_methods}
|
||||
kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
|
||||
fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
|
||||
return fuser
|
||||
|
||||
|
||||
def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider:
|
||||
"""Instantiate a codebooks pattern provider object.
|
||||
"""
|
||||
pattern_providers = {
|
||||
'parallel': ParallelPatternProvider,
|
||||
'delay': DelayedPatternProvider,
|
||||
'unroll': UnrolledPatternProvider,
|
||||
'valle': VALLEPattern,
|
||||
'musiclm': MusicLMPattern,
|
||||
}
|
||||
name = cfg.modeling
|
||||
kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
|
||||
klass = pattern_providers[name]
|
||||
return klass(n_q, **kwargs)
|
||||
|
||||
|
||||
def get_debug_compression_model(device='cpu'):
|
||||
"""Instantiate a debug compression model to be used for unit tests.
|
||||
"""
|
||||
seanet_kwargs = {
|
||||
'n_filters': 4,
|
||||
'n_residual_layers': 1,
|
||||
'dimension': 32,
|
||||
'ratios': [10, 8, 16] # 25 Hz at 32kHz
|
||||
}
|
||||
encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs)
|
||||
decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs)
|
||||
quantizer = qt.ResidualVectorQuantizer(dimension=32, bins=400, n_q=4)
|
||||
init_x = torch.randn(8, 32, 128)
|
||||
quantizer(init_x, 1) # initialize kmeans etc.
|
||||
compression_model = EncodecModel(
|
||||
encoder, decoder, quantizer,
|
||||
frame_rate=25, sample_rate=32000, channels=1).to(device)
|
||||
return compression_model.eval()
|
||||
|
||||
|
||||
def get_debug_lm_model(device='cpu'):
|
||||
"""Instantiate a debug LM to be used for unit tests.
|
||||
"""
|
||||
pattern = DelayedPatternProvider(n_q=4)
|
||||
dim = 16
|
||||
providers = {
|
||||
'description': LUTConditioner(n_bins=128, dim=dim, output_dim=dim, tokenizer="whitespace"),
|
||||
}
|
||||
condition_provider = ConditioningProvider(providers)
|
||||
fuser = ConditionFuser(
|
||||
{'cross': ['description'], 'prepend': [],
|
||||
'sum': [], 'input_interpolate': []})
|
||||
lm = LMModel(
|
||||
pattern, condition_provider, fuser,
|
||||
n_q=4, card=400, dim=dim, num_heads=4, custom=True, num_layers=2,
|
||||
cross_attention=True, causal=True)
|
||||
return lm.to(device).eval()
|
|
@ -0,0 +1,302 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import typing as tp
|
||||
|
||||
from einops import rearrange
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .. import quantization as qt
|
||||
|
||||
|
||||
class CompressionModel(ABC, nn.Module):
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
||||
"""See `EncodecModel.encode`"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
||||
"""See `EncodecModel.decode`"""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def channels(self) -> int:
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def frame_rate(self) -> int:
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def sample_rate(self) -> int:
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def cardinality(self) -> int:
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def num_codebooks(self) -> int:
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def total_codebooks(self) -> int:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def set_num_codebooks(self, n: int):
|
||||
"""Set the active number of codebooks used by the quantizer.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class EncodecModel(CompressionModel):
|
||||
"""Encodec model operating on the raw waveform.
|
||||
|
||||
Args:
|
||||
encoder (nn.Module): Encoder network.
|
||||
decoder (nn.Module): Decoder network.
|
||||
quantizer (qt.BaseQuantizer): Quantizer network.
|
||||
frame_rate (int): Frame rate for the latent representation.
|
||||
sample_rate (int): Audio sample rate.
|
||||
channels (int): Number of audio channels.
|
||||
causal (bool): Whether to use a causal version of the model.
|
||||
renormalize (bool): Whether to renormalize the audio before running the model.
|
||||
"""
|
||||
# we need assignement to override the property in the abstract class,
|
||||
# I couldn't find a better way...
|
||||
frame_rate: int = 0
|
||||
sample_rate: int = 0
|
||||
channels: int = 0
|
||||
|
||||
def __init__(self,
|
||||
encoder: nn.Module,
|
||||
decoder: nn.Module,
|
||||
quantizer: qt.BaseQuantizer,
|
||||
frame_rate: int,
|
||||
sample_rate: int,
|
||||
channels: int,
|
||||
causal: bool = False,
|
||||
renormalize: bool = False):
|
||||
super().__init__()
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
self.quantizer = quantizer
|
||||
self.frame_rate = frame_rate
|
||||
self.sample_rate = sample_rate
|
||||
self.channels = channels
|
||||
self.renormalize = renormalize
|
||||
self.causal = causal
|
||||
if self.causal:
|
||||
# we force disabling here to avoid handling linear overlap of segments
|
||||
# as supported in original EnCodec codebase.
|
||||
assert not self.renormalize, 'Causal model does not support renormalize'
|
||||
|
||||
@property
|
||||
def total_codebooks(self):
|
||||
"""Total number of quantizer codebooks available.
|
||||
"""
|
||||
return self.quantizer.total_codebooks
|
||||
|
||||
@property
|
||||
def num_codebooks(self):
|
||||
"""Active number of codebooks used by the quantizer.
|
||||
"""
|
||||
return self.quantizer.num_codebooks
|
||||
|
||||
def set_num_codebooks(self, n: int):
|
||||
"""Set the active number of codebooks used by the quantizer.
|
||||
"""
|
||||
self.quantizer.set_num_codebooks(n)
|
||||
|
||||
@property
|
||||
def cardinality(self):
|
||||
"""Cardinality of each codebook.
|
||||
"""
|
||||
return self.quantizer.bins
|
||||
|
||||
def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
||||
scale: tp.Optional[torch.Tensor]
|
||||
if self.renormalize:
|
||||
mono = x.mean(dim=1, keepdim=True)
|
||||
volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
|
||||
scale = 1e-8 + volume
|
||||
x = x / scale
|
||||
scale = scale.view(-1, 1)
|
||||
else:
|
||||
scale = None
|
||||
return x, scale
|
||||
|
||||
def postprocess(self,
|
||||
x: torch.Tensor,
|
||||
scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if scale is not None:
|
||||
assert self.renormalize
|
||||
x = x * scale.view(-1, 1, 1)
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
|
||||
assert x.dim() == 3
|
||||
length = x.shape[-1]
|
||||
x, scale = self.preprocess(x)
|
||||
|
||||
emb = self.encoder(x)
|
||||
q_res = self.quantizer(emb, self.frame_rate)
|
||||
out = self.decoder(q_res.x)
|
||||
|
||||
# remove extra padding added by the encoder and decoder
|
||||
assert out.shape[-1] >= length, (out.shape[-1], length)
|
||||
out = out[..., :length]
|
||||
|
||||
q_res.x = self.postprocess(out, scale)
|
||||
|
||||
return q_res
|
||||
|
||||
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
||||
"""Encode the given input tensor to quantized representation along with scale parameter.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Float tensor of shape [B, C, T]
|
||||
|
||||
Returns:
|
||||
codes, scale (tp.Tuple[torch.Tensor, torch.Tensor]): Tuple composed of:
|
||||
codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
|
||||
scale a float tensor containing the scale for audio renormalizealization.
|
||||
"""
|
||||
assert x.dim() == 3
|
||||
x, scale = self.preprocess(x)
|
||||
emb = self.encoder(x)
|
||||
codes = self.quantizer.encode(emb)
|
||||
return codes, scale
|
||||
|
||||
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
||||
"""Decode the given codes to a reconstructed representation, using the scale to perform
|
||||
audio denormalization if needed.
|
||||
|
||||
Args:
|
||||
codes (torch.Tensor): Int tensor of shape [B, K, T]
|
||||
scale (tp.Optional[torch.Tensor]): Float tensor containing the scale value.
|
||||
|
||||
Returns:
|
||||
out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
|
||||
"""
|
||||
emb = self.quantizer.decode(codes)
|
||||
out = self.decoder(emb)
|
||||
out = self.postprocess(out, scale)
|
||||
# out contains extra padding added by the encoder and decoder
|
||||
return out
|
||||
|
||||
|
||||
class FlattenedCompressionModel(CompressionModel):
|
||||
"""Wraps a CompressionModel and flatten its codebooks, e.g.
|
||||
instead of returning [B, K, T], return [B, S, T * (K // S)] with
|
||||
S the number of codebooks per step, and `K // S` the number of 'virtual steps'
|
||||
for each real time step.
|
||||
|
||||
Args:
|
||||
model (CompressionModel): compression model to wrap.
|
||||
codebooks_per_step (int): number of codebooks to keep per step,
|
||||
this must divide the number of codebooks provided by the wrapped model.
|
||||
extend_cardinality (bool): if True, and for instance if codebooks_per_step = 1,
|
||||
if each codebook has a cardinality N, then the first codebook will
|
||||
use the range [0, N - 1], and the second [N, 2 N - 1] etc.
|
||||
On decoding, this can lead to potentially invalid sequences.
|
||||
Any invalid entry will be silently remapped to the proper range
|
||||
with a modulo.
|
||||
"""
|
||||
def __init__(self, model: CompressionModel, codebooks_per_step: int = 1,
|
||||
extend_cardinality: bool = True):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.codebooks_per_step = codebooks_per_step
|
||||
self.extend_cardinality = extend_cardinality
|
||||
|
||||
@property
|
||||
def total_codebooks(self):
|
||||
return self.model.total_codebooks
|
||||
|
||||
@property
|
||||
def num_codebooks(self):
|
||||
"""Active number of codebooks used by the quantizer.
|
||||
|
||||
..Warning:: this reports the number of codebooks after the flattening
|
||||
of the codebooks!
|
||||
"""
|
||||
assert self.model.num_codebooks % self.codebooks_per_step == 0
|
||||
return self.codebooks_per_step
|
||||
|
||||
def set_num_codebooks(self, n: int):
|
||||
"""Set the active number of codebooks used by the quantizer.
|
||||
|
||||
..Warning:: this sets the number of codebooks **before** the flattening
|
||||
of the codebooks.
|
||||
"""
|
||||
assert n % self.codebooks_per_step == 0
|
||||
self.model.set_num_codebooks(n)
|
||||
|
||||
@property
|
||||
def num_virtual_steps(self) -> int:
|
||||
"""Return the number of virtual steps, e.g. one real step
|
||||
will be split into that many steps.
|
||||
"""
|
||||
return self.model.num_codebooks // self.codebooks_per_step
|
||||
|
||||
@property
|
||||
def frame_rate(self) -> int:
|
||||
return self.model.frame_rate * self.num_virtual_steps
|
||||
|
||||
@property
|
||||
def sample_rate(self) -> int:
|
||||
return self.model.sample_rate
|
||||
|
||||
@property
|
||||
def channels(self) -> int:
|
||||
return self.model.channels
|
||||
|
||||
@property
|
||||
def cardinality(self):
|
||||
"""Cardinality of each codebook.
|
||||
"""
|
||||
if self.extend_cardinality:
|
||||
return self.model.cardinality * self.num_virtual_steps
|
||||
else:
|
||||
return self.model.cardinality
|
||||
|
||||
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
|
||||
raise NotImplementedError("Not supported, use encode and decode.")
|
||||
|
||||
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
|
||||
indices, scales = self.model.encode(x)
|
||||
B, K, T = indices.shape
|
||||
indices = rearrange(indices, 'b (k v) t -> b k t v', k=self.codebooks_per_step)
|
||||
if self.extend_cardinality:
|
||||
for virtual_step in range(1, self.num_virtual_steps):
|
||||
indices[..., virtual_step] += self.model.cardinality * virtual_step
|
||||
indices = rearrange(indices, 'b k t v -> b k (t v)')
|
||||
return (indices, scales)
|
||||
|
||||
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
|
||||
B, K, T = codes.shape
|
||||
assert T % self.num_virtual_steps == 0
|
||||
codes = rearrange(codes, 'b k (t v) -> b (k v) t', v=self.num_virtual_steps)
|
||||
# We silently ignore potential errors from the LM when
|
||||
# using extend_cardinality.
|
||||
codes = codes % self.model.cardinality
|
||||
return self.model.decode(codes, scale)
|
|
@ -0,0 +1,527 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
import logging
|
||||
import math
|
||||
import typing as tp
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ..utils import utils
|
||||
from ..modules.streaming import StreamingModule, State
|
||||
from ..modules.transformer import StreamingTransformer, create_norm_fn
|
||||
from ..modules.conditioners import (
|
||||
ConditionFuser,
|
||||
ClassifierFreeGuidanceDropout,
|
||||
AttributeDropout,
|
||||
ConditioningProvider,
|
||||
ConditioningAttributes,
|
||||
ConditionType,
|
||||
)
|
||||
from ..modules.codebooks_patterns import CodebooksPatternProvider
|
||||
from ..modules.activations import get_activation_fn
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
ConditionTensors = tp.Dict[str, ConditionType]
|
||||
CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]]
|
||||
|
||||
|
||||
def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None):
|
||||
"""LM layer initialization.
|
||||
Inspired from xlformers: https://github.com/fairinternal/xlformers
|
||||
|
||||
Args:
|
||||
method (str): Method name for init function. Valid options are:
|
||||
'gaussian', 'uniform'.
|
||||
input_dim (int): Input dimension of the initialized module.
|
||||
init_depth (Optional[int]): Optional init depth value used to rescale
|
||||
the standard deviation if defined.
|
||||
"""
|
||||
# Compute std
|
||||
std = 1 / math.sqrt(input_dim)
|
||||
# Rescale with depth
|
||||
if init_depth is not None:
|
||||
std = std / math.sqrt(2 * init_depth)
|
||||
|
||||
if method == 'gaussian':
|
||||
return partial(
|
||||
torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std
|
||||
)
|
||||
elif method == 'uniform':
|
||||
bound = math.sqrt(3) * std # ensure the standard deviation is `std`
|
||||
return partial(torch.nn.init.uniform_, a=-bound, b=bound)
|
||||
else:
|
||||
raise ValueError("Unsupported layer initialization method")
|
||||
|
||||
|
||||
def init_layer(m: nn.Module,
|
||||
method: str,
|
||||
init_depth: tp.Optional[int] = None,
|
||||
zero_bias_init: bool = False):
|
||||
"""Wrapper around ``get_init_fn`` for proper initialization of LM modules.
|
||||
|
||||
Args:
|
||||
m (nn.Module): Module to initialize.
|
||||
method (str): Method name for the init function.
|
||||
init_depth (Optional[int]): Optional init depth value used to rescale
|
||||
the standard deviation if defined.
|
||||
zero_bias_init (bool): Whether to initialize the bias to 0 or not.
|
||||
"""
|
||||
if isinstance(m, nn.Linear):
|
||||
init_fn = get_init_fn(method, m.in_features, init_depth=init_depth)
|
||||
if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
|
||||
weight = m.weight.float()
|
||||
init_fn(weight)
|
||||
m.weight.data[:] = weight.half()
|
||||
else:
|
||||
init_fn(m.weight)
|
||||
if zero_bias_init and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Embedding):
|
||||
init_fn = get_init_fn(method, m.embedding_dim, init_depth=None)
|
||||
if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
|
||||
weight = m.weight.float()
|
||||
init_fn(weight)
|
||||
m.weight.data[:] = weight.half()
|
||||
else:
|
||||
init_fn(m.weight)
|
||||
|
||||
|
||||
class ScaledEmbedding(nn.Embedding):
|
||||
"""Boost learning rate for embeddings (with `scale`).
|
||||
"""
|
||||
def __init__(self, *args, lr=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lr = lr
|
||||
|
||||
def make_optim_group(self):
|
||||
group = {"params": list(self.parameters())}
|
||||
if self.lr is not None:
|
||||
group["lr"] = self.lr
|
||||
return group
|
||||
|
||||
|
||||
@dataclass
|
||||
class LMOutput:
|
||||
# The logits are already re-aligned with the input codes
|
||||
# hence no extra shift is required, e.g. when computing CE
|
||||
logits: torch.Tensor # [B, K, T, card]
|
||||
mask: torch.Tensor # [B, K, T]
|
||||
|
||||
|
||||
class LMModel(StreamingModule):
|
||||
"""Transformer-based language model on multiple streams of codes.
|
||||
|
||||
Args:
|
||||
pattern_provider (CodebooksPatternProvider): Pattern provider for codebook interleaving.
|
||||
condition_provider (MusicConditioningProvider): Conditioning provider from metadata.
|
||||
fuser (ConditionFuser): Fuser handling the fusing of conditions with language model input.
|
||||
n_q (int): Number of parallel streams to model.
|
||||
card (int): Cardinality, vocabulary size.
|
||||
dim (int): Dimension of the transformer encoder.
|
||||
num_heads (int): Number of heads for the transformer encoder.
|
||||
hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder.
|
||||
norm (str): Normalization method.
|
||||
norm_first (bool): Use pre-norm instead of post-norm.
|
||||
emb_lr (Optional[float]): Embedding-specific learning rate.
|
||||
bias_proj (bool): Use bias for output projections.
|
||||
weight_init (Optional[str]): Method for weight initialization.
|
||||
depthwise_init (Optional[str]): Method for depthwise weight initialization.
|
||||
zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros.
|
||||
cfg_dropout (float): Classifier-free guidance dropout.
|
||||
cfg_coef (float): Classifier-free guidance coefficient.
|
||||
attribute_dropout (dict): Attribute dropout probabilities.
|
||||
two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps.
|
||||
**kwargs: Additional parameters for the transformer encoder.
|
||||
"""
|
||||
def __init__(self, pattern_provider: CodebooksPatternProvider, condition_provider: ConditioningProvider,
|
||||
fuser: ConditionFuser, n_q: int = 8, card: int = 1024, dim: int = 128, num_heads: int = 8,
|
||||
hidden_scale: int = 4, norm: str = 'layer_norm', norm_first: bool = False,
|
||||
emb_lr: tp.Optional[float] = None, bias_proj: bool = True,
|
||||
weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None,
|
||||
zero_bias_init: bool = False, cfg_dropout: float = 0, cfg_coef: float = 1.0,
|
||||
attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {}, two_step_cfg: bool = False,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.cfg_coef = cfg_coef
|
||||
self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout)
|
||||
self.att_dropout = AttributeDropout(p=attribute_dropout)
|
||||
self.condition_provider = condition_provider
|
||||
self.fuser = fuser
|
||||
self.card = card
|
||||
embed_dim = self.card + 1
|
||||
self.n_q = n_q
|
||||
self.dim = dim
|
||||
self.pattern_provider = pattern_provider
|
||||
self.two_step_cfg = two_step_cfg
|
||||
self.emb = nn.ModuleList([ScaledEmbedding(embed_dim, dim, lr=emb_lr) for _ in range(n_q)])
|
||||
if 'activation' in kwargs:
|
||||
kwargs['activation'] = get_activation_fn(kwargs['activation'])
|
||||
self.transformer = StreamingTransformer(
|
||||
d_model=dim, num_heads=num_heads, dim_feedforward=int(hidden_scale * dim),
|
||||
norm=norm, norm_first=norm_first, **kwargs)
|
||||
self.out_norm: tp.Optional[nn.Module] = None
|
||||
if norm_first:
|
||||
self.out_norm = create_norm_fn(norm, dim)
|
||||
self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=bias_proj) for _ in range(n_q)])
|
||||
self._init_weights(weight_init, depthwise_init, zero_bias_init)
|
||||
self._fsdp: tp.Optional[nn.Module]
|
||||
self.__dict__['_fsdp'] = None
|
||||
|
||||
def _init_weights(self, weight_init: tp.Optional[str], depthwise_init: tp.Optional[str], zero_bias_init: bool):
|
||||
"""Initialization of the transformer module weights.
|
||||
|
||||
Args:
|
||||
weight_init (Optional[str]): Weight initialization strategy. See ``get_init_fn`` for valid options.
|
||||
depthwise_init (Optional[str]): Depwthwise initialization strategy. The following options are valid:
|
||||
'current' where the depth corresponds to the current layer index or 'global' where the total number
|
||||
of layer is used as depth. If not set, no depthwise initialization strategy is used.
|
||||
zero_bias_init (bool): Whether to initalize bias to zero or not.
|
||||
"""
|
||||
assert depthwise_init is None or depthwise_init in ['current', 'global']
|
||||
assert depthwise_init is None or weight_init is not None, \
|
||||
"If 'depthwise_init' is defined, a 'weight_init' method should be provided."
|
||||
assert not zero_bias_init or weight_init is not None, \
|
||||
"If 'zero_bias_init', a 'weight_init' method should be provided"
|
||||
|
||||
if weight_init is None:
|
||||
return
|
||||
|
||||
for emb_layer in self.emb:
|
||||
init_layer(emb_layer, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
|
||||
|
||||
for layer_idx, tr_layer in enumerate(self.transformer.layers):
|
||||
depth = None
|
||||
if depthwise_init == 'current':
|
||||
depth = layer_idx + 1
|
||||
elif depthwise_init == 'global':
|
||||
depth = len(self.transformer.layers)
|
||||
init_fn = partial(init_layer, method=weight_init, init_depth=depth, zero_bias_init=zero_bias_init)
|
||||
tr_layer.apply(init_fn)
|
||||
|
||||
for linear in self.linears:
|
||||
init_layer(linear, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
|
||||
|
||||
@property
|
||||
def special_token_id(self) -> int:
|
||||
return self.card
|
||||
|
||||
@property
|
||||
def num_codebooks(self) -> int:
|
||||
return self.n_q
|
||||
|
||||
def forward(self, sequence: torch.Tensor,
|
||||
conditions: tp.List[ConditioningAttributes],
|
||||
condition_tensors: tp.Optional[ConditionTensors] = None) -> torch.Tensor:
|
||||
"""Apply language model on sequence and conditions.
|
||||
Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and
|
||||
S the sequence steps, return the logits with shape [B, card, K, S].
|
||||
|
||||
Args:
|
||||
indices (torch.Tensor): indices of the codes to model.
|
||||
conditions (list[ConditioningAttributes]): conditionings to use when modeling
|
||||
the given codes. Note that when evaluating multiple time with the same conditioning
|
||||
you should pre-compute those and pass them as `condition_tensors`.
|
||||
condition_tensors (dict[str, ConditionType] or None): pre-computed conditioning
|
||||
tensors, see `conditions`.
|
||||
Returns:
|
||||
torch.Tensor: Logits.
|
||||
"""
|
||||
B, K, S = sequence.shape
|
||||
assert K == self.num_codebooks, 'Sequence shape must match the specified number of codebooks'
|
||||
input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
|
||||
if condition_tensors is None:
|
||||
assert not self._is_streaming, "Conditions tensors should be precomputed when streaming."
|
||||
# apply dropout modules
|
||||
conditions = self.cfg_dropout(conditions)
|
||||
conditions = self.att_dropout(conditions)
|
||||
tokenized = self.condition_provider.tokenize(conditions)
|
||||
# encode conditions and fuse, both have a streaming cache to not recompute when generating.
|
||||
condition_tensors = self.condition_provider(tokenized)
|
||||
else:
|
||||
assert not conditions, "Shouldn't pass both conditions and condition_tensors."
|
||||
|
||||
input_, cross_attention_input = self.fuser(input_, condition_tensors)
|
||||
|
||||
out = self.transformer(input_, cross_attention_src=cross_attention_input)
|
||||
if self.out_norm:
|
||||
out = self.out_norm(out)
|
||||
logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1) # [B, K, S, card]
|
||||
|
||||
# remove the prefix from the model outputs
|
||||
if len(self.fuser.fuse2cond['prepend']) > 0:
|
||||
logits = logits[:, :, -S:]
|
||||
|
||||
return logits # [B, K, S, card]
|
||||
|
||||
def compute_predictions(
|
||||
self, codes: torch.Tensor,
|
||||
conditions: tp.List[ConditioningAttributes],
|
||||
condition_tensors: tp.Optional[ConditionTensors] = None) -> LMOutput:
|
||||
"""Given an input tensor of codes [B, K, T] and list of conditions, runs the model
|
||||
forward using the specified codes interleaving pattern.
|
||||
|
||||
Args:
|
||||
codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size,
|
||||
K the number of codebooks and T the number of timesteps.
|
||||
conditions (list[ConditioningAttributes]): conditionings to use when modeling
|
||||
the given codes. Note that when evaluating multiple time with the same conditioning
|
||||
you should pre-compute those and pass them as `condition_tensors`.
|
||||
condition_tensors (dict[str, ConditionType] or None): pre-computed conditioning
|
||||
tensors, see `conditions`.
|
||||
Returns:
|
||||
LMOutput: Language model outputs
|
||||
logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes,
|
||||
i.e. the first item corresponds to logits to predict the first code, meaning that
|
||||
no additional shifting of codes and logits is required.
|
||||
mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions.
|
||||
Given the specified interleaving strategies, parts of the logits and codes should
|
||||
not be considered as valid predictions because of invalid context.
|
||||
"""
|
||||
B, K, T = codes.shape
|
||||
codes = codes.contiguous()
|
||||
# map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
|
||||
pattern = self.pattern_provider.get_pattern(T)
|
||||
sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
|
||||
codes, self.special_token_id, keep_only_valid_steps=True
|
||||
)
|
||||
# apply model on pattern sequence
|
||||
model = self if self._fsdp is None else self._fsdp
|
||||
logits = model(sequence_codes, conditions, condition_tensors) # [B, K, S, card]
|
||||
# map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card]
|
||||
# and provide the corresponding mask over invalid positions of tokens
|
||||
logits = logits.permute(0, 3, 1, 2) # [B, card, K, S]
|
||||
# note: we use nans as special token to make it obvious if we feed unexpected logits
|
||||
logits, logits_indexes, logits_mask = pattern.revert_pattern_logits(
|
||||
logits, float('nan'), keep_only_valid_steps=True
|
||||
)
|
||||
logits = logits.permute(0, 2, 3, 1) # [B, K, T, card]
|
||||
logits_mask = logits_mask[None, :, :].expand(B, -1, -1) # [K, T] -> [B, K, T]
|
||||
return LMOutput(logits, logits_mask)
|
||||
|
||||
def _sample_next_token(self,
|
||||
sequence: torch.Tensor,
|
||||
cfg_conditions: CFGConditions,
|
||||
unconditional_state: State,
|
||||
use_sampling: bool = False,
|
||||
temp: float = 1.0,
|
||||
top_k: int = 0,
|
||||
top_p: float = 0.0,
|
||||
cfg_coef: tp.Optional[float] = None) -> torch.Tensor:
|
||||
"""Sample next token from the model given a sequence and a set of conditions. The model supports
|
||||
multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).
|
||||
|
||||
Args:
|
||||
sequence (torch.Tensor): Current sequence of shape [B, K, S]
|
||||
with K corresponding to the number of codebooks and S the number of sequence steps.
|
||||
S = 1 in streaming mode, except for the first step that contains a bigger prompt.
|
||||
condition_tensors (Dict[str, ConditionType): Set of conditions. If CFG is used,
|
||||
should be twice the batch size, being the concatenation of the conditions + null conditions.
|
||||
use_sampling (bool): Whether to use a sampling strategy or not.
|
||||
temp (float): Sampling temperature.
|
||||
top_k (int): K for "top-k" sampling.
|
||||
top_p (float): P for "top-p" sampling.
|
||||
cfg_coef (float): classifier free guidance coefficient
|
||||
Returns:
|
||||
next_token (torch.Tensor): Next token tensor of shape [B, K, 1].
|
||||
"""
|
||||
B = sequence.shape[0]
|
||||
cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
|
||||
model = self if self._fsdp is None else self._fsdp
|
||||
if self.two_step_cfg and cfg_conditions != {}:
|
||||
assert isinstance(cfg_conditions, tuple)
|
||||
condition_tensors, null_condition_tensors = cfg_conditions
|
||||
cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors)
|
||||
state = self.get_streaming_state()
|
||||
self.set_streaming_state(unconditional_state)
|
||||
uncond_logits = model(sequence, conditions=[], condition_tensors=null_condition_tensors)
|
||||
unconditional_state.update(self.get_streaming_state())
|
||||
self.set_streaming_state(state)
|
||||
logits = uncond_logits + (cond_logits - uncond_logits) * self.cfg_coef
|
||||
else:
|
||||
assert isinstance(cfg_conditions, dict)
|
||||
condition_tensors = cfg_conditions
|
||||
if condition_tensors:
|
||||
# Preparing for CFG, predicting both conditional and unconditional logits.
|
||||
sequence = torch.cat([sequence, sequence], dim=0)
|
||||
all_logits = model(
|
||||
sequence,
|
||||
conditions=[], condition_tensors=condition_tensors)
|
||||
if condition_tensors:
|
||||
cond_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card]
|
||||
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_coef
|
||||
else:
|
||||
logits = all_logits
|
||||
|
||||
logits = logits.permute(0, 1, 3, 2) # [B, K, card, T]
|
||||
logits = logits[..., -1] # [B x K x card]
|
||||
|
||||
# Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
|
||||
if use_sampling and temp > 0.0:
|
||||
probs = torch.softmax(logits / temp, dim=-1)
|
||||
if top_p > 0.0:
|
||||
next_token = utils.sample_top_p(probs, p=top_p)
|
||||
elif top_k > 0:
|
||||
next_token = utils.sample_top_k(probs, k=top_k)
|
||||
else:
|
||||
next_token = utils.multinomial(probs, num_samples=1)
|
||||
else:
|
||||
next_token = torch.argmax(logits, dim=-1, keepdim=True)
|
||||
|
||||
return next_token
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self,
|
||||
prompt: tp.Optional[torch.Tensor] = None,
|
||||
conditions: tp.List[ConditioningAttributes] = [],
|
||||
num_samples: tp.Optional[int] = None,
|
||||
max_gen_len: int = 256,
|
||||
use_sampling: bool = True,
|
||||
temp: float = 1.0,
|
||||
top_k: int = 250,
|
||||
top_p: float = 0.0,
|
||||
cfg_coef: tp.Optional[float] = None,
|
||||
two_step_cfg: bool = False,
|
||||
remove_prompts: bool = False,
|
||||
check: bool = False,
|
||||
callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> torch.Tensor:
|
||||
"""Generate tokens sampling from the model given a prompt or unconditionally. Generation can
|
||||
be perform in a greedy fashion or using sampling with top K and top P strategies.
|
||||
|
||||
Args:
|
||||
prompt (Optional[torch.Tensor]): Prompt tokens of shape [B, K, T].
|
||||
conditions_tensors (Dict[str, torch.Tensor]): Set of conditions or None.
|
||||
num_samples (int or None): Number of samples to generate when no prompt and no conditions are given.
|
||||
max_gen_len (int): Maximum generation length.
|
||||
use_sampling (bool): Whether to use a sampling strategy or not.
|
||||
temp (float): Sampling temperature.
|
||||
top_k (int): K for "top-k" sampling.
|
||||
top_p (float): P for "top-p" sampling.
|
||||
remove_prompts (bool): Whether to remove prompts from generation or not.
|
||||
Returns:
|
||||
torch.Tensor: Generated tokens.
|
||||
"""
|
||||
assert not self.training, "generation shouldn't be used in training mode."
|
||||
first_param = next(iter(self.parameters()))
|
||||
device = first_param.device
|
||||
|
||||
# Checking all input shapes are consistents.
|
||||
possible_num_samples = []
|
||||
if num_samples is not None:
|
||||
possible_num_samples.append(num_samples)
|
||||
elif prompt is not None:
|
||||
possible_num_samples.append(prompt.shape[0])
|
||||
elif conditions:
|
||||
possible_num_samples.append(len(conditions))
|
||||
else:
|
||||
possible_num_samples.append(1)
|
||||
assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsitent inputs shapes"
|
||||
num_samples = possible_num_samples[0]
|
||||
|
||||
# below we create set of conditions: one conditional and one unconditional
|
||||
# to do that we merge the regular condition together with the null condition
|
||||
# we then do 1 forward pass instead of 2.
|
||||
# the reason for that is two-fold:
|
||||
# 1. it is about x2 faster than doing 2 forward passes
|
||||
# 2. avoid the streaming API treating the 2 passes as part of different time steps
|
||||
# We also support doing two different passes, in particular to ensure that
|
||||
# the padding structure is exactly the same between train anf test.
|
||||
# With a batch size of 1, this can be slower though.
|
||||
cfg_conditions: CFGConditions
|
||||
two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
|
||||
if conditions:
|
||||
null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
|
||||
if two_step_cfg:
|
||||
cfg_conditions = (
|
||||
self.condition_provider(self.condition_provider.tokenize(conditions)),
|
||||
self.condition_provider(self.condition_provider.tokenize(null_conditions)),
|
||||
)
|
||||
else:
|
||||
conditions = conditions + null_conditions
|
||||
tokenized = self.condition_provider.tokenize(conditions)
|
||||
cfg_conditions = self.condition_provider(tokenized)
|
||||
else:
|
||||
cfg_conditions = {}
|
||||
|
||||
if prompt is None:
|
||||
assert num_samples > 0
|
||||
prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device)
|
||||
|
||||
B, K, T = prompt.shape
|
||||
start_offset = T
|
||||
assert start_offset < max_gen_len
|
||||
|
||||
pattern = self.pattern_provider.get_pattern(max_gen_len)
|
||||
# this token is used as default value for codes that are not generated yet
|
||||
unknown_token = -1
|
||||
|
||||
# we generate codes up to the max_gen_len that will be mapped to the pattern sequence
|
||||
gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device)
|
||||
# filling the gen_codes with the prompt if needed
|
||||
gen_codes[..., :start_offset] = prompt
|
||||
# create the gen_sequence with proper interleaving from the pattern: [B, K, S]
|
||||
gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
|
||||
# retrieve the start_offset in the sequence:
|
||||
# it is the first sequence step that contains the `start_offset` timestep
|
||||
start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
|
||||
assert start_offset_sequence is not None
|
||||
|
||||
with self.streaming():
|
||||
unconditional_state = self.get_streaming_state()
|
||||
prev_offset = 0
|
||||
gen_sequence_len = gen_sequence.shape[-1] # gen_sequence shape is [B, K, S]
|
||||
for offset in range(start_offset_sequence, gen_sequence_len):
|
||||
# get current sequence (note that the streaming API is providing the caching over previous offsets)
|
||||
curr_sequence = gen_sequence[..., prev_offset:offset]
|
||||
curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1)
|
||||
if check:
|
||||
# check coherence between mask and sequence
|
||||
assert (curr_sequence == torch.where(curr_mask, curr_sequence, self.special_token_id)).all()
|
||||
# should never happen as gen_sequence is filled progressively
|
||||
assert not (curr_sequence == unknown_token).any()
|
||||
# sample next token from the model, next token shape is [B, K, 1]
|
||||
next_token = self._sample_next_token(
|
||||
curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p,
|
||||
cfg_coef=cfg_coef)
|
||||
# ensure the tokens that should be masked are properly set to special_token_id
|
||||
# as the model never output special_token_id
|
||||
valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
|
||||
next_token[~valid_mask] = self.special_token_id
|
||||
# ensure we don't overwrite prompt tokens, we only write over unknown tokens
|
||||
# (then mask tokens should be left as is as well, which is correct)
|
||||
gen_sequence[..., offset:offset+1] = torch.where(
|
||||
gen_sequence[..., offset:offset+1] == unknown_token,
|
||||
next_token, gen_sequence[..., offset:offset+1]
|
||||
)
|
||||
prev_offset = offset
|
||||
if callback is not None:
|
||||
callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
|
||||
unconditional_state.clear()
|
||||
|
||||
# ensure sequence has been entirely filled
|
||||
assert not (gen_sequence == unknown_token).any()
|
||||
# ensure gen_sequence pattern and mask are matching
|
||||
# which means the gen_sequence is valid according to the pattern
|
||||
assert (
|
||||
gen_sequence == torch.where(mask[None, ...].expand(B, -1, -1), gen_sequence, self.special_token_id)
|
||||
).all()
|
||||
# get back the codes, trimming the prompt if needed and cutting potentially incomplete timesteps
|
||||
out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
|
||||
|
||||
# sanity checks over the returned codes and corresponding masks
|
||||
assert (out_codes[..., :max_gen_len] != unknown_token).all()
|
||||
assert (out_mask[..., :max_gen_len] == 1).all()
|
||||
|
||||
out_start_offset = start_offset if remove_prompts else 0
|
||||
out_codes = out_codes[..., out_start_offset:max_gen_len]
|
||||
|
||||
# ensure the returned codes are all valid
|
||||
assert (out_codes >= 0).all() and (out_codes <= self.card).all()
|
||||
return out_codes
|
|
@ -0,0 +1,90 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Utility functions to load from the checkpoints.
|
||||
Each checkpoint is a torch.saved dict with the following keys:
|
||||
- 'xp.cfg': the hydra config as dumped during training. This should be used
|
||||
to rebuild the object using the audiocraft.models.builders functions,
|
||||
- 'model_best_state': a readily loadable best state for the model, including
|
||||
the conditioner. The model obtained from `xp.cfg` should be compatible
|
||||
with this state dict. In the case of a LM, the encodec model would not be
|
||||
bundled along but instead provided separately.
|
||||
|
||||
Those functions also support loading from a remote location with the Torch Hub API.
|
||||
They also support overriding some parameters, in particular the device and dtype
|
||||
of the returned model.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from huggingface_hub import hf_hub_download
|
||||
import typing as tp
|
||||
import os
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
import torch
|
||||
|
||||
from . import builders
|
||||
|
||||
|
||||
HF_MODEL_CHECKPOINTS_MAP = {
|
||||
"small": "facebook/musicgen-small",
|
||||
"medium": "facebook/musicgen-medium",
|
||||
"large": "facebook/musicgen-large",
|
||||
"melody": "facebook/musicgen-melody",
|
||||
}
|
||||
|
||||
|
||||
def _get_state_dict(
|
||||
file_or_url_or_id: tp.Union[Path, str],
|
||||
filename: tp.Optional[str] = None,
|
||||
device='cpu',
|
||||
cache_dir: tp.Optional[str] = None,
|
||||
):
|
||||
# Return the state dict either from a file or url
|
||||
file_or_url_or_id = str(file_or_url_or_id)
|
||||
assert isinstance(file_or_url_or_id, str)
|
||||
|
||||
if os.path.isfile(file_or_url_or_id):
|
||||
return torch.load(file_or_url_or_id, map_location=device)
|
||||
|
||||
elif file_or_url_or_id.startswith('https://'):
|
||||
return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True)
|
||||
|
||||
elif file_or_url_or_id in HF_MODEL_CHECKPOINTS_MAP:
|
||||
assert filename is not None, "filename needs to be defined if using HF checkpoints"
|
||||
|
||||
repo_id = HF_MODEL_CHECKPOINTS_MAP[file_or_url_or_id]
|
||||
file = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir)
|
||||
return torch.load(file, map_location=device)
|
||||
|
||||
else:
|
||||
raise ValueError(f"{file_or_url_or_id} is not a valid name, path or link that can be loaded.")
|
||||
|
||||
|
||||
def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
|
||||
pkg = _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir)
|
||||
cfg = OmegaConf.create(pkg['xp.cfg'])
|
||||
cfg.device = str(device)
|
||||
model = builders.get_compression_model(cfg)
|
||||
model.load_state_dict(pkg['best_state'])
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
|
||||
def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
|
||||
pkg = _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir)
|
||||
cfg = OmegaConf.create(pkg['xp.cfg'])
|
||||
cfg.device = str(device)
|
||||
if cfg.device == 'cpu':
|
||||
cfg.dtype = 'float32'
|
||||
else:
|
||||
cfg.dtype = 'float16'
|
||||
model = builders.get_lm_model(cfg)
|
||||
model.load_state_dict(pkg['best_state'])
|
||||
model.eval()
|
||||
model.cfg = cfg
|
||||
return model
|
|
@ -0,0 +1,361 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Main model for using MusicGen. This will combine all the required components
|
||||
and provide easy access to the generation API.
|
||||
"""
|
||||
|
||||
import os
|
||||
import typing as tp
|
||||
|
||||
import torch
|
||||
|
||||
from .encodec import CompressionModel
|
||||
from .lm import LMModel
|
||||
from .builders import get_debug_compression_model, get_debug_lm_model
|
||||
from .loaders import load_compression_model, load_lm_model, HF_MODEL_CHECKPOINTS_MAP
|
||||
from ..data.audio_utils import convert_audio
|
||||
from ..modules.conditioners import ConditioningAttributes, WavCondition
|
||||
from ..utils.autocast import TorchAutocast
|
||||
|
||||
|
||||
MelodyList = tp.List[tp.Optional[torch.Tensor]]
|
||||
MelodyType = tp.Union[torch.Tensor, MelodyList]
|
||||
|
||||
|
||||
class MusicGen:
|
||||
"""MusicGen main model with convenient generation API.
|
||||
|
||||
Args:
|
||||
name (str): name of the model.
|
||||
compression_model (CompressionModel): Compression model
|
||||
used to map audio to invertible discrete representations.
|
||||
lm (LMModel): Language model over discrete representations.
|
||||
"""
|
||||
def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
|
||||
max_duration: float = 30):
|
||||
self.name = name
|
||||
self.compression_model = compression_model
|
||||
self.lm = lm
|
||||
self.max_duration = max_duration
|
||||
self.device = next(iter(lm.parameters())).device
|
||||
self.generation_params: dict = {}
|
||||
self.set_generation_params(duration=15) # 15 seconds by default
|
||||
self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
|
||||
if self.device.type == 'cpu':
|
||||
self.autocast = TorchAutocast(enabled=False)
|
||||
else:
|
||||
self.autocast = TorchAutocast(
|
||||
enabled=True, device_type=self.device.type, dtype=torch.float16)
|
||||
|
||||
@property
|
||||
def frame_rate(self) -> int:
|
||||
"""Roughly the number of AR steps per seconds."""
|
||||
return self.compression_model.frame_rate
|
||||
|
||||
@property
|
||||
def sample_rate(self) -> int:
|
||||
"""Sample rate of the generated audio."""
|
||||
return self.compression_model.sample_rate
|
||||
|
||||
@property
|
||||
def audio_channels(self) -> int:
|
||||
"""Audio channels of the generated audio."""
|
||||
return self.compression_model.channels
|
||||
|
||||
@staticmethod
|
||||
def get_pretrained(name: str = 'melody', device=None):
|
||||
"""Return pretrained model, we provide four models:
|
||||
- small (300M), text to music, # see: https://huggingface.co/facebook/musicgen-small
|
||||
- medium (1.5B), text to music, # see: https://huggingface.co/facebook/musicgen-medium
|
||||
- melody (1.5B) text to music and text+melody to music, # see: https://huggingface.co/facebook/musicgen-melody
|
||||
- large (3.3B), text to music, # see: https://huggingface.co/facebook/musicgen-large
|
||||
"""
|
||||
|
||||
if device is None:
|
||||
if torch.cuda.device_count():
|
||||
device = 'cuda'
|
||||
else:
|
||||
device = 'cpu'
|
||||
|
||||
if name == 'debug':
|
||||
# used only for unit tests
|
||||
compression_model = get_debug_compression_model(device)
|
||||
lm = get_debug_lm_model(device)
|
||||
return MusicGen(name, compression_model, lm)
|
||||
|
||||
if name not in HF_MODEL_CHECKPOINTS_MAP:
|
||||
raise ValueError(
|
||||
f"{name} is not a valid checkpoint name. "
|
||||
f"Choose one of {', '.join(HF_MODEL_CHECKPOINTS_MAP.keys())}"
|
||||
)
|
||||
|
||||
cache_dir = os.environ.get('MUSICGEN_ROOT', None)
|
||||
compression_model = load_compression_model(name, device=device, cache_dir=cache_dir)
|
||||
lm = load_lm_model(name, device=device, cache_dir=cache_dir)
|
||||
if name == 'melody':
|
||||
lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True
|
||||
|
||||
return MusicGen(name, compression_model, lm)
|
||||
|
||||
def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
|
||||
top_p: float = 0.0, temperature: float = 1.0,
|
||||
duration: float = 30.0, cfg_coef: float = 3.0,
|
||||
two_step_cfg: bool = False, extend_stride: float = 18):
|
||||
"""Set the generation parameters for MusicGen.
|
||||
|
||||
Args:
|
||||
use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
|
||||
top_k (int, optional): top_k used for sampling. Defaults to 250.
|
||||
top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
|
||||
temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
|
||||
duration (float, optional): Duration of the generated waveform. Defaults to 30.0.
|
||||
cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
|
||||
two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
|
||||
instead of batching together the two. This has some impact on how things
|
||||
are padded but seems to have little impact in practice.
|
||||
extend_stride: when doing extended generation (i.e. more than 30 seconds), by how much
|
||||
should we extend the audio each time. Larger values will mean less context is
|
||||
preserved, and shorter value will require extra computations.
|
||||
"""
|
||||
assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
|
||||
self.extend_stride = extend_stride
|
||||
self.duration = duration
|
||||
self.generation_params = {
|
||||
'use_sampling': use_sampling,
|
||||
'temp': temperature,
|
||||
'top_k': top_k,
|
||||
'top_p': top_p,
|
||||
'cfg_coef': cfg_coef,
|
||||
'two_step_cfg': two_step_cfg,
|
||||
}
|
||||
|
||||
def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
|
||||
"""Override the default progress callback."""
|
||||
self._progress_callback = progress_callback
|
||||
|
||||
def generate_unconditional(self, num_samples: int, progress: bool = False) -> torch.Tensor:
|
||||
"""Generate samples in an unconditional manner.
|
||||
|
||||
Args:
|
||||
num_samples (int): Number of samples to be generated.
|
||||
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
||||
"""
|
||||
descriptions: tp.List[tp.Optional[str]] = [None] * num_samples
|
||||
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
|
||||
return self._generate_tokens(attributes, prompt_tokens, progress)
|
||||
|
||||
def generate(self, descriptions: tp.List[str], progress: bool = False) -> torch.Tensor:
|
||||
"""Generate samples conditioned on text.
|
||||
|
||||
Args:
|
||||
descriptions (tp.List[str]): A list of strings used as text conditioning.
|
||||
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
||||
"""
|
||||
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
|
||||
assert prompt_tokens is None
|
||||
return self._generate_tokens(attributes, prompt_tokens, progress)
|
||||
|
||||
def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs: MelodyType,
|
||||
melody_sample_rate: int, progress: bool = False) -> torch.Tensor:
|
||||
"""Generate samples conditioned on text and melody.
|
||||
|
||||
Args:
|
||||
descriptions (tp.List[str]): A list of strings used as text conditioning.
|
||||
melody_wavs: (torch.Tensor or list of Tensor): A batch of waveforms used as
|
||||
melody conditioning. Should have shape [B, C, T] with B matching the description length,
|
||||
C=1 or 2. It can be [C, T] if there is a single description. It can also be
|
||||
a list of [C, T] tensors.
|
||||
melody_sample_rate: (int): Sample rate of the melody waveforms.
|
||||
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
||||
"""
|
||||
if isinstance(melody_wavs, torch.Tensor):
|
||||
if melody_wavs.dim() == 2:
|
||||
melody_wavs = melody_wavs[None]
|
||||
if melody_wavs.dim() != 3:
|
||||
raise ValueError("Melody wavs should have a shape [B, C, T].")
|
||||
melody_wavs = list(melody_wavs)
|
||||
else:
|
||||
for melody in melody_wavs:
|
||||
if melody is not None:
|
||||
assert melody.dim() == 2, "One melody in the list has the wrong number of dims."
|
||||
|
||||
melody_wavs = [
|
||||
convert_audio(wav, melody_sample_rate, self.sample_rate, self.audio_channels)
|
||||
if wav is not None else None
|
||||
for wav in melody_wavs]
|
||||
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=None,
|
||||
melody_wavs=melody_wavs)
|
||||
assert prompt_tokens is None
|
||||
return self._generate_tokens(attributes, prompt_tokens, progress)
|
||||
|
||||
def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
|
||||
descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
|
||||
progress: bool = False) -> torch.Tensor:
|
||||
"""Generate samples conditioned on audio prompts.
|
||||
|
||||
Args:
|
||||
prompt (torch.Tensor): A batch of waveforms used for continuation.
|
||||
Prompt should be [B, C, T], or [C, T] if only one sample is generated.
|
||||
prompt_sample_rate (int): Sampling rate of the given audio waveforms.
|
||||
descriptions (tp.List[str], optional): A list of strings used as text conditioning. Defaults to None.
|
||||
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
||||
"""
|
||||
if prompt.dim() == 2:
|
||||
prompt = prompt[None]
|
||||
if prompt.dim() != 3:
|
||||
raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).")
|
||||
prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels)
|
||||
if descriptions is None:
|
||||
descriptions = [None] * len(prompt)
|
||||
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
|
||||
assert prompt_tokens is not None
|
||||
return self._generate_tokens(attributes, prompt_tokens, progress)
|
||||
|
||||
@torch.no_grad()
|
||||
def _prepare_tokens_and_attributes(
|
||||
self,
|
||||
descriptions: tp.Sequence[tp.Optional[str]],
|
||||
prompt: tp.Optional[torch.Tensor],
|
||||
melody_wavs: tp.Optional[MelodyList] = None,
|
||||
) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
|
||||
"""Prepare model inputs.
|
||||
|
||||
Args:
|
||||
descriptions (tp.List[str]): A list of strings used as text conditioning.
|
||||
prompt (torch.Tensor): A batch of waveforms used for continuation.
|
||||
melody_wavs (tp.Optional[torch.Tensor], optional): A batch of waveforms
|
||||
used as melody conditioning. Defaults to None.
|
||||
"""
|
||||
attributes = [
|
||||
ConditioningAttributes(text={'description': description})
|
||||
for description in descriptions]
|
||||
|
||||
if melody_wavs is None:
|
||||
for attr in attributes:
|
||||
attr.wav['self_wav'] = WavCondition(
|
||||
torch.zeros((1, 1), device=self.device),
|
||||
torch.tensor([0], device=self.device),
|
||||
path='null_wav') # type: ignore
|
||||
else:
|
||||
if self.name != "melody":
|
||||
raise RuntimeError("This model doesn't support melody conditioning. "
|
||||
"Use the `melody` model.")
|
||||
assert len(melody_wavs) == len(descriptions), \
|
||||
f"number of melody wavs must match number of descriptions! " \
|
||||
f"got melody len={len(melody_wavs)}, and descriptions len={len(descriptions)}"
|
||||
for attr, melody in zip(attributes, melody_wavs):
|
||||
if melody is None:
|
||||
attr.wav['self_wav'] = WavCondition(
|
||||
torch.zeros((1, 1), device=self.device),
|
||||
torch.tensor([0], device=self.device),
|
||||
path='null_wav') # type: ignore
|
||||
else:
|
||||
attr.wav['self_wav'] = WavCondition(
|
||||
melody.to(device=self.device),
|
||||
torch.tensor([melody.shape[-1]], device=self.device))
|
||||
|
||||
if prompt is not None:
|
||||
if descriptions is not None:
|
||||
assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match"
|
||||
prompt = prompt.to(self.device)
|
||||
prompt_tokens, scale = self.compression_model.encode(prompt)
|
||||
assert scale is None
|
||||
else:
|
||||
prompt_tokens = None
|
||||
return attributes, prompt_tokens
|
||||
|
||||
def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
|
||||
prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
|
||||
"""Generate discrete audio tokens given audio prompt and/or conditions.
|
||||
|
||||
Args:
|
||||
attributes (tp.List[ConditioningAttributes]): Conditions used for generation (text/melody).
|
||||
prompt_tokens (tp.Optional[torch.Tensor]): Audio prompt used for continuation.
|
||||
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
|
||||
Returns:
|
||||
torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
|
||||
"""
|
||||
total_gen_len = int(self.duration * self.frame_rate)
|
||||
max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
|
||||
current_gen_offset: int = 0
|
||||
|
||||
def _progress_callback(generated_tokens: int, tokens_to_generate: int):
|
||||
generated_tokens += current_gen_offset
|
||||
if self._progress_callback is not None:
|
||||
# Note that total_gen_len might be quite wrong depending on the
|
||||
# codebook pattern used, but with delay it is almost accurate.
|
||||
self._progress_callback(generated_tokens, total_gen_len)
|
||||
else:
|
||||
print(f'{generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
|
||||
|
||||
if prompt_tokens is not None:
|
||||
assert max_prompt_len >= prompt_tokens.shape[-1], \
|
||||
"Prompt is longer than audio to generate"
|
||||
|
||||
callback = None
|
||||
if progress:
|
||||
callback = _progress_callback
|
||||
|
||||
if self.duration <= self.max_duration:
|
||||
# generate by sampling from LM, simple case.
|
||||
with self.autocast:
|
||||
gen_tokens = self.lm.generate(
|
||||
prompt_tokens, attributes,
|
||||
callback=callback, max_gen_len=total_gen_len, **self.generation_params)
|
||||
|
||||
else:
|
||||
# now this gets a bit messier, we need to handle prompts,
|
||||
# melody conditioning etc.
|
||||
ref_wavs = [attr.wav['self_wav'] for attr in attributes]
|
||||
all_tokens = []
|
||||
if prompt_tokens is None:
|
||||
prompt_length = 0
|
||||
else:
|
||||
all_tokens.append(prompt_tokens)
|
||||
prompt_length = prompt_tokens.shape[-1]
|
||||
|
||||
stride_tokens = int(self.frame_rate * self.extend_stride)
|
||||
|
||||
while current_gen_offset + prompt_length < total_gen_len:
|
||||
time_offset = current_gen_offset / self.frame_rate
|
||||
chunk_duration = min(self.duration - time_offset, self.max_duration)
|
||||
max_gen_len = int(chunk_duration * self.frame_rate)
|
||||
for attr, ref_wav in zip(attributes, ref_wavs):
|
||||
wav_length = ref_wav.length.item()
|
||||
if wav_length == 0:
|
||||
continue
|
||||
# We will extend the wav periodically if it not long enough.
|
||||
# we have to do it here rather than in conditioners.py as otherwise
|
||||
# we wouldn't have the full wav.
|
||||
initial_position = int(time_offset * self.sample_rate)
|
||||
wav_target_length = int(self.max_duration * self.sample_rate)
|
||||
print(initial_position / self.sample_rate, wav_target_length / self.sample_rate)
|
||||
positions = torch.arange(initial_position,
|
||||
initial_position + wav_target_length, device=self.device)
|
||||
attr.wav['self_wav'] = WavCondition(
|
||||
ref_wav[0][:, positions % wav_length],
|
||||
torch.full_like(ref_wav[1], wav_target_length))
|
||||
with self.autocast:
|
||||
gen_tokens = self.lm.generate(
|
||||
prompt_tokens, attributes,
|
||||
callback=callback, max_gen_len=max_gen_len, **self.generation_params)
|
||||
if prompt_tokens is None:
|
||||
all_tokens.append(gen_tokens)
|
||||
else:
|
||||
all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
|
||||
prompt_tokens = gen_tokens[:, :, stride_tokens:]
|
||||
prompt_length = prompt_tokens.shape[-1]
|
||||
current_gen_offset += stride_tokens
|
||||
|
||||
gen_tokens = torch.cat(all_tokens, dim=-1)
|
||||
|
||||
# generate audio
|
||||
assert gen_tokens.dim() == 3
|
||||
with torch.no_grad():
|
||||
gen_audio = self.compression_model.decode(gen_tokens, None)
|
||||
return gen_audio
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# flake8: noqa
|
||||
from .conv import (
|
||||
NormConv1d,
|
||||
NormConv2d,
|
||||
NormConvTranspose1d,
|
||||
NormConvTranspose2d,
|
||||
StreamableConv1d,
|
||||
StreamableConvTranspose1d,
|
||||
pad_for_conv1d,
|
||||
pad1d,
|
||||
unpad1d,
|
||||
)
|
||||
from .lstm import StreamableLSTM
|
||||
from .seanet import SEANetEncoder, SEANetDecoder
|
|
@ -0,0 +1,96 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from typing import Union, Callable
|
||||
|
||||
|
||||
class CustomGLU(nn.Module):
|
||||
"""Custom Gated Linear Unit activation.
|
||||
Applies a modified gated linear unit :math:`a * f(b)` where :math:`a` is the first half
|
||||
of the input matrices, :math:`b` is the second half, and :math:`f` is a provided activation
|
||||
function (i.e. sigmoid, swish, etc.).
|
||||
|
||||
Args:
|
||||
activation (nn.Module): The custom activation to apply in the Gated Linear Unit
|
||||
dim (int): the dimension on which to split the input. Default: -1
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
|
||||
dimensions
|
||||
- Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
|
||||
|
||||
Examples::
|
||||
>>> m = CustomGLU(nn.Sigmoid())
|
||||
>>> input = torch.randn(4, 2)
|
||||
>>> output = m(input)
|
||||
"""
|
||||
def __init__(self, activation: nn.Module, dim: int = -1):
|
||||
super(CustomGLU, self).__init__()
|
||||
self.dim = dim
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
assert x.shape[self.dim] % 2 == 0 # M = N / 2
|
||||
a, b = torch.chunk(x, 2, dim=self.dim)
|
||||
return a * self.activation(b)
|
||||
|
||||
|
||||
class SwiGLU(CustomGLU):
|
||||
"""SiLU Gated Linear Unit activation.
|
||||
Applies SiLU Gated Linear Unit :math:`a * SiLU(b)` where :math:`a` is
|
||||
the first half of the input matrices, :math:`b` is the second half.
|
||||
|
||||
Args:
|
||||
dim (int): the dimension on which to split the input. Default: -1
|
||||
"""
|
||||
def __init__(self, dim: int = -1):
|
||||
super(SwiGLU, self).__init__(nn.SiLU(), dim)
|
||||
|
||||
|
||||
class GeGLU(CustomGLU):
|
||||
"""GeLU Gated Linear Unit activation.
|
||||
Applies GeLU Gated Linear Unit :math:`a * GELU(b)` where :math:`a` is
|
||||
the first half of the input matrices, :math:`b` is the second half.
|
||||
|
||||
Args:
|
||||
dim (int): the dimension on which to split the input. Default: -1
|
||||
"""
|
||||
def __init__(self, dim: int = -1):
|
||||
super(GeGLU, self).__init__(nn.GELU(), dim)
|
||||
|
||||
|
||||
class ReGLU(CustomGLU):
|
||||
"""ReLU Gated Linear Unit activation.
|
||||
Applies ReLU Gated Linear Unit :math:`a * ReLU(b)` where :math:`a` is
|
||||
the first half of the input matrices, :math:`b` is the second half.
|
||||
|
||||
Args:
|
||||
dim (int): the dimension on which to split the input. Default: -1
|
||||
"""
|
||||
def __init__(self, dim: int = -1):
|
||||
super(ReGLU, self).__init__(nn.ReLU(), dim)
|
||||
|
||||
|
||||
def get_activation_fn(
|
||||
activation: Union[str, Callable[[Tensor], Tensor]]
|
||||
) -> Union[str, Callable[[Tensor], Tensor]]:
|
||||
"""Helper function to map an activation string to the activation class.
|
||||
If the supplied activation is not a string that is recognized, the activation is passed back.
|
||||
|
||||
Args:
|
||||
activation (Union[str, Callable[[Tensor], Tensor]]): Activation to check
|
||||
"""
|
||||
if isinstance(activation, str):
|
||||
if activation == "reglu":
|
||||
return ReGLU()
|
||||
elif activation == "geglu":
|
||||
return GeGLU()
|
||||
elif activation == "swiglu":
|
||||
return SwiGLU()
|
||||
return activation
|
|
@ -0,0 +1,539 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from collections import namedtuple
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
import logging
|
||||
import typing as tp
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import torch
|
||||
|
||||
LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index)
|
||||
PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Pattern:
|
||||
"""Base implementation of a pattern over a sequence with multiple codebooks.
|
||||
|
||||
The codebook pattern consists in a layout, defining for each sequence step
|
||||
the list of coordinates of each codebook timestep in the resulting interleaved sequence.
|
||||
The first item of the pattern is always an empty list in order to properly insert a special token
|
||||
to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern
|
||||
and ``timesteps`` the number of timesteps corresponding to the original sequence.
|
||||
|
||||
The pattern provides convenient methods to build and revert interleaved sequences from it:
|
||||
``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
|
||||
to the interleaved sequence of shape [B, K, S] applying the pattern, with S being the batch size,
|
||||
K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
|
||||
for the output sequence. The unfilled positions are replaced with a special token and the built sequence
|
||||
is returned along with a mask indicating valid tokens.
|
||||
``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment
|
||||
of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask
|
||||
to fill and specify invalid positions if needed.
|
||||
See the dedicated methods for more details.
|
||||
"""
|
||||
# Pattern layout, for each sequence step, we have a list of coordinates
|
||||
# corresponding to the original codebook timestep and position.
|
||||
# The first list is always an empty list in order to properly insert
|
||||
# a special token to start with.
|
||||
layout: PatternLayout
|
||||
timesteps: int
|
||||
n_q: int
|
||||
|
||||
def __post_init__(self):
|
||||
assert len(self.layout) > 0
|
||||
assert self.layout[0] == []
|
||||
self._validate_layout()
|
||||
self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
|
||||
self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
|
||||
logger.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
|
||||
|
||||
def _validate_layout(self):
|
||||
"""Runs checks on the layout to ensure a valid pattern is defined.
|
||||
A pattern is considered invalid if:
|
||||
- Multiple timesteps for a same codebook are defined in the same sequence step
|
||||
- The timesteps for a given codebook are not in ascending order as we advance in the sequence
|
||||
(this would mean that we have future timesteps before past timesteps).
|
||||
"""
|
||||
q_timesteps = {q: 0 for q in range(self.n_q)}
|
||||
for s, seq_coords in enumerate(self.layout):
|
||||
if len(seq_coords) > 0:
|
||||
qs = set()
|
||||
for coord in seq_coords:
|
||||
qs.add(coord.q)
|
||||
last_q_timestep = q_timesteps[coord.q]
|
||||
assert coord.t >= last_q_timestep, \
|
||||
f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}"
|
||||
q_timesteps[coord.q] = coord.t
|
||||
# each sequence step contains at max 1 coordinate per codebook
|
||||
assert len(qs) == len(seq_coords), \
|
||||
f"Multiple entries for a same codebook are found at step {s}"
|
||||
|
||||
@property
|
||||
def num_sequence_steps(self):
|
||||
return len(self.layout) - 1
|
||||
|
||||
@property
|
||||
def max_delay(self):
|
||||
max_t_in_seq_coords = 0
|
||||
for seq_coords in self.layout[1:]:
|
||||
for coords in seq_coords:
|
||||
max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
|
||||
return max_t_in_seq_coords - self.timesteps
|
||||
|
||||
@property
|
||||
def valid_layout(self):
|
||||
valid_step = len(self.layout) - self.max_delay
|
||||
return self.layout[:valid_step]
|
||||
|
||||
def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
|
||||
"""Get codebook coordinates in the layout that corresponds to the specified timestep t
|
||||
and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
|
||||
and the actual codebook coordinates.
|
||||
"""
|
||||
assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps"
|
||||
if q is not None:
|
||||
assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks"
|
||||
coords = []
|
||||
for s, seq_codes in enumerate(self.layout):
|
||||
for code in seq_codes:
|
||||
if code.t == t and (q is None or code.q == q):
|
||||
coords.append((s, code))
|
||||
return coords
|
||||
|
||||
def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]:
|
||||
return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
|
||||
|
||||
def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]:
|
||||
steps_with_timesteps = self.get_steps_with_timestep(t, q)
|
||||
return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
|
||||
|
||||
def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool,
|
||||
device: tp.Union[torch.device, str] = 'cpu'):
|
||||
"""Build scatter indexes corresponding to the pattern, up to the provided sequence_steps.
|
||||
|
||||
Args:
|
||||
timesteps (int): Maximum number of timesteps steps to consider.
|
||||
keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
|
||||
device (Union[torch.device, str]): Device for created tensors.
|
||||
Returns:
|
||||
indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
|
||||
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
|
||||
"""
|
||||
assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
|
||||
assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern"
|
||||
# use the proper layout based on whether we limit ourselves to valid steps only or not,
|
||||
# note that using the valid_layout will result in a truncated sequence up to the valid steps
|
||||
ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
|
||||
# single item indexing being super slow with pytorch vs. numpy, so we use numpy here
|
||||
indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy()
|
||||
mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy()
|
||||
# fill indexes with last sequence step value that will correspond to our special token
|
||||
# the last value is n_q * timesteps as we have flattened z and append special token as the last token
|
||||
# which will correspond to the index: n_q * timesteps
|
||||
indexes[:] = n_q * timesteps
|
||||
# iterate over the pattern and fill scattered indexes and mask
|
||||
for s, sequence_coords in enumerate(ref_layout):
|
||||
for coords in sequence_coords:
|
||||
if coords.t < timesteps:
|
||||
indexes[coords.q, s] = coords.t + coords.q * timesteps
|
||||
mask[coords.q, s] = 1
|
||||
indexes = torch.from_numpy(indexes).to(device)
|
||||
mask = torch.from_numpy(mask).to(device)
|
||||
return indexes, mask
|
||||
|
||||
def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
|
||||
"""Build sequence corresponding to the pattern from the input tensor z.
|
||||
The sequence is built using up to sequence_steps if specified, and non-pattern
|
||||
coordinates are filled with the special token.
|
||||
|
||||
Args:
|
||||
z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T].
|
||||
special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
|
||||
keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
|
||||
Steps that are beyond valid steps will be replaced by the special_token in that case.
|
||||
Returns:
|
||||
values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S
|
||||
corresponding either to the sequence_steps if provided, otherwise to the length of the pattern.
|
||||
indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S].
|
||||
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].
|
||||
"""
|
||||
B, K, T = z.shape
|
||||
indexes, mask = self._build_pattern_sequence_scatter_indexes(
|
||||
T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
|
||||
)
|
||||
z = z.view(B, -1)
|
||||
# we append the special token as the last index of our flattened z tensor
|
||||
z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
|
||||
values = z[:, indexes.view(-1)]
|
||||
values = values.view(B, K, indexes.shape[-1])
|
||||
return values, indexes, mask
|
||||
|
||||
def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int,
|
||||
keep_only_valid_steps: bool = False,
|
||||
is_model_output: bool = False,
|
||||
device: tp.Union[torch.device, str] = 'cpu'):
|
||||
"""Builds scatter indexes required to retrieve the original multi-codebook sequence
|
||||
from interleaving pattern.
|
||||
|
||||
Args:
|
||||
sequence_steps (int): Sequence steps.
|
||||
n_q (int): Number of codebooks.
|
||||
keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
|
||||
Steps that are beyond valid steps will be replaced by the special_token in that case.
|
||||
is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
|
||||
device (Union[torch.device, str]): Device for created tensors.
|
||||
Returns:
|
||||
torch.Tensor: Indexes for reconstructing the output, of shape [K, T].
|
||||
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
|
||||
"""
|
||||
ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
|
||||
# TODO(jade): Do we want to further truncate to only valid timesteps here as well?
|
||||
timesteps = self.timesteps
|
||||
assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
|
||||
assert sequence_steps <= len(ref_layout), \
|
||||
f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
|
||||
|
||||
# ensure we take the appropriate indexes to keep the model output from the first special token as well
|
||||
if is_model_output:
|
||||
ref_layout = ref_layout[1:]
|
||||
|
||||
# single item indexing being super slow with pytorch vs. numpy, so we use numpy here
|
||||
indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy()
|
||||
mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy()
|
||||
# fill indexes with last sequence step value that will correspond to our special token
|
||||
indexes[:] = n_q * sequence_steps
|
||||
for s, sequence_codes in enumerate(ref_layout):
|
||||
if s < sequence_steps:
|
||||
for code in sequence_codes:
|
||||
if code.t < timesteps:
|
||||
indexes[code.q, code.t] = s + code.q * sequence_steps
|
||||
mask[code.q, code.t] = 1
|
||||
indexes = torch.from_numpy(indexes).to(device)
|
||||
mask = torch.from_numpy(mask).to(device)
|
||||
return indexes, mask
|
||||
|
||||
def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
|
||||
"""Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving.
|
||||
The sequence is reverted using up to timesteps if specified, and non-pattern coordinates
|
||||
are filled with the special token.
|
||||
|
||||
Args:
|
||||
s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
|
||||
special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
|
||||
Returns:
|
||||
values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T
|
||||
corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise.
|
||||
indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
|
||||
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
|
||||
"""
|
||||
B, K, S = s.shape
|
||||
indexes, mask = self._build_reverted_sequence_scatter_indexes(
|
||||
S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
|
||||
)
|
||||
s = s.view(B, -1)
|
||||
# we append the special token as the last index of our flattened z tensor
|
||||
s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
|
||||
values = s[:, indexes.view(-1)]
|
||||
values = values.view(B, K, indexes.shape[-1])
|
||||
return values, indexes, mask
|
||||
|
||||
def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False):
|
||||
"""Revert model logits obtained on a sequence built from the pattern
|
||||
back to a tensor matching the original sequence.
|
||||
|
||||
This method is similar to ``revert_pattern_sequence`` with the following specificities:
|
||||
1. It is designed to work with the extra cardinality dimension
|
||||
2. We return the logits for the first sequence item that matches the special_token and
|
||||
which matching target in the original sequence is the first item of the sequence,
|
||||
while we skip the last logits as there is no matching target
|
||||
"""
|
||||
B, card, K, S = logits.shape
|
||||
indexes, mask = self._build_reverted_sequence_scatter_indexes(
|
||||
S, K, keep_only_valid_steps, is_model_output=True, device=logits.device
|
||||
)
|
||||
logits = logits.reshape(B, card, -1)
|
||||
# we append the special token as the last index of our flattened z tensor
|
||||
logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S]
|
||||
values = logits[:, :, indexes.view(-1)]
|
||||
values = values.view(B, card, K, indexes.shape[-1])
|
||||
return values, indexes, mask
|
||||
|
||||
|
||||
class CodebooksPatternProvider(ABC):
|
||||
"""Abstraction around providing pattern for interleaving codebooks.
|
||||
|
||||
The CodebooksPatternProvider abstraction allows to implement various strategies to
|
||||
define interleaving pattern of sequences composed of multiple codebooks. For a given
|
||||
number of codebooks `n_q`, the pattern provider can generate a specified pattern
|
||||
corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern
|
||||
can be used to construct a new sequence from the original codes respecting the specified
|
||||
pattern. The pattern is defined as a list of list of code coordinates, code coordinate
|
||||
being a tuple with the original timestep and codebook to build the new sequence.
|
||||
Note that all patterns must start with an empty list that is then used to insert a first
|
||||
sequence step of special tokens in the newly generated sequence.
|
||||
|
||||
Args:
|
||||
n_q (int): number of codebooks.
|
||||
cached (bool): if True, patterns for a given length are cached. In general
|
||||
that should be true for efficiency reason to avoid synchronization points.
|
||||
"""
|
||||
def __init__(self, n_q: int, cached: bool = True):
|
||||
assert n_q > 0
|
||||
self.n_q = n_q
|
||||
self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore
|
||||
|
||||
@abstractmethod
|
||||
def get_pattern(self, timesteps: int) -> Pattern:
|
||||
"""Builds pattern with specific interleaving between codebooks.
|
||||
|
||||
Args:
|
||||
timesteps (int): Total numer of timesteps.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DelayedPatternProvider(CodebooksPatternProvider):
|
||||
"""Provider for delayed pattern across delayed codebooks.
|
||||
Codebooks are delayed in the sequence and sequence steps will contain codebooks
|
||||
from different timesteps.
|
||||
|
||||
Example:
|
||||
Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence:
|
||||
[[1, 2, 3, 4],
|
||||
[1, 2, 3, 4],
|
||||
[1, 2, 3, 4]]
|
||||
The resulting sequence obtained from the returned pattern is:
|
||||
[[S, 1, 2, 3, 4],
|
||||
[S, S, 1, 2, 3],
|
||||
[S, S, S, 1, 2]]
|
||||
(with S being a special token)
|
||||
|
||||
Args:
|
||||
n_q (int): Number of codebooks.
|
||||
delays (Optional[List[int]]): Delay for each of the codebooks.
|
||||
If delays not defined, each codebook is delayed by 1 compared to the previous one.
|
||||
flatten_first (int): Flatten the first N timesteps.
|
||||
empty_initial (int): Prepend with N empty list of coordinates.
|
||||
"""
|
||||
def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None,
|
||||
flatten_first: int = 0, empty_initial: int = 0):
|
||||
super().__init__(n_q)
|
||||
if delays is None:
|
||||
delays = list(range(n_q))
|
||||
self.delays = delays
|
||||
self.flatten_first = flatten_first
|
||||
self.empty_initial = empty_initial
|
||||
assert len(self.delays) == self.n_q
|
||||
assert sorted(self.delays) == self.delays
|
||||
|
||||
def get_pattern(self, timesteps: int) -> Pattern:
|
||||
out: PatternLayout = [[]]
|
||||
max_delay = max(self.delays)
|
||||
if self.empty_initial:
|
||||
out += [[] for _ in range(self.empty_initial)]
|
||||
if self.flatten_first:
|
||||
for t in range(min(timesteps, self.flatten_first)):
|
||||
for q in range(self.n_q):
|
||||
out.append([LayoutCoord(t, q)])
|
||||
for t in range(self.flatten_first, timesteps + max_delay):
|
||||
v = []
|
||||
for q, delay in enumerate(self.delays):
|
||||
t_for_q = t - delay
|
||||
if t_for_q >= self.flatten_first:
|
||||
v.append(LayoutCoord(t_for_q, q))
|
||||
out.append(v)
|
||||
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
||||
|
||||
|
||||
class ParallelPatternProvider(DelayedPatternProvider):
|
||||
"""Provider for parallel pattern across codebooks.
|
||||
This pattern provider is a special case of the delayed pattern with actually no delay,
|
||||
hence delays=repeat(0, n_q).
|
||||
|
||||
Args:
|
||||
n_q (int): Number of codebooks.
|
||||
"""
|
||||
def __init__(self, n_q: int):
|
||||
super().__init__(n_q, [0] * n_q)
|
||||
|
||||
|
||||
class UnrolledPatternProvider(CodebooksPatternProvider):
|
||||
"""Provider for unrolling codebooks pattern.
|
||||
This pattern provider enables to represent the codebook flattened completely or only to some extend
|
||||
while also specifying a given delay between the flattened codebooks representation, allowing to
|
||||
unroll the codebooks in the sequence.
|
||||
|
||||
Example:
|
||||
1. Flattening of the codebooks.
|
||||
By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q),
|
||||
taking n_q = 3 and timesteps = 4:
|
||||
[[1, 2, 3, 4],
|
||||
[1, 2, 3, 4],
|
||||
[1, 2, 3, 4]]
|
||||
will result into:
|
||||
[[S, S, 1, S, S, 2, S, S, 3, S, S, 4],
|
||||
[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
|
||||
[1, S, S, 2, S, S, 3, S, S, 4, S, S]]
|
||||
2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step
|
||||
for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example
|
||||
taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]:
|
||||
[[1, 2, 3, 4],
|
||||
[1, 2, 3, 4],
|
||||
[1, 2, 3, 4]]
|
||||
will result into:
|
||||
[[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
|
||||
[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
|
||||
[1, S, S, 2, S, S, 3, S, S, 4, S, S]]
|
||||
3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks
|
||||
allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the
|
||||
same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1]
|
||||
and delays = [0, 3, 3]:
|
||||
[[1, 2, 3, 4],
|
||||
[1, 2, 3, 4],
|
||||
[1, 2, 3, 4]]
|
||||
will result into:
|
||||
[[S, S, S, 1, S, 2, S, 3, S, 4],
|
||||
[S, S, S, 1, S, 2, S, 3, S, 4],
|
||||
[1, 2, 3, S, 4, S, 5, S, 6, S]]
|
||||
|
||||
Args:
|
||||
n_q (int): Number of codebooks.
|
||||
flattening (Optional[List[int]]): Flattening schema over the codebooks. If not defined,
|
||||
the codebooks will be flattened to 1 codebook per step, meaning that the sequence will
|
||||
have n_q extra steps for each timestep.
|
||||
delays (Optional[List[int]]): Delay for each of the codebooks. If not defined,
|
||||
no delay is added and therefore will default to [0] * ``n_q``.
|
||||
Note that two codebooks that will be flattened to the same inner step
|
||||
should have the same delay, otherwise the pattern is considered as invalid.
|
||||
"""
|
||||
FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay'])
|
||||
|
||||
def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None,
|
||||
delays: tp.Optional[tp.List[int]] = None):
|
||||
super().__init__(n_q)
|
||||
if flattening is None:
|
||||
flattening = list(range(n_q))
|
||||
if delays is None:
|
||||
delays = [0] * n_q
|
||||
assert len(flattening) == n_q
|
||||
assert len(delays) == n_q
|
||||
assert sorted(flattening) == flattening
|
||||
assert sorted(delays) == delays
|
||||
self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening)
|
||||
self.max_delay = max(delays)
|
||||
|
||||
def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]):
|
||||
"""Build a flattened codebooks representation as a dictionary of inner step
|
||||
and the actual codebook indices corresponding to the flattened codebook. For convenience, we
|
||||
also store the delay associated to the flattened codebook to avoid maintaining an extra mapping.
|
||||
"""
|
||||
flattened_codebooks: dict = {}
|
||||
for q, (inner_step, delay) in enumerate(zip(flattening, delays)):
|
||||
if inner_step not in flattened_codebooks:
|
||||
flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay)
|
||||
else:
|
||||
flat_codebook = flattened_codebooks[inner_step]
|
||||
assert flat_codebook.delay == delay, (
|
||||
"Delay and flattening between codebooks is inconsistent: ",
|
||||
"two codebooks flattened to the same position should have the same delay."
|
||||
)
|
||||
flat_codebook.codebooks.append(q)
|
||||
flattened_codebooks[inner_step] = flat_codebook
|
||||
return flattened_codebooks
|
||||
|
||||
@property
|
||||
def _num_inner_steps(self):
|
||||
"""Number of inner steps to unroll between timesteps in order to flatten the codebooks.
|
||||
"""
|
||||
return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1
|
||||
|
||||
def num_virtual_steps(self, timesteps: int) -> int:
|
||||
return timesteps * self._num_inner_steps + 1
|
||||
|
||||
def get_pattern(self, timesteps: int) -> Pattern:
|
||||
"""Builds pattern for delay across codebooks.
|
||||
|
||||
Args:
|
||||
timesteps (int): Total numer of timesteps.
|
||||
"""
|
||||
# the PatternLayout is built as a tuple of sequence position and list of coordinates
|
||||
# so that it can be reordered properly given the required delay between codebooks of given timesteps
|
||||
indexed_out: list = [(-1, [])]
|
||||
max_timesteps = timesteps + self.max_delay
|
||||
for t in range(max_timesteps):
|
||||
# for each timestep, we unroll the flattened codebooks,
|
||||
# emitting the sequence step with the corresponding delay
|
||||
for step in range(self._num_inner_steps):
|
||||
if step in self._flattened_codebooks:
|
||||
# we have codebooks at this virtual step to emit
|
||||
step_codebooks = self._flattened_codebooks[step]
|
||||
t_for_q = t + step_codebooks.delay
|
||||
coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks]
|
||||
if t_for_q < max_timesteps and t < max_timesteps:
|
||||
indexed_out.append((t_for_q, coords))
|
||||
else:
|
||||
# there is no codebook in this virtual step so we emit an empty list
|
||||
indexed_out.append((t, []))
|
||||
out = [coords for _, coords in sorted(indexed_out)]
|
||||
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
||||
|
||||
|
||||
class VALLEPattern(CodebooksPatternProvider):
|
||||
"""Almost VALL-E style pattern. We futher allow some delays for the
|
||||
codebooks other than the first one.
|
||||
|
||||
Args:
|
||||
n_q (int): Number of codebooks.
|
||||
delays (Optional[List[int]]): Delay for each of the codebooks.
|
||||
If delays not defined, each codebook is delayed by 1 compared to the previous one.
|
||||
"""
|
||||
def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
|
||||
super().__init__(n_q)
|
||||
if delays is None:
|
||||
delays = [0] * (n_q - 1)
|
||||
self.delays = delays
|
||||
assert len(self.delays) == self.n_q - 1
|
||||
assert sorted(self.delays) == self.delays
|
||||
|
||||
def get_pattern(self, timesteps: int) -> Pattern:
|
||||
out: PatternLayout = [[]]
|
||||
for t in range(timesteps):
|
||||
out.append([LayoutCoord(t, 0)])
|
||||
max_delay = max(self.delays)
|
||||
for t in range(timesteps + max_delay):
|
||||
v = []
|
||||
for q, delay in enumerate(self.delays):
|
||||
t_for_q = t - delay
|
||||
if t_for_q >= 0:
|
||||
v.append(LayoutCoord(t_for_q, q + 1))
|
||||
out.append(v)
|
||||
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
||||
|
||||
|
||||
class MusicLMPattern(CodebooksPatternProvider):
|
||||
"""Almost MusicLM style pattern. This is equivalent to full flattening
|
||||
but in a different order.
|
||||
|
||||
Args:
|
||||
n_q (int): Number of codebooks.
|
||||
group_by (int): Number of codebooks to group together.
|
||||
"""
|
||||
def __init__(self, n_q: int, group_by: int = 2):
|
||||
super().__init__(n_q)
|
||||
self.group_by = group_by
|
||||
|
||||
def get_pattern(self, timesteps: int) -> Pattern:
|
||||
out: PatternLayout = [[]]
|
||||
for offset in range(0, self.n_q, self.group_by):
|
||||
for t in range(timesteps):
|
||||
for q in range(offset, offset + self.group_by):
|
||||
out.append([LayoutCoord(t, q)])
|
||||
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
|
|
@ -0,0 +1,990 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, field
|
||||
from itertools import chain
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
import re
|
||||
import typing as tp
|
||||
import warnings
|
||||
|
||||
from einops import rearrange
|
||||
from num2words import num2words
|
||||
import spacy
|
||||
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
|
||||
import torchaudio
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch import Tensor
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from .streaming import StreamingModule
|
||||
from .transformer import create_sin_embedding
|
||||
from ..data.audio_dataset import SegmentInfo
|
||||
from ..utils.autocast import TorchAutocast
|
||||
from ..utils.utils import hash_trick, length_to_mask, collate
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist)
|
||||
ConditionType = tp.Tuple[Tensor, Tensor] # condition, mask
|
||||
|
||||
|
||||
class WavCondition(tp.NamedTuple):
|
||||
wav: Tensor
|
||||
length: Tensor
|
||||
path: tp.List[tp.Optional[str]] = []
|
||||
|
||||
|
||||
def nullify_condition(condition: ConditionType, dim: int = 1):
|
||||
"""This function transforms an input condition to a null condition.
|
||||
The way it is done by converting it to a single zero vector similarly
|
||||
to how it is done inside WhiteSpaceTokenizer and NoopTokenizer.
|
||||
|
||||
Args:
|
||||
condition (ConditionType): a tuple of condition and mask (tp.Tuple[Tensor, Tensor])
|
||||
dim (int): the dimension that will be truncated (should be the time dimension)
|
||||
WARNING!: dim should not be the batch dimension!
|
||||
Returns:
|
||||
ConditionType: a tuple of null condition and mask
|
||||
"""
|
||||
assert dim != 0, "dim cannot be the batch dimension!"
|
||||
assert type(condition) == tuple and \
|
||||
type(condition[0]) == Tensor and \
|
||||
type(condition[1]) == Tensor, "'nullify_condition' got an unexpected input type!"
|
||||
cond, mask = condition
|
||||
B = cond.shape[0]
|
||||
last_dim = cond.dim() - 1
|
||||
out = cond.transpose(dim, last_dim)
|
||||
out = 0. * out[..., :1]
|
||||
out = out.transpose(dim, last_dim)
|
||||
mask = torch.zeros((B, 1), device=out.device).int()
|
||||
assert cond.dim() == out.dim()
|
||||
return out, mask
|
||||
|
||||
|
||||
def nullify_wav(wav: Tensor) -> WavCondition:
|
||||
"""Create a nullified WavCondition from a wav tensor with appropriate shape.
|
||||
|
||||
Args:
|
||||
wav (Tensor): tensor of shape [B, T]
|
||||
Returns:
|
||||
WavCondition: wav condition with nullified wav.
|
||||
"""
|
||||
null_wav, _ = nullify_condition((wav, torch.zeros_like(wav)), dim=wav.dim() - 1)
|
||||
return WavCondition(
|
||||
wav=null_wav,
|
||||
length=torch.tensor([0] * wav.shape[0], device=wav.device),
|
||||
path=['null_wav'] * wav.shape[0]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditioningAttributes:
|
||||
text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
|
||||
wav: tp.Dict[str, WavCondition] = field(default_factory=dict)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return getattr(self, item)
|
||||
|
||||
@property
|
||||
def text_attributes(self):
|
||||
return self.text.keys()
|
||||
|
||||
@property
|
||||
def wav_attributes(self):
|
||||
return self.wav.keys()
|
||||
|
||||
@property
|
||||
def attributes(self):
|
||||
return {"text": self.text_attributes, "wav": self.wav_attributes}
|
||||
|
||||
def to_flat_dict(self):
|
||||
return {
|
||||
**{f"text.{k}": v for k, v in self.text.items()},
|
||||
**{f"wav.{k}": v for k, v in self.wav.items()},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_flat_dict(cls, x):
|
||||
out = cls()
|
||||
for k, v in x.items():
|
||||
kind, att = k.split(".")
|
||||
out[kind][att] = v
|
||||
return out
|
||||
|
||||
|
||||
class SegmentWithAttributes(SegmentInfo):
|
||||
"""Base class for all dataclasses that are used for conditioning.
|
||||
All child classes should implement `to_condition_attributes` that converts
|
||||
the existing attributes to a dataclass of type ConditioningAttributes.
|
||||
"""
|
||||
def to_condition_attributes(self) -> ConditioningAttributes:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""Base class for all tokenizers
|
||||
(in case we want to introduce more advances tokenizers in the future).
|
||||
"""
|
||||
def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[Tensor, Tensor]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class WhiteSpaceTokenizer(Tokenizer):
|
||||
"""This tokenizer should be used for natural language descriptions.
|
||||
For example:
|
||||
["he didn't, know he's going home.", 'shorter sentence'] =>
|
||||
[[78, 62, 31, 4, 78, 25, 19, 34],
|
||||
[59, 77, 0, 0, 0, 0, 0, 0]]
|
||||
"""
|
||||
PUNCTUATIONS = "?:!.,;"
|
||||
|
||||
def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_core_web_sm",
|
||||
lemma: bool = True, stopwords: bool = True) -> None:
|
||||
self.n_bins = n_bins
|
||||
self.pad_idx = pad_idx
|
||||
self.lemma = lemma
|
||||
self.stopwords = stopwords
|
||||
try:
|
||||
self.nlp = spacy.load(language)
|
||||
except IOError:
|
||||
spacy.cli.download(language) # type: ignore
|
||||
self.nlp = spacy.load(language)
|
||||
|
||||
@tp.no_type_check
|
||||
def __call__(
|
||||
self,
|
||||
texts: tp.List[tp.Optional[str]],
|
||||
return_text: bool = False
|
||||
) -> tp.Tuple[Tensor, Tensor]:
|
||||
"""Take a list of strings and convert them to a tensor of indices.
|
||||
|
||||
Args:
|
||||
texts (tp.List[str]): List of strings.
|
||||
return_text (bool, optional): Whether to return text as additional tuple item. Defaults to False.
|
||||
Returns:
|
||||
tp.Tuple[Tensor, Tensor]:
|
||||
- Indices of words in the LUT.
|
||||
- And a mask indicating where the padding tokens are
|
||||
"""
|
||||
output, lengths = [], []
|
||||
texts = deepcopy(texts)
|
||||
for i, text in enumerate(texts):
|
||||
# if current sample doesn't have a certain attribute, replace with pad token
|
||||
if text is None:
|
||||
output.append(Tensor([self.pad_idx]))
|
||||
lengths.append(0)
|
||||
continue
|
||||
|
||||
# convert numbers to words
|
||||
text = re.sub(r"(\d+)", lambda x: num2words(int(x.group(0))), text) # type: ignore
|
||||
# normalize text
|
||||
text = self.nlp(text) # type: ignore
|
||||
# remove stopwords
|
||||
if self.stopwords:
|
||||
text = [w for w in text if not w.is_stop] # type: ignore
|
||||
# remove punctuations
|
||||
text = [w for w in text if w.text not in self.PUNCTUATIONS] # type: ignore
|
||||
# lemmatize if needed
|
||||
text = [getattr(t, "lemma_" if self.lemma else "text") for t in text] # type: ignore
|
||||
|
||||
texts[i] = " ".join(text)
|
||||
lengths.append(len(text))
|
||||
# convert to tensor
|
||||
tokens = Tensor([hash_trick(w, self.n_bins) for w in text])
|
||||
output.append(tokens)
|
||||
|
||||
mask = length_to_mask(torch.IntTensor(lengths)).int()
|
||||
padded_output = pad_sequence(output, padding_value=self.pad_idx).int().t()
|
||||
if return_text:
|
||||
return padded_output, mask, texts # type: ignore
|
||||
return padded_output, mask
|
||||
|
||||
|
||||
class NoopTokenizer(Tokenizer):
|
||||
"""This tokenizer should be used for global conditioners such as: artist, genre, key, etc.
|
||||
The difference between this and WhiteSpaceTokenizer is that NoopTokenizer does not split
|
||||
strings, so "Jeff Buckley" will get it's own index. Whereas WhiteSpaceTokenizer will
|
||||
split it to ["Jeff", "Buckley"] and return an index per word.
|
||||
|
||||
For example:
|
||||
["Queen", "ABBA", "Jeff Buckley"] => [43, 55, 101]
|
||||
["Metal", "Rock", "Classical"] => [0, 223, 51]
|
||||
"""
|
||||
def __init__(self, n_bins: int, pad_idx: int = 0):
|
||||
self.n_bins = n_bins
|
||||
self.pad_idx = pad_idx
|
||||
|
||||
def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[Tensor, Tensor]:
|
||||
output, lengths = [], []
|
||||
for text in texts:
|
||||
# if current sample doesn't have a certain attribute, replace with pad token
|
||||
if text is None:
|
||||
output.append(self.pad_idx)
|
||||
lengths.append(0)
|
||||
else:
|
||||
output.append(hash_trick(text, self.n_bins))
|
||||
lengths.append(1)
|
||||
|
||||
tokens = torch.LongTensor(output).unsqueeze(1)
|
||||
mask = length_to_mask(torch.IntTensor(lengths)).int()
|
||||
return tokens, mask
|
||||
|
||||
|
||||
class BaseConditioner(nn.Module):
|
||||
"""Base model for all conditioner modules. We allow the output dim to be different
|
||||
than the hidden dim for two reasons: 1) keep our LUTs small when the vocab is large;
|
||||
2) make all condition dims consistent.
|
||||
|
||||
Args:
|
||||
dim (int): Hidden dim of the model (text-encoder/LUT).
|
||||
output_dim (int): Output dim of the conditioner.
|
||||
"""
|
||||
def __init__(self, dim, output_dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.output_dim = output_dim
|
||||
self.output_proj = nn.Linear(dim, output_dim)
|
||||
|
||||
def tokenize(self, *args, **kwargs) -> tp.Any:
|
||||
"""Should be any part of the processing that will lead to a synchronization
|
||||
point, e.g. BPE tokenization with transfer to the GPU.
|
||||
|
||||
The returned value will be saved and return later when calling forward().
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward(self, inputs: tp.Any) -> ConditionType:
|
||||
"""Gets input that should be used as conditioning (e.g, genre, description or a waveform).
|
||||
Outputs a ConditionType, after the input data was embedded as a dense vector.
|
||||
|
||||
Returns:
|
||||
ConditionType:
|
||||
- A tensor of size [B, T, D] where B is the batch size, T is the length of the
|
||||
output embedding and D is the dimension of the embedding.
|
||||
- And a mask indicating where the padding tokens.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class TextConditioner(BaseConditioner):
|
||||
...
|
||||
|
||||
|
||||
class LUTConditioner(TextConditioner):
|
||||
"""Lookup table TextConditioner.
|
||||
|
||||
Args:
|
||||
n_bins (int): Number of bins.
|
||||
dim (int): Hidden dim of the model (text-encoder/LUT).
|
||||
output_dim (int): Output dim of the conditioner.
|
||||
tokenizer (str): Name of the tokenizer.
|
||||
pad_idx (int, optional): Index for padding token. Defaults to 0.
|
||||
"""
|
||||
def __init__(self, n_bins: int, dim: int, output_dim: int, tokenizer: str, pad_idx: int = 0):
|
||||
super().__init__(dim, output_dim)
|
||||
self.embed = nn.Embedding(n_bins, dim)
|
||||
self.tokenizer: Tokenizer
|
||||
if tokenizer == "whitespace":
|
||||
self.tokenizer = WhiteSpaceTokenizer(n_bins, pad_idx=pad_idx)
|
||||
elif tokenizer == "noop":
|
||||
self.tokenizer = NoopTokenizer(n_bins, pad_idx=pad_idx)
|
||||
else:
|
||||
raise ValueError(f"unrecognized tokenizer `{tokenizer}`.")
|
||||
|
||||
def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
||||
device = self.embed.weight.device
|
||||
tokens, mask = self.tokenizer(x)
|
||||
tokens, mask = tokens.to(device), mask.to(device)
|
||||
return tokens, mask
|
||||
|
||||
def forward(self, inputs: tp.Tuple[torch.Tensor, torch.Tensor]) -> ConditionType:
|
||||
tokens, mask = inputs
|
||||
embeds = self.embed(tokens)
|
||||
embeds = self.output_proj(embeds)
|
||||
embeds = (embeds * mask.unsqueeze(-1))
|
||||
return embeds, mask
|
||||
|
||||
|
||||
class T5Conditioner(TextConditioner):
|
||||
"""T5-based TextConditioner.
|
||||
|
||||
Args:
|
||||
name (str): Name of the T5 model.
|
||||
output_dim (int): Output dim of the conditioner.
|
||||
finetune (bool): Whether to fine-tune T5 at train time.
|
||||
device (str): Device for T5 Conditioner.
|
||||
autocast_dtype (tp.Optional[str], optional): Autocast dtype.
|
||||
word_dropout (float, optional): Word dropout probability.
|
||||
normalize_text (bool, optional): Whether to apply text normalization.
|
||||
"""
|
||||
MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
|
||||
"google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large",
|
||||
"google/flan-t5-xl", "google/flan-t5-xxl"]
|
||||
MODELS_DIMS = {
|
||||
"t5-small": 512,
|
||||
"t5-base": 768,
|
||||
"t5-large": 1024,
|
||||
"t5-3b": 1024,
|
||||
"t5-11b": 1024,
|
||||
"google/flan-t5-small": 512,
|
||||
"google/flan-t5-base": 768,
|
||||
"google/flan-t5-large": 1024,
|
||||
"google/flan-t5-3b": 1024,
|
||||
"google/flan-t5-11b": 1024,
|
||||
}
|
||||
|
||||
def __init__(self, name: str, output_dim: int, finetune: bool, device: str,
|
||||
autocast_dtype: tp.Optional[str] = 'float32', word_dropout: float = 0.,
|
||||
normalize_text: bool = False):
|
||||
assert name in self.MODELS, f"unrecognized t5 model name (should in {self.MODELS})"
|
||||
super().__init__(self.MODELS_DIMS[name], output_dim)
|
||||
self.device = device
|
||||
self.name = name
|
||||
self.finetune = finetune
|
||||
self.word_dropout = word_dropout
|
||||
|
||||
if autocast_dtype is None or self.device == 'cpu':
|
||||
self.autocast = TorchAutocast(enabled=False)
|
||||
if self.device != 'cpu':
|
||||
logger.warning("T5 has no autocast, this might lead to NaN")
|
||||
else:
|
||||
dtype = getattr(torch, autocast_dtype)
|
||||
assert isinstance(dtype, torch.dtype)
|
||||
logger.info(f"T5 will be evaluated with autocast as {autocast_dtype}")
|
||||
self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype)
|
||||
# Let's disable logging temporarily because T5 will vomit some errors otherwise.
|
||||
# thanks https://gist.github.com/simon-weber/7853144
|
||||
previous_level = logging.root.manager.disable
|
||||
logging.disable(logging.ERROR)
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
try:
|
||||
self.t5_tokenizer = T5Tokenizer.from_pretrained(name)
|
||||
t5 = T5EncoderModel.from_pretrained(name).train(mode=finetune)
|
||||
finally:
|
||||
logging.disable(previous_level)
|
||||
if finetune:
|
||||
self.t5 = t5
|
||||
else:
|
||||
# this makes sure that the t5 models is not part
|
||||
# of the saved checkpoint
|
||||
self.__dict__["t5"] = t5.to(device)
|
||||
|
||||
self.normalize_text = normalize_text
|
||||
if normalize_text:
|
||||
self.text_normalizer = WhiteSpaceTokenizer(1, lemma=True, stopwords=True)
|
||||
|
||||
def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]:
|
||||
# if current sample doesn't have a certain attribute, replace with empty string
|
||||
entries: tp.List[str] = [xi if xi is not None else "" for xi in x]
|
||||
if self.normalize_text:
|
||||
_, _, entries = self.text_normalizer(entries, return_text=True)
|
||||
if self.word_dropout > 0. and self.training:
|
||||
new_entries = []
|
||||
for entry in entries:
|
||||
words = [word for word in entry.split(" ") if random.random() >= self.word_dropout]
|
||||
new_entries.append(" ".join(words))
|
||||
entries = new_entries
|
||||
|
||||
empty_idx = torch.LongTensor([i for i, xi in enumerate(entries) if xi == ""])
|
||||
|
||||
inputs = self.t5_tokenizer(entries, return_tensors="pt", padding=True).to(self.device)
|
||||
mask = inputs["attention_mask"]
|
||||
mask[empty_idx, :] = 0 # zero-out index where the input is non-existant
|
||||
return inputs
|
||||
|
||||
def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType:
|
||||
mask = inputs["attention_mask"]
|
||||
with torch.set_grad_enabled(self.finetune), self.autocast:
|
||||
embeds = self.t5(**inputs).last_hidden_state
|
||||
embeds = self.output_proj(embeds.to(self.output_proj.weight))
|
||||
embeds = (embeds * mask.unsqueeze(-1))
|
||||
return embeds, mask
|
||||
|
||||
|
||||
class WaveformConditioner(BaseConditioner):
|
||||
"""Base class for all conditioners that take a waveform as input.
|
||||
Classes that inherit must implement `_get_wav_embedding` that outputs
|
||||
a continuous tensor, and `_downsampling_factor` that returns the down-sampling
|
||||
factor of the embedding model.
|
||||
|
||||
Args:
|
||||
dim (int): The internal representation dimension.
|
||||
output_dim (int): Output dimension.
|
||||
device (tp.Union[torch.device, str]): Device.
|
||||
"""
|
||||
def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.device, str]):
|
||||
super().__init__(dim, output_dim)
|
||||
self.device = device
|
||||
|
||||
def tokenize(self, wav_length: WavCondition) -> WavCondition:
|
||||
wav, length, path = wav_length
|
||||
assert length is not None
|
||||
return WavCondition(wav.to(self.device), length.to(self.device), path)
|
||||
|
||||
def _get_wav_embedding(self, wav: Tensor) -> Tensor:
|
||||
"""Gets as input a wav and returns a dense vector of conditions."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def _downsampling_factor(self):
|
||||
"""Returns the downsampling factor of the embedding model."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward(self, inputs: WavCondition) -> ConditionType:
|
||||
"""
|
||||
Args:
|
||||
input (WavCondition): Tuple of (waveform, lengths).
|
||||
Returns:
|
||||
ConditionType: Dense vector representing the conditioning along with its' mask.
|
||||
"""
|
||||
wav, lengths, path = inputs
|
||||
with torch.no_grad():
|
||||
embeds = self._get_wav_embedding(wav)
|
||||
embeds = embeds.to(self.output_proj.weight)
|
||||
embeds = self.output_proj(embeds)
|
||||
|
||||
if lengths is not None:
|
||||
lengths = lengths / self._downsampling_factor()
|
||||
mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore
|
||||
else:
|
||||
mask = torch.ones_like(embeds)
|
||||
embeds = (embeds * mask.unsqueeze(2).to(self.device))
|
||||
|
||||
return embeds, mask
|
||||
|
||||
|
||||
class ChromaStemConditioner(WaveformConditioner):
|
||||
"""Chroma conditioner that uses DEMUCS to first filter out drums and bass. The is followed by
|
||||
the insight the drums and bass often dominate the chroma, leading to the chroma not containing the
|
||||
information about melody.
|
||||
|
||||
Args:
|
||||
output_dim (int): Output dimension for the conditioner.
|
||||
sample_rate (int): Sample rate for the chroma extractor.
|
||||
n_chroma (int): Number of chroma for the chroma extractor.
|
||||
radix2_exp (int): Radix2 exponent for the chroma extractor.
|
||||
duration (float): Duration used during training. This is later used for correct padding
|
||||
in case we are using chroma as prefix.
|
||||
match_len_on_eval (bool, optional): If True then all chromas are padded to the training
|
||||
duration. Defaults to False.
|
||||
eval_wavs (str, optional): Path to a json egg with waveform, this waveforms are used as
|
||||
conditions during eval (for cases where we don't want to leak test conditions like MusicCaps).
|
||||
Defaults to None.
|
||||
n_eval_wavs (int, optional): Limits the number of waveforms used for conditioning. Defaults to 0.
|
||||
device (tp.Union[torch.device, str], optional): Device for the conditioner.
|
||||
**kwargs: Additional parameters for the chroma extractor.
|
||||
"""
|
||||
def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, radix2_exp: int,
|
||||
duration: float, match_len_on_eval: bool = True, eval_wavs: tp.Optional[str] = None,
|
||||
n_eval_wavs: int = 0, device: tp.Union[torch.device, str] = "cpu", **kwargs):
|
||||
from demucs import pretrained
|
||||
super().__init__(dim=n_chroma, output_dim=output_dim, device=device)
|
||||
self.autocast = TorchAutocast(enabled=device != "cpu", device_type=self.device, dtype=torch.float32)
|
||||
self.sample_rate = sample_rate
|
||||
self.match_len_on_eval = match_len_on_eval
|
||||
self.duration = duration
|
||||
self.__dict__["demucs"] = pretrained.get_model('htdemucs').to(device)
|
||||
self.stem2idx = {'drums': 0, 'bass': 1, 'other': 2, 'vocal': 3}
|
||||
self.stem_idx = torch.LongTensor([self.stem2idx['vocal'], self.stem2idx['other']]).to(device)
|
||||
self.chroma = ChromaExtractor(sample_rate=sample_rate, n_chroma=n_chroma, radix2_exp=radix2_exp,
|
||||
device=device, **kwargs)
|
||||
self.chroma_len = self._get_chroma_len()
|
||||
|
||||
def _downsampling_factor(self):
|
||||
return self.chroma.winhop
|
||||
|
||||
def _get_chroma_len(self):
|
||||
"""Get length of chroma during training"""
|
||||
dummy_wav = torch.zeros((1, self.sample_rate * self.duration), device=self.device)
|
||||
dummy_chr = self.chroma(dummy_wav)
|
||||
return dummy_chr.shape[1]
|
||||
|
||||
@torch.no_grad()
|
||||
def _get_filtered_wav(self, wav):
|
||||
from demucs.apply import apply_model
|
||||
from demucs.audio import convert_audio
|
||||
with self.autocast:
|
||||
wav = convert_audio(wav, self.sample_rate, self.demucs.samplerate, self.demucs.audio_channels)
|
||||
stems = apply_model(self.demucs, wav, device=self.device)
|
||||
stems = stems[:, self.stem_idx] # extract stem
|
||||
stems = stems.sum(1) # merge extracted stems
|
||||
stems = stems.mean(1, keepdim=True) # mono
|
||||
stems = convert_audio(stems, self.demucs.samplerate, self.sample_rate, 1)
|
||||
return stems
|
||||
|
||||
@torch.no_grad()
|
||||
def _get_wav_embedding(self, wav):
|
||||
# avoid 0-size tensors when we are working with null conds
|
||||
if wav.shape[-1] == 1:
|
||||
return self.chroma(wav)
|
||||
stems = self._get_filtered_wav(wav)
|
||||
chroma = self.chroma(stems)
|
||||
|
||||
if self.match_len_on_eval:
|
||||
b, t, c = chroma.shape
|
||||
if t > self.chroma_len:
|
||||
chroma = chroma[:, :self.chroma_len]
|
||||
logger.debug(f'chroma was truncated! ({t} -> {chroma.shape[1]})')
|
||||
elif t < self.chroma_len:
|
||||
# chroma = F.pad(chroma, (0, 0, 0, self.chroma_len - t))
|
||||
n_repeat = int(math.ceil(self.chroma_len / t))
|
||||
chroma = chroma.repeat(1, n_repeat, 1)
|
||||
chroma = chroma[:, :self.chroma_len]
|
||||
logger.debug(f'chroma was zero-padded! ({t} -> {chroma.shape[1]})')
|
||||
return chroma
|
||||
|
||||
|
||||
class ChromaExtractor(nn.Module):
|
||||
"""Chroma extraction class, handles chroma extraction and quantization.
|
||||
|
||||
Args:
|
||||
sample_rate (int): Sample rate.
|
||||
n_chroma (int): Number of chroma to consider.
|
||||
radix2_exp (int): Radix2 exponent.
|
||||
nfft (tp.Optional[int], optional): Number of FFT.
|
||||
winlen (tp.Optional[int], optional): Window length.
|
||||
winhop (tp.Optional[int], optional): Window hop size.
|
||||
argmax (bool, optional): Whether to use argmax. Defaults to False.
|
||||
norm (float, optional): Norm for chroma normalization. Defaults to inf.
|
||||
device (tp.Union[torch.device, str], optional): Device to use. Defaults to cpu.
|
||||
"""
|
||||
def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12,
|
||||
nfft: tp.Optional[int] = None, winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None,
|
||||
argmax: bool = False, norm: float = torch.inf, device: tp.Union[torch.device, str] = "cpu"):
|
||||
super().__init__()
|
||||
from librosa import filters
|
||||
self.device = device
|
||||
self.autocast = TorchAutocast(enabled=device != "cpu", device_type=self.device, dtype=torch.float32)
|
||||
self.winlen = winlen or 2 ** radix2_exp
|
||||
self.nfft = nfft or self.winlen
|
||||
self.winhop = winhop or (self.winlen // 4)
|
||||
self.sr = sample_rate
|
||||
self.n_chroma = n_chroma
|
||||
self.norm = norm
|
||||
self.argmax = argmax
|
||||
self.window = torch.hann_window(self.winlen).to(device)
|
||||
self.fbanks = torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0,
|
||||
n_chroma=self.n_chroma)).to(device)
|
||||
self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen,
|
||||
hop_length=self.winhop, power=2, center=True,
|
||||
pad=0, normalized=True).to(device)
|
||||
|
||||
def forward(self, wav):
|
||||
with self.autocast:
|
||||
T = wav.shape[-1]
|
||||
# in case we are getting a wav that was dropped out (nullified)
|
||||
# make sure wav length is no less that nfft
|
||||
if T < self.nfft:
|
||||
pad = self.nfft - T
|
||||
r = 0 if pad % 2 == 0 else 1
|
||||
wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0)
|
||||
assert wav.shape[-1] == self.nfft, f'expected len {self.nfft} but got {wav.shape[-1]}'
|
||||
spec = self.spec(wav).squeeze(1)
|
||||
raw_chroma = torch.einsum("cf,...ft->...ct", self.fbanks, spec)
|
||||
norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6)
|
||||
norm_chroma = rearrange(norm_chroma, "b d t -> b t d")
|
||||
|
||||
if self.argmax:
|
||||
idx = norm_chroma.argmax(-1, keepdims=True)
|
||||
norm_chroma[:] = 0
|
||||
norm_chroma.scatter_(dim=-1, index=idx, value=1)
|
||||
|
||||
return norm_chroma
|
||||
|
||||
|
||||
def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str):
|
||||
"""Utility function for nullifying an attribute inside an ConditioningAttributes object.
|
||||
If the condition is of type "wav", then nullify it using "nullify_condition".
|
||||
If the condition is of any other type, set its' value to None.
|
||||
Works in-place.
|
||||
"""
|
||||
if condition_type not in ["text", "wav"]:
|
||||
raise ValueError(
|
||||
"dropout_condition got an unexpected condition type!"
|
||||
f" expected 'wav' or 'text' but got '{condition_type}'"
|
||||
)
|
||||
|
||||
if condition not in getattr(sample, condition_type):
|
||||
raise ValueError(
|
||||
"dropout_condition received an unexpected condition!"
|
||||
f" expected wav={sample.wav.keys()} and text={sample.text.keys()}"
|
||||
f"but got '{condition}' of type '{condition_type}'!"
|
||||
)
|
||||
|
||||
if condition_type == "wav":
|
||||
wav, length, path = sample.wav[condition]
|
||||
sample.wav[condition] = nullify_wav(wav)
|
||||
else:
|
||||
sample.text[condition] = None
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
class DropoutModule(nn.Module):
|
||||
"""Base class for all dropout modules."""
|
||||
def __init__(self, seed: int = 1234):
|
||||
super().__init__()
|
||||
self.rng = torch.Generator()
|
||||
self.rng.manual_seed(seed)
|
||||
|
||||
|
||||
class AttributeDropout(DropoutModule):
|
||||
"""Applies dropout with a given probability per attribute. This is different from the behavior of
|
||||
ClassifierFreeGuidanceDropout as this allows for attributes to be dropped out separately. For example,
|
||||
"artist" can be dropped while "genre" remains. This is in contrast to ClassifierFreeGuidanceDropout
|
||||
where if "artist" is dropped "genre" must also be dropped.
|
||||
|
||||
Args:
|
||||
p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example:
|
||||
...
|
||||
"genre": 0.1,
|
||||
"artist": 0.5,
|
||||
"wav": 0.25,
|
||||
...
|
||||
active_on_eval (bool, optional): Whether the dropout is active at eval. Default to False.
|
||||
seed (int, optional): Random seed.
|
||||
"""
|
||||
def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eval: bool = False, seed: int = 1234):
|
||||
super().__init__(seed=seed)
|
||||
self.active_on_eval = active_on_eval
|
||||
# construct dict that return the values from p otherwise 0
|
||||
self.p = {}
|
||||
for condition_type, probs in p.items():
|
||||
self.p[condition_type] = defaultdict(lambda: 0, probs)
|
||||
|
||||
def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
|
||||
"""
|
||||
Args:
|
||||
samples (tp.List[ConditioningAttributes]): List of conditions.
|
||||
Returns:
|
||||
tp.List[ConditioningAttributes]: List of conditions after certain attributes were set to None.
|
||||
"""
|
||||
if not self.training and not self.active_on_eval:
|
||||
return samples
|
||||
|
||||
samples = deepcopy(samples)
|
||||
|
||||
for condition_type, ps in self.p.items(): # for condition types [text, wav]
|
||||
for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre])
|
||||
if torch.rand(1, generator=self.rng).item() < p:
|
||||
for sample in samples:
|
||||
dropout_condition(sample, condition_type, condition)
|
||||
|
||||
return samples
|
||||
|
||||
def __repr__(self):
|
||||
return f"AttributeDropout({dict(self.p)})"
|
||||
|
||||
|
||||
class ClassifierFreeGuidanceDropout(DropoutModule):
|
||||
"""Applies Classifier Free Guidance dropout, meaning all attributes
|
||||
are dropped with the same probability.
|
||||
|
||||
Args:
|
||||
p (float): Probability to apply condition dropout during training.
|
||||
seed (int): Random seed.
|
||||
"""
|
||||
def __init__(self, p: float, seed: int = 1234):
|
||||
super().__init__(seed=seed)
|
||||
self.p = p
|
||||
|
||||
def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
|
||||
"""
|
||||
Args:
|
||||
samples (tp.List[ConditioningAttributes]): List of conditions.
|
||||
Returns:
|
||||
tp.List[ConditioningAttributes]: List of conditions after all attributes were set to None.
|
||||
"""
|
||||
if not self.training:
|
||||
return samples
|
||||
|
||||
# decide on which attributes to drop in a batched fashion
|
||||
drop = torch.rand(1, generator=self.rng).item() < self.p
|
||||
if not drop:
|
||||
return samples
|
||||
|
||||
# nullify conditions of all attributes
|
||||
samples = deepcopy(samples)
|
||||
|
||||
for condition_type in ["wav", "text"]:
|
||||
for sample in samples:
|
||||
for condition in sample.attributes[condition_type]:
|
||||
dropout_condition(sample, condition_type, condition)
|
||||
|
||||
return samples
|
||||
|
||||
def __repr__(self):
|
||||
return f"ClassifierFreeGuidanceDropout(p={self.p})"
|
||||
|
||||
|
||||
class ConditioningProvider(nn.Module):
|
||||
"""Main class to provide conditions given all the supported conditioners.
|
||||
|
||||
Args:
|
||||
conditioners (dict): Dictionary of conditioners.
|
||||
merge_text_conditions_p (float, optional): Probability to merge all text sources
|
||||
into a single text condition. Defaults to 0.
|
||||
drop_desc_p (float, optional): Probability to drop the original description
|
||||
when merging all text sources into a single text condition. Defaults to 0.
|
||||
device (tp.Union[torch.device, str], optional): Device for conditioners and output condition types.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
conditioners: tp.Dict[str, BaseConditioner],
|
||||
merge_text_conditions_p: float = 0,
|
||||
drop_desc_p: float = 0,
|
||||
device: tp.Union[torch.device, str] = "cpu",
|
||||
):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.merge_text_conditions_p = merge_text_conditions_p
|
||||
self.drop_desc_p = drop_desc_p
|
||||
self.conditioners = nn.ModuleDict(conditioners)
|
||||
|
||||
@property
|
||||
def text_conditions(self):
|
||||
return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)]
|
||||
|
||||
@property
|
||||
def wav_conditions(self):
|
||||
return [k for k, v in self.conditioners.items() if isinstance(v, WaveformConditioner)]
|
||||
|
||||
@property
|
||||
def has_wav_condition(self):
|
||||
return len(self.wav_conditions) > 0
|
||||
|
||||
def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
|
||||
"""Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly.
|
||||
This should be called before starting any real GPU work to avoid synchronization points.
|
||||
This will return a dict matching conditioner names to their arbitrary tokenized representations.
|
||||
|
||||
Args:
|
||||
inputs (list[ConditioningAttribres]): List of ConditioningAttributes objects containing
|
||||
text and wav conditions.
|
||||
"""
|
||||
assert all([type(x) == ConditioningAttributes for x in inputs]), \
|
||||
"got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]" \
|
||||
f" but types were {set([type(x) for x in inputs])}"
|
||||
|
||||
output = {}
|
||||
text = self._collate_text(inputs)
|
||||
wavs = self._collate_wavs(inputs)
|
||||
|
||||
assert set(text.keys() | wavs.keys()).issubset(set(self.conditioners.keys())), \
|
||||
f"got an unexpected attribute! Expected {self.conditioners.keys()}, got {text.keys(), wavs.keys()}"
|
||||
|
||||
for attribute, batch in chain(text.items(), wavs.items()):
|
||||
output[attribute] = self.conditioners[attribute].tokenize(batch)
|
||||
return output
|
||||
|
||||
def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
|
||||
"""Compute pairs of `(embedding, mask)` using the configured conditioners
|
||||
and the tokenized representations. The output is for example:
|
||||
|
||||
{
|
||||
"genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])),
|
||||
"description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])),
|
||||
...
|
||||
}
|
||||
|
||||
Args:
|
||||
tokenized (dict): Dict of tokenized representations as returned by `tokenize()`.
|
||||
"""
|
||||
output = {}
|
||||
for attribute, inputs in tokenized.items():
|
||||
condition, mask = self.conditioners[attribute](inputs)
|
||||
output[attribute] = (condition, mask)
|
||||
return output
|
||||
|
||||
def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]:
|
||||
"""Given a list of ConditioningAttributes objects, compile a dictionary where the keys
|
||||
are the attributes and the values are the aggregated input per attribute.
|
||||
For example:
|
||||
Input:
|
||||
[
|
||||
ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...),
|
||||
ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, wav=...),
|
||||
]
|
||||
Output:
|
||||
{
|
||||
"genre": ["Rock", "Hip-hop"],
|
||||
"description": ["A rock song with a guitar solo", "A hip-hop verse"]
|
||||
}
|
||||
"""
|
||||
batch_per_attribute: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list)
|
||||
|
||||
def _merge_conds(cond, merge_text_conditions_p=0, drop_desc_p=0):
|
||||
def is_valid(k, v):
|
||||
k_valid = k in ['key', 'bpm', 'genre', 'moods', 'instrument']
|
||||
v_valid = v is not None and isinstance(v, (int, float, str, list))
|
||||
return k_valid and v_valid
|
||||
|
||||
def process_value(v):
|
||||
if isinstance(v, (int, float, str)):
|
||||
return v
|
||||
if isinstance(v, list):
|
||||
return ", ".join(v)
|
||||
else:
|
||||
RuntimeError(f"unknown type for text value! ({type(v), v})")
|
||||
|
||||
desc = cond.text['description']
|
||||
meta_data = ""
|
||||
if random.uniform(0, 1) < merge_text_conditions_p:
|
||||
meta_pairs = [f'{k}: {process_value(v)}' for k, v in cond.text.items() if is_valid(k, v)]
|
||||
random.shuffle(meta_pairs)
|
||||
meta_data = ". ".join(meta_pairs)
|
||||
desc = desc if not random.uniform(0, 1) < drop_desc_p else None
|
||||
|
||||
if desc is None:
|
||||
desc = meta_data if len(meta_data) > 1 else None
|
||||
else:
|
||||
desc = desc.rstrip('.') + ". " + meta_data
|
||||
cond.text['description'] = desc.strip() if desc else None
|
||||
|
||||
if self.training and self.merge_text_conditions_p:
|
||||
for sample in samples:
|
||||
_merge_conds(sample, self.merge_text_conditions_p, self.drop_desc_p)
|
||||
|
||||
texts = [x.text for x in samples]
|
||||
for text in texts:
|
||||
for condition in self.text_conditions:
|
||||
batch_per_attribute[condition].append(text[condition])
|
||||
|
||||
return batch_per_attribute
|
||||
|
||||
def _collate_wavs(self, samples: tp.List[ConditioningAttributes]):
|
||||
"""Generate a dict where the keys are attributes by which we fetch similar wavs,
|
||||
and the values are Tensors of wavs according to said attribtues.
|
||||
|
||||
*Note*: by the time the samples reach this function, each sample should have some waveform
|
||||
inside the "wav" attribute. It should be either:
|
||||
1. A real waveform
|
||||
2. A null waveform due to the sample having no similar waveforms (nullified by the dataset)
|
||||
3. A null waveform due to it being dropped in a dropout module (nullified by dropout)
|
||||
|
||||
Args:
|
||||
samples (tp.List[ConditioningAttributes]): List of ConditioningAttributes samples.
|
||||
Returns:
|
||||
dict: A dicionary mapping an attribute name to wavs.
|
||||
"""
|
||||
wavs = defaultdict(list)
|
||||
lens = defaultdict(list)
|
||||
paths = defaultdict(list)
|
||||
out = {}
|
||||
|
||||
for sample in samples:
|
||||
for attribute in self.wav_conditions:
|
||||
wav, length, path = sample.wav[attribute]
|
||||
wavs[attribute].append(wav.flatten())
|
||||
lens[attribute].append(length)
|
||||
paths[attribute].append(path)
|
||||
|
||||
# stack all wavs to a single tensor
|
||||
for attribute in self.wav_conditions:
|
||||
stacked_wav, _ = collate(wavs[attribute], dim=0)
|
||||
out[attribute] = WavCondition(stacked_wav.unsqueeze(1),
|
||||
torch.cat(lens['self_wav']), paths[attribute]) # type: ignore
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ConditionFuser(StreamingModule):
|
||||
"""Condition fuser handles the logic to combine the different conditions
|
||||
to the actual model input.
|
||||
|
||||
Args:
|
||||
fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse
|
||||
each condition. For example:
|
||||
{
|
||||
"prepend": ["description"],
|
||||
"sum": ["genre", "bpm"],
|
||||
"cross": ["description"],
|
||||
}
|
||||
cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention.
|
||||
cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used.
|
||||
"""
|
||||
FUSING_METHODS = ["sum", "prepend", "cross", "input_interpolate"]
|
||||
|
||||
def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False,
|
||||
cross_attention_pos_emb_scale: float = 1.0):
|
||||
super().__init__()
|
||||
assert all(
|
||||
[k in self.FUSING_METHODS for k in fuse2cond.keys()]
|
||||
), f"got invalid fuse method, allowed methods: {self.FUSING_MEHTODS}"
|
||||
self.cross_attention_pos_emb = cross_attention_pos_emb
|
||||
self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale
|
||||
self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond
|
||||
self.cond2fuse: tp.Dict[str, str] = {}
|
||||
for fuse_method, conditions in fuse2cond.items():
|
||||
for condition in conditions:
|
||||
self.cond2fuse[condition] = fuse_method
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input: Tensor,
|
||||
conditions: tp.Dict[str, ConditionType]
|
||||
) -> tp.Tuple[Tensor, tp.Optional[Tensor]]:
|
||||
"""Fuse the conditions to the provided model input.
|
||||
|
||||
Args:
|
||||
input (Tensor): Transformer input.
|
||||
conditions (tp.Dict[str, ConditionType]): Dict of conditions.
|
||||
Returns:
|
||||
tp.Tuple[Tensor, Tensor]: The first tensor is the transformer input
|
||||
after the conditions have been fused. The second output tensor is the tensor
|
||||
used for cross-attention or None if no cross attention inputs exist.
|
||||
"""
|
||||
B, T, _ = input.shape
|
||||
|
||||
if 'offsets' in self._streaming_state:
|
||||
first_step = False
|
||||
offsets = self._streaming_state['offsets']
|
||||
else:
|
||||
first_step = True
|
||||
offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device)
|
||||
|
||||
assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \
|
||||
f"given conditions contain unknown attributes for fuser, " \
|
||||
f"expected {self.cond2fuse.keys()}, got {conditions.keys()}"
|
||||
cross_attention_output = None
|
||||
for cond_type, (cond, cond_mask) in conditions.items():
|
||||
op = self.cond2fuse[cond_type]
|
||||
if op == "sum":
|
||||
input += cond
|
||||
elif op == "input_interpolate":
|
||||
cond = rearrange(cond, "b t d -> b d t")
|
||||
cond = F.interpolate(cond, size=input.shape[1])
|
||||
input += rearrange(cond, "b d t -> b t d")
|
||||
elif op == "prepend":
|
||||
if first_step:
|
||||
input = torch.cat([cond, input], dim=1)
|
||||
elif op == "cross":
|
||||
if cross_attention_output is not None:
|
||||
cross_attention_output = torch.cat([cross_attention_output, cond], dim=1)
|
||||
else:
|
||||
cross_attention_output = cond
|
||||
else:
|
||||
raise ValueError(f"unknown op ({op})")
|
||||
|
||||
if self.cross_attention_pos_emb and cross_attention_output is not None:
|
||||
positions = torch.arange(
|
||||
cross_attention_output.shape[1],
|
||||
device=cross_attention_output.device
|
||||
).view(1, -1, 1)
|
||||
pos_emb = create_sin_embedding(positions, cross_attention_output.shape[-1])
|
||||
cross_attention_output = cross_attention_output + self.cross_attention_pos_emb_scale * pos_emb
|
||||
|
||||
if self._is_streaming:
|
||||
self._streaming_state['offsets'] = offsets + T
|
||||
|
||||
return input, cross_attention_output
|
|
@ -0,0 +1,245 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
import typing as tp
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.utils import spectral_norm, weight_norm
|
||||
|
||||
|
||||
CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
|
||||
'time_group_norm'])
|
||||
|
||||
|
||||
def apply_parametrization_norm(module: nn.Module, norm: str = 'none'):
|
||||
assert norm in CONV_NORMALIZATIONS
|
||||
if norm == 'weight_norm':
|
||||
return weight_norm(module)
|
||||
elif norm == 'spectral_norm':
|
||||
return spectral_norm(module)
|
||||
else:
|
||||
# We already check was in CONV_NORMALIZATION, so any other choice
|
||||
# doesn't need reparametrization.
|
||||
return module
|
||||
|
||||
|
||||
def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs):
|
||||
"""Return the proper normalization module. If causal is True, this will ensure the returned
|
||||
module is causal, or return an error if the normalization doesn't support causal evaluation.
|
||||
"""
|
||||
assert norm in CONV_NORMALIZATIONS
|
||||
if norm == 'time_group_norm':
|
||||
if causal:
|
||||
raise ValueError("GroupNorm doesn't support causal evaluation.")
|
||||
assert isinstance(module, nn.modules.conv._ConvNd)
|
||||
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
|
||||
else:
|
||||
return nn.Identity()
|
||||
|
||||
|
||||
def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
|
||||
padding_total: int = 0) -> int:
|
||||
"""See `pad_for_conv1d`.
|
||||
"""
|
||||
length = x.shape[-1]
|
||||
n_frames = (length - kernel_size + padding_total) / stride + 1
|
||||
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
||||
return ideal_length - length
|
||||
|
||||
|
||||
def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
|
||||
"""Pad for a convolution to make sure that the last window is full.
|
||||
Extra padding is added at the end. This is required to ensure that we can rebuild
|
||||
an output of the same length, as otherwise, even with padding, some time steps
|
||||
might get removed.
|
||||
For instance, with total padding = 4, kernel size = 4, stride = 2:
|
||||
0 0 1 2 3 4 5 0 0 # (0s are padding)
|
||||
1 2 3 # (output frames of a convolution, last 0 is never used)
|
||||
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
|
||||
1 2 3 4 # once you removed padding, we are missing one time step !
|
||||
"""
|
||||
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
||||
return F.pad(x, (0, extra_padding))
|
||||
|
||||
|
||||
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
|
||||
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
||||
If this is the case, we insert extra 0 padding to the right before the reflection happen.
|
||||
"""
|
||||
length = x.shape[-1]
|
||||
padding_left, padding_right = paddings
|
||||
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
||||
if mode == 'reflect':
|
||||
max_pad = max(padding_left, padding_right)
|
||||
extra_pad = 0
|
||||
if length <= max_pad:
|
||||
extra_pad = max_pad - length + 1
|
||||
x = F.pad(x, (0, extra_pad))
|
||||
padded = F.pad(x, paddings, mode, value)
|
||||
end = padded.shape[-1] - extra_pad
|
||||
return padded[..., :end]
|
||||
else:
|
||||
return F.pad(x, paddings, mode, value)
|
||||
|
||||
|
||||
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
|
||||
"""Remove padding from x, handling properly zero padding. Only for 1d!
|
||||
"""
|
||||
padding_left, padding_right = paddings
|
||||
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
||||
assert (padding_left + padding_right) <= x.shape[-1]
|
||||
end = x.shape[-1] - padding_right
|
||||
return x[..., padding_left: end]
|
||||
|
||||
|
||||
class NormConv1d(nn.Module):
|
||||
"""Wrapper around Conv1d and normalization applied to this conv
|
||||
to provide a uniform interface across normalization approaches.
|
||||
"""
|
||||
def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
||||
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
||||
super().__init__()
|
||||
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
|
||||
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
|
||||
self.norm_type = norm
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class NormConv2d(nn.Module):
|
||||
"""Wrapper around Conv2d and normalization applied to this conv
|
||||
to provide a uniform interface across normalization approaches.
|
||||
"""
|
||||
def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
||||
super().__init__()
|
||||
self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
|
||||
self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
|
||||
self.norm_type = norm
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class NormConvTranspose1d(nn.Module):
|
||||
"""Wrapper around ConvTranspose1d and normalization applied to this conv
|
||||
to provide a uniform interface across normalization approaches.
|
||||
"""
|
||||
def __init__(self, *args, causal: bool = False, norm: str = 'none',
|
||||
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
||||
super().__init__()
|
||||
self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
|
||||
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
|
||||
self.norm_type = norm
|
||||
|
||||
def forward(self, x):
|
||||
x = self.convtr(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class NormConvTranspose2d(nn.Module):
|
||||
"""Wrapper around ConvTranspose2d and normalization applied to this conv
|
||||
to provide a uniform interface across normalization approaches.
|
||||
"""
|
||||
def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
|
||||
super().__init__()
|
||||
self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
|
||||
self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.convtr(x)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class StreamableConv1d(nn.Module):
|
||||
"""Conv1d with some builtin handling of asymmetric or causal padding
|
||||
and normalization.
|
||||
"""
|
||||
def __init__(self, in_channels: int, out_channels: int,
|
||||
kernel_size: int, stride: int = 1, dilation: int = 1,
|
||||
groups: int = 1, bias: bool = True, causal: bool = False,
|
||||
norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
|
||||
pad_mode: str = 'reflect'):
|
||||
super().__init__()
|
||||
# warn user on unusual setup between dilation and stride
|
||||
if stride > 1 and dilation > 1:
|
||||
warnings.warn('StreamableConv1d has been initialized with stride > 1 and dilation > 1'
|
||||
f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
|
||||
self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
|
||||
dilation=dilation, groups=groups, bias=bias, causal=causal,
|
||||
norm=norm, norm_kwargs=norm_kwargs)
|
||||
self.causal = causal
|
||||
self.pad_mode = pad_mode
|
||||
|
||||
def forward(self, x):
|
||||
B, C, T = x.shape
|
||||
kernel_size = self.conv.conv.kernel_size[0]
|
||||
stride = self.conv.conv.stride[0]
|
||||
dilation = self.conv.conv.dilation[0]
|
||||
kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
|
||||
padding_total = kernel_size - stride
|
||||
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
||||
if self.causal:
|
||||
# Left padding for causal
|
||||
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
|
||||
else:
|
||||
# Asymmetric padding required for odd strides
|
||||
padding_right = padding_total // 2
|
||||
padding_left = padding_total - padding_right
|
||||
x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class StreamableConvTranspose1d(nn.Module):
|
||||
"""ConvTranspose1d with some builtin handling of asymmetric or causal padding
|
||||
and normalization.
|
||||
"""
|
||||
def __init__(self, in_channels: int, out_channels: int,
|
||||
kernel_size: int, stride: int = 1, causal: bool = False,
|
||||
norm: str = 'none', trim_right_ratio: float = 1.,
|
||||
norm_kwargs: tp.Dict[str, tp.Any] = {}):
|
||||
super().__init__()
|
||||
self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
|
||||
causal=causal, norm=norm, norm_kwargs=norm_kwargs)
|
||||
self.causal = causal
|
||||
self.trim_right_ratio = trim_right_ratio
|
||||
assert self.causal or self.trim_right_ratio == 1., \
|
||||
"`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
|
||||
assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
|
||||
|
||||
def forward(self, x):
|
||||
kernel_size = self.convtr.convtr.kernel_size[0]
|
||||
stride = self.convtr.convtr.stride[0]
|
||||
padding_total = kernel_size - stride
|
||||
|
||||
y = self.convtr(x)
|
||||
|
||||
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
|
||||
# removed at the very end, when keeping only the right length for the output,
|
||||
# as removing it here would require also passing the length at the matching layer
|
||||
# in the encoder.
|
||||
if self.causal:
|
||||
# Trim the padding on the right according to the specified ratio
|
||||
# if trim_right_ratio = 1.0, trim everything from right
|
||||
padding_right = math.ceil(padding_total * self.trim_right_ratio)
|
||||
padding_left = padding_total - padding_right
|
||||
y = unpad1d(y, (padding_left, padding_right))
|
||||
else:
|
||||
# Asymmetric padding required for odd strides
|
||||
padding_right = padding_total // 2
|
||||
padding_left = padding_total - padding_right
|
||||
y = unpad1d(y, (padding_left, padding_right))
|
||||
return y
|
|
@ -0,0 +1,25 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
class StreamableLSTM(nn.Module):
|
||||
"""LSTM without worrying about the hidden state, nor the layout of the data.
|
||||
Expects input as convolutional layout.
|
||||
"""
|
||||
def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
|
||||
super().__init__()
|
||||
self.skip = skip
|
||||
self.lstm = nn.LSTM(dimension, dimension, num_layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(2, 0, 1)
|
||||
y, _ = self.lstm(x)
|
||||
if self.skip:
|
||||
y = y + x
|
||||
y = y.permute(1, 2, 0)
|
||||
return y
|
|
@ -0,0 +1,124 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import typing as tp
|
||||
|
||||
from torch import nn
|
||||
import torch
|
||||
|
||||
|
||||
class XPos(nn.Module):
|
||||
"""Length-extrapolatable positional embedding (xPos) from [Sun et al 2022](https://arxiv.org/abs/2212.10554v1).
|
||||
This applies an exponential decay to the RoPE rotation matrix.
|
||||
|
||||
Args:
|
||||
dim (int): Embedding dimension.
|
||||
smoothing (float): Smoothing factor applied to the decay rates.
|
||||
base_scale (int): Base decay rate, given in terms of scaling time.
|
||||
device (torch.device or None): Device on which to initialize the module.
|
||||
dtype (torch.dtype): dtype to use to generate the embedding.
|
||||
"""
|
||||
def __init__(self, dim: int, smoothing: float = 0.4, base_scale: int = 512,
|
||||
device=None, dtype: torch.dtype = torch.float32):
|
||||
super().__init__()
|
||||
assert dim % 2 == 0
|
||||
assert dtype in [torch.float64, torch.float32]
|
||||
self.dtype = dtype
|
||||
self.base_scale = base_scale
|
||||
|
||||
half_dim = dim // 2
|
||||
adim = torch.arange(half_dim, device=device, dtype=dtype)
|
||||
decay_rates = (adim / half_dim + smoothing) / (1.0 + smoothing)
|
||||
self.register_buffer("decay_rates", decay_rates)
|
||||
self.decay: tp.Optional[torch.Tensor] = None
|
||||
|
||||
def get_decay(self, start: int, end: int):
|
||||
"""Create complex decay tensor, cache values for fast computation.
|
||||
"""
|
||||
if self.decay is None or end > self.decay.shape[0]:
|
||||
assert isinstance(self.decay_rates, torch.Tensor) # Satisfy type checker.
|
||||
idx = torch.arange(end, device=self.decay_rates.device, dtype=self.dtype)
|
||||
power = idx / self.base_scale
|
||||
scale = self.decay_rates ** power.unsqueeze(-1)
|
||||
self.decay = torch.polar(scale, torch.zeros_like(scale))
|
||||
return self.decay[start:end] # [T, C/2]
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
"""Rotary positional embedding (RoPE) from [Su et al 2022](https://arxiv.org/abs/2104.09864).
|
||||
|
||||
Args:
|
||||
dim (int): Embedding dimension (twice the number of frequencies).
|
||||
max_period (float): Maximum period of the rotation frequencies.
|
||||
xpos (bool): Use xPos, applies an exponential decay to rotation matrix.
|
||||
scale (float): Scale of positional embedding, set to 0 to deactivate.
|
||||
device (torch.device or None): Device on which to initialize the module.
|
||||
dtype (torch.dtype): dtype to use to generate the embedding.
|
||||
"""
|
||||
def __init__(self, dim: int, max_period: float = 10000.0, xpos: bool = False,
|
||||
scale: float = 1.0, device=None, dtype: torch.dtype = torch.float32):
|
||||
super().__init__()
|
||||
assert dim % 2 == 0
|
||||
self.scale = scale
|
||||
assert dtype in [torch.float64, torch.float32]
|
||||
self.dtype = dtype
|
||||
|
||||
adim = torch.arange(0, dim, 2, device=device, dtype=dtype)[: (dim // 2)]
|
||||
frequencies = 1.0 / (max_period ** (adim / dim))
|
||||
self.register_buffer("frequencies", frequencies)
|
||||
self.rotation: tp.Optional[torch.Tensor] = None
|
||||
|
||||
self.xpos = XPos(dim, device=device, dtype=dtype) if xpos else None
|
||||
|
||||
def get_rotation(self, start: int, end: int):
|
||||
"""Create complex rotation tensor, cache values for fast computation.
|
||||
"""
|
||||
if self.rotation is None or end > self.rotation.shape[0]:
|
||||
assert isinstance(self.frequencies, torch.Tensor) # Satisfy type checker.
|
||||
idx = torch.arange(end, device=self.frequencies.device, dtype=self.dtype)
|
||||
angles = torch.outer(idx, self.frequencies)
|
||||
self.rotation = torch.polar(torch.ones_like(angles), angles)
|
||||
return self.rotation[start:end]
|
||||
|
||||
def rotate(self, x: torch.Tensor, start: int = 0, invert_decay: bool = False):
|
||||
"""Apply rope rotation to query or key tensor.
|
||||
"""
|
||||
T = x.shape[1]
|
||||
rotation = self.get_rotation(start, start + T).unsqueeze(0).unsqueeze(2)
|
||||
|
||||
if self.xpos:
|
||||
decay = self.xpos.get_decay(start, start + T).unsqueeze(0).unsqueeze(2)
|
||||
else:
|
||||
decay = 1.0
|
||||
|
||||
if invert_decay:
|
||||
decay = decay ** -1
|
||||
|
||||
x_complex = torch.view_as_complex(x.to(self.dtype).reshape(*x.shape[:-1], -1, 2))
|
||||
scaled_rotation = (rotation * decay) * self.scale + (1.0 - self.scale)
|
||||
x_out = torch.view_as_real(x_complex * scaled_rotation).flatten(-2)
|
||||
|
||||
return x_out.type_as(x)
|
||||
|
||||
def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0):
|
||||
""" Apply rope rotation to both query and key tensors.
|
||||
Supports streaming mode, in which query and key are not expected to have the same shape.
|
||||
In streaming mode, key will be of legnth [P + C] with P the cached past timesteps, but
|
||||
query will be [C] (typically C == 1).
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): Query to rotate.
|
||||
key (torch.Tensor): Key to rotate.
|
||||
start (int): Start index of the sequence for time offset.
|
||||
"""
|
||||
query_timesteps = query.shape[1]
|
||||
key_timesteps = key.shape[1]
|
||||
streaming_offset = key_timesteps - query_timesteps
|
||||
|
||||
query_out = self.rotate(query, start + streaming_offset)
|
||||
key_out = self.rotate(key, start, invert_decay=True)
|
||||
|
||||
return query_out, key_out
|
|
@ -0,0 +1,258 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import typing as tp
|
||||
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
|
||||
from .conv import StreamableConv1d, StreamableConvTranspose1d
|
||||
from .lstm import StreamableLSTM
|
||||
|
||||
|
||||
class SEANetResnetBlock(nn.Module):
|
||||
"""Residual block from SEANet model.
|
||||
|
||||
Args:
|
||||
dim (int): Dimension of the input/output.
|
||||
kernel_sizes (list): List of kernel sizes for the convolutions.
|
||||
dilations (list): List of dilations for the convolutions.
|
||||
activation (str): Activation function.
|
||||
activation_params (dict): Parameters to provide to the activation function.
|
||||
norm (str): Normalization method.
|
||||
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
||||
causal (bool): Whether to use fully causal convolution.
|
||||
pad_mode (str): Padding mode for the convolutions.
|
||||
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
||||
true_skip (bool): Whether to use true skip connection or a simple
|
||||
(streamable) convolution as the skip connection.
|
||||
"""
|
||||
def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1],
|
||||
activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
|
||||
norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False,
|
||||
pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True):
|
||||
super().__init__()
|
||||
assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations'
|
||||
act = getattr(nn, activation)
|
||||
hidden = dim // compress
|
||||
block = []
|
||||
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
|
||||
in_chs = dim if i == 0 else hidden
|
||||
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
|
||||
block += [
|
||||
act(**activation_params),
|
||||
StreamableConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation,
|
||||
norm=norm, norm_kwargs=norm_params,
|
||||
causal=causal, pad_mode=pad_mode),
|
||||
]
|
||||
self.block = nn.Sequential(*block)
|
||||
self.shortcut: nn.Module
|
||||
if true_skip:
|
||||
self.shortcut = nn.Identity()
|
||||
else:
|
||||
self.shortcut = StreamableConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params,
|
||||
causal=causal, pad_mode=pad_mode)
|
||||
|
||||
def forward(self, x):
|
||||
return self.shortcut(x) + self.block(x)
|
||||
|
||||
|
||||
class SEANetEncoder(nn.Module):
|
||||
"""SEANet encoder.
|
||||
|
||||
Args:
|
||||
channels (int): Audio channels.
|
||||
dimension (int): Intermediate representation dimension.
|
||||
n_filters (int): Base width for the model.
|
||||
n_residual_layers (int): nb of residual layers.
|
||||
ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
|
||||
upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
|
||||
that must match the decoder order. We use the decoder order as some models may only employ the decoder.
|
||||
activation (str): Activation function.
|
||||
activation_params (dict): Parameters to provide to the activation function.
|
||||
norm (str): Normalization method.
|
||||
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
||||
kernel_size (int): Kernel size for the initial convolution.
|
||||
last_kernel_size (int): Kernel size for the initial convolution.
|
||||
residual_kernel_size (int): Kernel size for the residual layers.
|
||||
dilation_base (int): How much to increase the dilation with each layer.
|
||||
causal (bool): Whether to use fully causal convolution.
|
||||
pad_mode (str): Padding mode for the convolutions.
|
||||
true_skip (bool): Whether to use true skip connection or a simple
|
||||
(streamable) convolution as the skip connection in the residual network blocks.
|
||||
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
||||
lstm (int): Number of LSTM layers at the end of the encoder.
|
||||
disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
|
||||
For the encoder, it corresponds to the N first blocks.
|
||||
"""
|
||||
def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
|
||||
ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
|
||||
norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
|
||||
last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
|
||||
pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
|
||||
disable_norm_outer_blocks: int = 0):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.dimension = dimension
|
||||
self.n_filters = n_filters
|
||||
self.ratios = list(reversed(ratios))
|
||||
del ratios
|
||||
self.n_residual_layers = n_residual_layers
|
||||
self.hop_length = np.prod(self.ratios)
|
||||
self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
|
||||
self.disable_norm_outer_blocks = disable_norm_outer_blocks
|
||||
assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
|
||||
"Number of blocks for which to disable norm is invalid." \
|
||||
"It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
|
||||
|
||||
act = getattr(nn, activation)
|
||||
mult = 1
|
||||
model: tp.List[nn.Module] = [
|
||||
StreamableConv1d(channels, mult * n_filters, kernel_size,
|
||||
norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
|
||||
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
|
||||
]
|
||||
# Downsample to raw audio scale
|
||||
for i, ratio in enumerate(self.ratios):
|
||||
block_norm = 'none' if self.disable_norm_outer_blocks >= i + 2 else norm
|
||||
# Add residual layers
|
||||
for j in range(n_residual_layers):
|
||||
model += [
|
||||
SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1],
|
||||
dilations=[dilation_base ** j, 1],
|
||||
norm=block_norm, norm_params=norm_params,
|
||||
activation=activation, activation_params=activation_params,
|
||||
causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
|
||||
|
||||
# Add downsampling layers
|
||||
model += [
|
||||
act(**activation_params),
|
||||
StreamableConv1d(mult * n_filters, mult * n_filters * 2,
|
||||
kernel_size=ratio * 2, stride=ratio,
|
||||
norm=block_norm, norm_kwargs=norm_params,
|
||||
causal=causal, pad_mode=pad_mode),
|
||||
]
|
||||
mult *= 2
|
||||
|
||||
if lstm:
|
||||
model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
|
||||
|
||||
model += [
|
||||
act(**activation_params),
|
||||
StreamableConv1d(mult * n_filters, dimension, last_kernel_size,
|
||||
norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
|
||||
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
|
||||
]
|
||||
|
||||
self.model = nn.Sequential(*model)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
|
||||
class SEANetDecoder(nn.Module):
|
||||
"""SEANet decoder.
|
||||
|
||||
Args:
|
||||
channels (int): Audio channels.
|
||||
dimension (int): Intermediate representation dimension.
|
||||
n_filters (int): Base width for the model.
|
||||
n_residual_layers (int): nb of residual layers.
|
||||
ratios (Sequence[int]): kernel size and stride ratios.
|
||||
activation (str): Activation function.
|
||||
activation_params (dict): Parameters to provide to the activation function.
|
||||
final_activation (str): Final activation function after all convolutions.
|
||||
final_activation_params (dict): Parameters to provide to the activation function.
|
||||
norm (str): Normalization method.
|
||||
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
||||
kernel_size (int): Kernel size for the initial convolution.
|
||||
last_kernel_size (int): Kernel size for the initial convolution.
|
||||
residual_kernel_size (int): Kernel size for the residual layers.
|
||||
dilation_base (int): How much to increase the dilation with each layer.
|
||||
causal (bool): Whether to use fully causal convolution.
|
||||
pad_mode (str): Padding mode for the convolutions.
|
||||
true_skip (bool): Whether to use true skip connection or a simple.
|
||||
(streamable) convolution as the skip connection in the residual network blocks.
|
||||
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
||||
lstm (int): Number of LSTM layers at the end of the encoder.
|
||||
disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
|
||||
For the decoder, it corresponds to the N last blocks.
|
||||
trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
|
||||
If equal to 1.0, it means that all the trimming is done at the right.
|
||||
"""
|
||||
def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
|
||||
ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
|
||||
final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None,
|
||||
norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
|
||||
last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
|
||||
pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
|
||||
disable_norm_outer_blocks: int = 0, trim_right_ratio: float = 1.0):
|
||||
super().__init__()
|
||||
self.dimension = dimension
|
||||
self.channels = channels
|
||||
self.n_filters = n_filters
|
||||
self.ratios = ratios
|
||||
del ratios
|
||||
self.n_residual_layers = n_residual_layers
|
||||
self.hop_length = np.prod(self.ratios)
|
||||
self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
|
||||
self.disable_norm_outer_blocks = disable_norm_outer_blocks
|
||||
assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
|
||||
"Number of blocks for which to disable norm is invalid." \
|
||||
"It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
|
||||
|
||||
act = getattr(nn, activation)
|
||||
mult = int(2 ** len(self.ratios))
|
||||
model: tp.List[nn.Module] = [
|
||||
StreamableConv1d(dimension, mult * n_filters, kernel_size,
|
||||
norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
|
||||
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
|
||||
]
|
||||
|
||||
if lstm:
|
||||
model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
|
||||
|
||||
# Upsample to raw audio scale
|
||||
for i, ratio in enumerate(self.ratios):
|
||||
block_norm = 'none' if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1) else norm
|
||||
# Add upsampling layers
|
||||
model += [
|
||||
act(**activation_params),
|
||||
StreamableConvTranspose1d(mult * n_filters, mult * n_filters // 2,
|
||||
kernel_size=ratio * 2, stride=ratio,
|
||||
norm=block_norm, norm_kwargs=norm_params,
|
||||
causal=causal, trim_right_ratio=trim_right_ratio),
|
||||
]
|
||||
# Add residual layers
|
||||
for j in range(n_residual_layers):
|
||||
model += [
|
||||
SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1],
|
||||
dilations=[dilation_base ** j, 1],
|
||||
activation=activation, activation_params=activation_params,
|
||||
norm=block_norm, norm_params=norm_params, causal=causal,
|
||||
pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
|
||||
|
||||
mult //= 2
|
||||
|
||||
# Add final layers
|
||||
model += [
|
||||
act(**activation_params),
|
||||
StreamableConv1d(n_filters, channels, last_kernel_size,
|
||||
norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
|
||||
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
|
||||
]
|
||||
# Add optional final activation to decoder (eg. tanh)
|
||||
if final_activation is not None:
|
||||
final_act = getattr(nn, final_activation)
|
||||
final_activation_params = final_activation_params or {}
|
||||
model += [
|
||||
final_act(**final_activation_params)
|
||||
]
|
||||
self.model = nn.Sequential(*model)
|
||||
|
||||
def forward(self, z):
|
||||
y = self.model(z)
|
||||
return y
|
|
@ -0,0 +1,135 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Streaming module API that should be implemented by all Streaming components,
|
||||
"""
|
||||
|
||||
from contextlib import contextmanager
|
||||
import typing as tp
|
||||
from torch import nn
|
||||
import torch
|
||||
|
||||
|
||||
State = tp.Dict[str, torch.Tensor]
|
||||
|
||||
|
||||
class StreamingModule(nn.Module):
|
||||
"""Common API for streaming components.
|
||||
|
||||
Each streaming component has a streaming state, which is just a dict[str, Tensor].
|
||||
By convention, the first dim of each tensor must be the batch size.
|
||||
Don't use dots in the key names, as this would clash with submodules
|
||||
(like in state_dict).
|
||||
|
||||
If `self._is_streaming` is True, the component should use and remember
|
||||
the proper state inside `self._streaming_state`.
|
||||
|
||||
To set a streaming component in streaming state, use
|
||||
|
||||
with module.streaming():
|
||||
...
|
||||
|
||||
This will automatically reset the streaming state when exiting the context manager.
|
||||
This also automatically propagates to all streaming children module.
|
||||
|
||||
Some module might also implement the `StreamingModule.flush` method, although
|
||||
this one is trickier, as all parents module must be StreamingModule and implement
|
||||
it as well for it to work properly. See `StreamingSequential` after.
|
||||
"""
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._streaming_state: State = {}
|
||||
self._is_streaming = False
|
||||
|
||||
def _apply_named_streaming(self, fn: tp.Any):
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, StreamingModule):
|
||||
fn(name, module)
|
||||
|
||||
def _set_streaming(self, streaming: bool):
|
||||
def _set_streaming(name, module):
|
||||
module._is_streaming = streaming
|
||||
self._apply_named_streaming(_set_streaming)
|
||||
|
||||
@contextmanager
|
||||
def streaming(self):
|
||||
"""Context manager to enter streaming mode. Reset streaming state on exit.
|
||||
"""
|
||||
self._set_streaming(True)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._set_streaming(False)
|
||||
self.reset_streaming()
|
||||
|
||||
def reset_streaming(self):
|
||||
"""Reset the streaming state.
|
||||
"""
|
||||
def _reset(name: str, module: StreamingModule):
|
||||
module._streaming_state.clear()
|
||||
|
||||
self._apply_named_streaming(_reset)
|
||||
|
||||
def get_streaming_state(self) -> State:
|
||||
"""Return the streaming state, including that of sub-modules.
|
||||
"""
|
||||
state: State = {}
|
||||
|
||||
def _add(name: str, module: StreamingModule):
|
||||
if name:
|
||||
name += "."
|
||||
for key, value in module._streaming_state.items():
|
||||
state[name + key] = value
|
||||
|
||||
self._apply_named_streaming(_add)
|
||||
return state
|
||||
|
||||
def set_streaming_state(self, state: State):
|
||||
"""Set the streaming state, including that of sub-modules.
|
||||
"""
|
||||
state = dict(state)
|
||||
|
||||
def _set(name: str, module: StreamingModule):
|
||||
if name:
|
||||
name += "."
|
||||
module._streaming_state.clear()
|
||||
for key, value in list(state.items()):
|
||||
# complexity is not ideal here, but probably fine.
|
||||
if key.startswith(name):
|
||||
local_key = key[len(name):]
|
||||
if '.' not in local_key:
|
||||
module._streaming_state[local_key] = value
|
||||
del state[key]
|
||||
|
||||
self._apply_named_streaming(_set)
|
||||
assert len(state) == 0, list(state.keys())
|
||||
|
||||
def flush(self, x: tp.Optional[torch.Tensor] = None):
|
||||
"""Flush any remaining outputs that were waiting for completion.
|
||||
Typically, for convolutions, this will add the final padding
|
||||
and process the last buffer.
|
||||
|
||||
This should take an optional argument `x`, which will be provided
|
||||
if a module before this one in the streaming pipeline has already
|
||||
spitted out a flushed out buffer.
|
||||
"""
|
||||
if x is None:
|
||||
return None
|
||||
else:
|
||||
return self(x)
|
||||
|
||||
|
||||
class StreamingSequential(StreamingModule, nn.Sequential):
|
||||
"""A streaming compatible alternative of `nn.Sequential`.
|
||||
"""
|
||||
def flush(self, x: tp.Optional[torch.Tensor] = None):
|
||||
for module in self:
|
||||
if isinstance(module, StreamingModule):
|
||||
x = module.flush(x)
|
||||
elif x is not None:
|
||||
x = module(x)
|
||||
return x
|
|
@ -0,0 +1,747 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Transformer model, with streaming support, xformer attention support
|
||||
and easy causal attention with a potentially finite receptive field.
|
||||
|
||||
See `StreamingTransformer` for more information.
|
||||
|
||||
Unlike regular PyTorch Transformer, we make the hard choice that batches are first.
|
||||
"""
|
||||
|
||||
import typing as tp
|
||||
|
||||
from einops import rearrange
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.checkpoint import checkpoint as torch_checkpoint
|
||||
from xformers import ops
|
||||
|
||||
from .rope import RotaryEmbedding
|
||||
from .streaming import StreamingModule
|
||||
|
||||
_efficient_attention_backend: str = 'torch'
|
||||
|
||||
|
||||
def set_efficient_attention_backend(backend: str = 'torch'):
|
||||
# Using torch by default, it seems a bit faster on older P100 GPUs (~20% faster).
|
||||
global _efficient_attention_backend
|
||||
assert _efficient_attention_backend in ['xformers', 'torch']
|
||||
_efficient_attention_backend = backend
|
||||
|
||||
|
||||
def _get_attention_time_dimension() -> int:
|
||||
if _efficient_attention_backend == 'torch':
|
||||
return 2
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
def _is_profiled() -> bool:
|
||||
# Return true if we are currently running with a xformers profiler activated.
|
||||
try:
|
||||
from xformers.profiler import profiler
|
||||
except ImportError:
|
||||
return False
|
||||
return profiler._Profiler._CURRENT_PROFILER is not None
|
||||
|
||||
|
||||
def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module:
|
||||
"""Create normalization module for transformer encoder layer.
|
||||
|
||||
Args:
|
||||
norm_type (str): Normalization method.
|
||||
dim (int): Dimension of the normalized layer.
|
||||
**kwargs (dict): Additional parameters for normalization layer.
|
||||
Returns:
|
||||
nn.Module: Normalization module.
|
||||
"""
|
||||
if norm_type == 'layer_norm':
|
||||
return nn.LayerNorm(dim, eps=1e-5, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown norm type: {norm_type}")
|
||||
|
||||
|
||||
def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000,
|
||||
dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
||||
"""Create sinusoidal positional embedding, with shape `[B, T, C]`.
|
||||
|
||||
Args:
|
||||
positions (torch.Tensor): LongTensor of positions.
|
||||
dim (int): Dimension of the embedding.
|
||||
max_period (float): Maximum period of the cosine/sine functions.
|
||||
dtype (torch.dtype or str): dtype to use to generate the embedding.
|
||||
Returns:
|
||||
torch.Tensor: Sinusoidal positional embedding.
|
||||
"""
|
||||
# We aim for BTC format
|
||||
assert dim % 2 == 0
|
||||
half_dim = dim // 2
|
||||
positions = positions.to(dtype)
|
||||
adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1)
|
||||
max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype) # avoid sync point
|
||||
phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
|
||||
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
|
||||
|
||||
|
||||
def expand_repeated_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers"""
|
||||
if n_rep == 1:
|
||||
return x
|
||||
if _efficient_attention_backend == 'torch':
|
||||
bs, n_kv_heads, slen, head_dim = x.shape
|
||||
return (
|
||||
x[:, :, None, :, :]
|
||||
.expand(bs, n_kv_heads, n_rep, slen, head_dim)
|
||||
.reshape(bs, n_kv_heads * n_rep, slen, head_dim)
|
||||
)
|
||||
else:
|
||||
bs, slen, n_kv_heads, head_dim = x.shape
|
||||
return (
|
||||
x[:, :, :, None, :]
|
||||
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
|
||||
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
|
||||
)
|
||||
|
||||
|
||||
class LayerScale(nn.Module):
|
||||
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
|
||||
This rescales diagonaly the residual outputs close to 0, with a learnt scale.
|
||||
|
||||
Args:
|
||||
channels (int): Number of channels.
|
||||
init (float): Initial scale.
|
||||
channel_last (bool): If True, expect `[*, C]` shaped tensors, otherwise, `[*, C, T]`.
|
||||
device (torch.device or None): Device on which to initialize the module.
|
||||
dtype (torch.dtype or None): dtype to use to initialize the module.
|
||||
"""
|
||||
def __init__(self, channels: int, init: float = 1e-4, channel_last: bool = True,
|
||||
device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.channel_last = channel_last
|
||||
self.scale = nn.Parameter(
|
||||
torch.full((channels,), init,
|
||||
requires_grad=True, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.channel_last:
|
||||
return self.scale * x
|
||||
else:
|
||||
return self.scale[:, None] * x
|
||||
|
||||
|
||||
class StreamingMultiheadAttention(StreamingModule):
|
||||
"""Similar to `nn.MultiheadAttention` but with support for streaming, causal evaluation.
|
||||
|
||||
Args:
|
||||
embed_dim (int): Dimension to project to.
|
||||
num_heads (int): Number of heads.
|
||||
dropout (float): Dropout level.
|
||||
bias (bool): Use bias in projections.
|
||||
causal (bool): Causal mask applied automatically.
|
||||
past_context (int or None): Receptive field for the causal mask, infinite if None.
|
||||
custom (bool): Use custom MHA implementation, for testing / benchmarking.
|
||||
memory_efficient (bool): Use xformers based memory efficient attention.
|
||||
attention_as_float32 (bool): Perform the attention as float32
|
||||
(especially important with memory_efficient as autocast won't do this automatically).
|
||||
rope (`RotaryEmbedding` or None): Rope embedding to use.
|
||||
cross_attention: Should be true when used as a cross attention.
|
||||
All keys and values must be available at once, streaming is only for the queries.
|
||||
Cannot be used with `causal` or `rope` (as it wouldn't make sens to
|
||||
intepret the time steps in the keys relative to those in the queries).
|
||||
safe_streaming (bool): Bug fix, will go away with xformers update.
|
||||
qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product.
|
||||
kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
|
||||
This will lead to faster decoding time on A100 or other GPUs with tensorcore.
|
||||
device (torch.device or None): Sevice on which to initialize.
|
||||
dtype (torch.dtype or None): dtype to use.
|
||||
"""
|
||||
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True,
|
||||
causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False,
|
||||
memory_efficient: bool = False, attention_as_float32: bool = False,
|
||||
rope: tp.Optional[RotaryEmbedding] = None, cross_attention: bool = False,
|
||||
safe_streaming: bool = True, qk_layer_norm: bool = False, kv_repeat: int = 1,
|
||||
device=None, dtype=None):
|
||||
super().__init__()
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
if past_context is not None:
|
||||
assert causal
|
||||
|
||||
self.embed_dim = embed_dim
|
||||
self.causal = causal
|
||||
self.past_context = past_context
|
||||
self.memory_efficient = memory_efficient
|
||||
self.attention_as_float32 = attention_as_float32
|
||||
self.rope = rope
|
||||
self.cross_attention = cross_attention
|
||||
self.safe_streaming = safe_streaming
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
self.kv_repeat = kv_repeat
|
||||
if cross_attention:
|
||||
assert not causal, "Causal cannot work with cross attention."
|
||||
assert rope is None, "Rope cannot work with cross attention."
|
||||
|
||||
if memory_efficient:
|
||||
_verify_xformers_memory_efficient_compat()
|
||||
|
||||
self.custom = _is_custom(custom, memory_efficient)
|
||||
if self.custom:
|
||||
out_dim = embed_dim
|
||||
assert num_heads % kv_repeat == 0
|
||||
assert not cross_attention or kv_repeat == 1
|
||||
num_kv = num_heads // kv_repeat
|
||||
kv_dim = (embed_dim // num_heads) * num_kv
|
||||
out_dim += 2 * kv_dim
|
||||
in_proj = nn.Linear(embed_dim, out_dim, bias=bias, **factory_kwargs)
|
||||
# We try to follow the default PyTorch MHA convention, to easily compare results.
|
||||
self.in_proj_weight = in_proj.weight
|
||||
self.in_proj_bias = in_proj.bias
|
||||
if bias:
|
||||
self.in_proj_bias.data.zero_() # Following Pytorch convention
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
|
||||
if bias:
|
||||
self.out_proj.bias.data.zero_()
|
||||
else:
|
||||
assert not qk_layer_norm
|
||||
assert kv_repeat == 1
|
||||
self.mha = nn.MultiheadAttention(
|
||||
embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True,
|
||||
**factory_kwargs)
|
||||
self.qk_layer_norm = qk_layer_norm
|
||||
if qk_layer_norm:
|
||||
assert self.custom
|
||||
assert kv_repeat == 1
|
||||
ln_dim = embed_dim
|
||||
self.q_layer_norm = nn.LayerNorm(ln_dim)
|
||||
self.k_layer_norm = nn.LayerNorm(ln_dim)
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
|
||||
if not self.custom:
|
||||
# Support compat with regular MHA
|
||||
keys = [n for n, _ in self.mha.named_parameters()]
|
||||
for key in keys:
|
||||
if prefix + key in state_dict:
|
||||
state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key)
|
||||
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
||||
|
||||
def _get_mask(self, current_steps: int, device: torch.device, dtype: torch.dtype):
|
||||
# Return a causal mask, accounting for potentially stored past keys/values
|
||||
# We actually return a bias for the attention score, as this has the same
|
||||
# convention both in the builtin MHA in Pytorch, and Xformers functions.
|
||||
time_dim = _get_attention_time_dimension()
|
||||
if self.memory_efficient:
|
||||
from xformers.ops import LowerTriangularMask
|
||||
if current_steps == 1:
|
||||
# If we only have one step, then we do not need a mask.
|
||||
return None
|
||||
elif 'past_keys' in self._streaming_state:
|
||||
raise RuntimeError('Not supported at the moment')
|
||||
else:
|
||||
# Then we can safely use a lower triangular mask
|
||||
return LowerTriangularMask()
|
||||
if self._streaming_state:
|
||||
past_keys = self._streaming_state['past_keys']
|
||||
past_steps = past_keys.shape[time_dim]
|
||||
else:
|
||||
past_steps = 0
|
||||
|
||||
queries_pos = torch.arange(
|
||||
past_steps, current_steps + past_steps, device=device).view(-1, 1)
|
||||
keys_pos = torch.arange(past_steps + current_steps, device=device).view(1, -1)
|
||||
delta = queries_pos - keys_pos
|
||||
valid = delta >= 0
|
||||
if self.past_context is not None:
|
||||
valid &= (delta <= self.past_context)
|
||||
return torch.where(
|
||||
valid,
|
||||
torch.zeros([], device=device, dtype=dtype),
|
||||
torch.full([], float('-inf'), device=device, dtype=dtype))
|
||||
|
||||
def _complete_kv(self, k, v):
|
||||
time_dim = _get_attention_time_dimension()
|
||||
if self.cross_attention:
|
||||
# With cross attention we assume all keys and values
|
||||
# are already available, and streaming is with respect
|
||||
# to the queries only.
|
||||
return k, v
|
||||
# Complete the key/value pair using the streaming state.
|
||||
if self._streaming_state:
|
||||
pk = self._streaming_state['past_keys']
|
||||
nk = torch.cat([pk, k], dim=time_dim)
|
||||
if v is k:
|
||||
nv = nk
|
||||
else:
|
||||
pv = self._streaming_state['past_values']
|
||||
nv = torch.cat([pv, v], dim=time_dim)
|
||||
else:
|
||||
nk = k
|
||||
nv = v
|
||||
|
||||
assert nk.shape[time_dim] == nv.shape[time_dim]
|
||||
offset = 0
|
||||
if self.past_context is not None:
|
||||
offset = max(0, nk.shape[time_dim] - self.past_context)
|
||||
if self._is_streaming:
|
||||
self._streaming_state['past_keys'] = nk[:, offset:]
|
||||
if v is not k:
|
||||
self._streaming_state['past_values'] = nv[:, offset:]
|
||||
if 'offset' in self._streaming_state:
|
||||
self._streaming_state['offset'] += offset
|
||||
else:
|
||||
self._streaming_state['offset'] = torch.tensor(0)
|
||||
return nk, nv
|
||||
|
||||
def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
|
||||
# TODO: fix and verify layout.
|
||||
assert _efficient_attention_backend == 'xformers', 'Rope not supported with torch attn.'
|
||||
# Apply rope embeddings to query and key tensors.
|
||||
assert self.rope is not None
|
||||
if 'past_keys' in self._streaming_state:
|
||||
past_keys_offset = self._streaming_state['past_keys'].shape[1]
|
||||
else:
|
||||
past_keys_offset = 0
|
||||
if 'offset' in self._streaming_state:
|
||||
past_context_offset = int(self._streaming_state['offset'].item())
|
||||
else:
|
||||
past_context_offset = 0
|
||||
streaming_offset = past_context_offset + past_keys_offset
|
||||
return self.rope.rotate_qk(query, key, start=streaming_offset)
|
||||
|
||||
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
|
||||
key_padding_mask=None, need_weights=False, attn_mask=None,
|
||||
average_attn_weights=True, is_causal=False):
|
||||
assert attn_mask is None
|
||||
assert not is_causal, ("new param added in torch 2.0.1 not supported, "
|
||||
"use the causal args in the constructor.")
|
||||
|
||||
time_dim = _get_attention_time_dimension()
|
||||
if time_dim == 2:
|
||||
layout = "b h t d"
|
||||
else:
|
||||
layout = "b t h d"
|
||||
dtype = query.dtype
|
||||
if self._is_streaming:
|
||||
assert self.causal or self.cross_attention, \
|
||||
"Streaming only available for causal or cross attention"
|
||||
|
||||
if self.causal:
|
||||
# At the moment we specialize only for the self-attention case.
|
||||
assert query.shape[1] == key.shape[1], "Causal only for same length query / key / value"
|
||||
assert value.shape[1] == key.shape[1], "Causal only for same length query / key / value"
|
||||
attn_mask = self._get_mask(query.shape[1], query.device, query.dtype)
|
||||
|
||||
if self.custom:
|
||||
# custom implementation
|
||||
assert need_weights is False
|
||||
assert key_padding_mask is None
|
||||
if self.cross_attention:
|
||||
# Different queries, keys, values, we have to spit manually the weights
|
||||
# before applying the linear.
|
||||
dim = self.in_proj_weight.shape[0] // 3
|
||||
if self.in_proj_bias is None:
|
||||
bias_q, bias_k, bias_v = None, None, None
|
||||
else:
|
||||
bias_q = self.in_proj_bias[:dim]
|
||||
bias_k = self.in_proj_bias[dim: 2 * dim]
|
||||
bias_v = self.in_proj_bias[2 * dim:]
|
||||
q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q)
|
||||
# todo: when streaming, we could actually save k, v and check the shape actually match.
|
||||
k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim], bias_k)
|
||||
v = nn.functional.linear(value, self.in_proj_weight[2 * dim:], bias_v)
|
||||
if self.qk_layer_norm is True:
|
||||
q = self.q_layer_norm(q)
|
||||
k = self.k_layer_norm(k)
|
||||
q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
|
||||
else:
|
||||
if not _is_profiled():
|
||||
# profiling breaks that propertysomehow.
|
||||
assert query is key, "specialized implementation"
|
||||
assert value is key, "specialized implementation"
|
||||
projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
|
||||
if self.kv_repeat == 1:
|
||||
if time_dim == 2:
|
||||
bound_layout = "b h p t d"
|
||||
else:
|
||||
bound_layout = "b t p h d"
|
||||
packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
|
||||
q, k, v = ops.unbind(packed, dim=2)
|
||||
else:
|
||||
embed_dim = self.embed_dim
|
||||
per_head_dim = (embed_dim // self.num_heads)
|
||||
kv_heads = self.num_heads // self.kv_repeat
|
||||
q = projected[:, :, :embed_dim]
|
||||
start = embed_dim
|
||||
end = start + per_head_dim * kv_heads
|
||||
k = projected[:, :, start: end]
|
||||
v = projected[:, :, end:]
|
||||
q = rearrange(q, f"b t (h d) -> {layout}", h=self.num_heads)
|
||||
k = rearrange(k, f"b t (h d) -> {layout}", h=kv_heads)
|
||||
v = rearrange(v, f"b t (h d) -> {layout}", h=kv_heads)
|
||||
|
||||
if self.qk_layer_norm is True:
|
||||
assert self.kv_repeat == 1
|
||||
q, k = [rearrange(x, f"{layout} -> b t (h d)") for x in [q, k]]
|
||||
q = self.q_layer_norm(q)
|
||||
k = self.k_layer_norm(k)
|
||||
q, k = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k]]
|
||||
if self.rope:
|
||||
q, k = self._apply_rope(q, k)
|
||||
k, v = self._complete_kv(k, v)
|
||||
if self.kv_repeat > 1:
|
||||
k = expand_repeated_kv(k, self.kv_repeat)
|
||||
v = expand_repeated_kv(v, self.kv_repeat)
|
||||
if self.attention_as_float32:
|
||||
q, k, v = [x.float() for x in [q, k, v]]
|
||||
if self.memory_efficient:
|
||||
p = self.dropout if self.training else 0
|
||||
if _efficient_attention_backend == 'torch':
|
||||
x = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, is_causal=attn_mask is not None, dropout_p=p)
|
||||
else:
|
||||
x = ops.memory_efficient_attention(q, k, v, attn_mask, p=p)
|
||||
else:
|
||||
# We include the dot product as float32, for consistency
|
||||
# with the other implementations that include that step
|
||||
# as part of the attention. Note that when using `autocast`,
|
||||
# the einsums would be done as bfloat16, but the softmax
|
||||
# would be done as bfloat16, so `attention_as_float32` will
|
||||
# extend a bit the range of operations done in float32,
|
||||
# although this should make no difference.
|
||||
q = q / q.shape[-1] ** 0.5
|
||||
key_layout = layout.replace('t', 'k')
|
||||
query_layout = layout
|
||||
if self._is_streaming and self.safe_streaming and q.device.type == 'cuda':
|
||||
with torch.autocast(device_type=q.device.type, dtype=torch.float32):
|
||||
pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
|
||||
else:
|
||||
pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
|
||||
if attn_mask is not None:
|
||||
pre_w = pre_w + attn_mask
|
||||
w = torch.softmax(pre_w, dim=-1)
|
||||
w = F.dropout(w, self.dropout, training=self.training).to(v)
|
||||
# Key and value have the same format.
|
||||
x = torch.einsum(f"b h t k, {key_layout} -> {layout}", w, v)
|
||||
x = x.to(dtype)
|
||||
x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
|
||||
x = self.out_proj(x)
|
||||
else:
|
||||
key, value = self._complete_kv(key, value)
|
||||
if self.attention_as_float32:
|
||||
query, key, value = [x.float() for x in [query, key, value]]
|
||||
x, _ = self.mha(
|
||||
query, key, value, key_padding_mask,
|
||||
need_weights, attn_mask, average_attn_weights)
|
||||
x = x.to(dtype)
|
||||
|
||||
return x, None
|
||||
|
||||
|
||||
class StreamingTransformerLayer(nn.TransformerEncoderLayer):
|
||||
"""TransformerLayer with Streaming / Causal support.
|
||||
This also integrates cross_attention, when passing `cross_attention=True`,
|
||||
rather than having two separate classes like in PyTorch.
|
||||
|
||||
Args:
|
||||
d_model (int): Dimension of the data.
|
||||
num_heads (int): Number of heads.
|
||||
dim_feedforward (int): Intermediate dimension of FF module.
|
||||
dropout (float): Dropout both for MHA and FF.
|
||||
bias_ff (bool): Use bias for FF.
|
||||
bias_attn (bool): Use bias for MHA.
|
||||
causal (bool): Causal mask applied automatically.
|
||||
past_context (int or None): Receptive field for the causal mask, infinite if None.
|
||||
custom (bool): Use custom MHA implementation, for testing / benchmarking.
|
||||
memory_efficient (bool): Use xformers based memory efficient attention.
|
||||
attention_as_float32 (bool): Perform the attention as float32
|
||||
(especially important with memory_efficient as autocast won't do this automatically).
|
||||
qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product in attention.
|
||||
qk_layer_norm_cross (bool): Same for the cross attention.
|
||||
cross_attention (bool): If True, expect to get secondary input for cross-attention.
|
||||
Cross attention will use the default MHA, as it typically won't require
|
||||
special treatment.
|
||||
layer_scale (float or None): If not None, LayerScale will be used with
|
||||
the given value as initial scale.
|
||||
rope (`RotaryEmbedding` or None): Rope embedding to use.
|
||||
attention_dropout (float or None): If not None, separate the value of the dimension dropout
|
||||
in FFN and of the attention dropout.
|
||||
kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
|
||||
This will lead to faster decoding time on A100 or other GPUs with tensorcore.
|
||||
device (torch.device or None): Device on which to initialize.
|
||||
dtype (torch.dtype or None): dtype to use.
|
||||
**kwargs: See `nn.TransformerEncoderLayer`.
|
||||
"""
|
||||
def __init__(self, d_model: int, num_heads: int, dim_feedforward: int = 2048, dropout: float = 0.1,
|
||||
bias_ff: bool = True, bias_attn: bool = True, causal: bool = False,
|
||||
past_context: tp.Optional[int] = None, custom: bool = False,
|
||||
memory_efficient: bool = False, attention_as_float32: bool = False,
|
||||
qk_layer_norm: bool = False, qk_layer_norm_cross: bool = False,
|
||||
cross_attention: bool = False, layer_scale: tp.Optional[float] = None,
|
||||
rope: tp.Optional[RotaryEmbedding] = None, attention_dropout: tp.Optional[float] = None,
|
||||
kv_repeat: int = 1, norm: str = 'layer_norm', device=None, dtype=None, **kwargs):
|
||||
super().__init__(d_model, num_heads, dim_feedforward, dropout,
|
||||
device=device, dtype=dtype, batch_first=True, **kwargs)
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
# Redefine self_attn to our streaming multi-head attention
|
||||
attn_kwargs: tp.Dict[str, tp.Any] = {
|
||||
'embed_dim': d_model,
|
||||
'num_heads': num_heads,
|
||||
'dropout': dropout if attention_dropout is None else attention_dropout,
|
||||
'bias': bias_attn,
|
||||
'custom': custom,
|
||||
'memory_efficient': memory_efficient,
|
||||
'attention_as_float32': attention_as_float32,
|
||||
}
|
||||
self.self_attn: StreamingMultiheadAttention = StreamingMultiheadAttention(
|
||||
causal=causal, past_context=past_context, rope=rope, qk_layer_norm=qk_layer_norm,
|
||||
kv_repeat=kv_repeat, **attn_kwargs, **factory_kwargs) # type: ignore
|
||||
# Redefine feedforward layers to expose bias parameter
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias_ff, **factory_kwargs)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias_ff, **factory_kwargs)
|
||||
|
||||
self.layer_scale_1: nn.Module
|
||||
self.layer_scale_2: nn.Module
|
||||
if layer_scale is None:
|
||||
self.layer_scale_1 = nn.Identity()
|
||||
self.layer_scale_2 = nn.Identity()
|
||||
else:
|
||||
self.layer_scale_1 = LayerScale(d_model, layer_scale, **factory_kwargs)
|
||||
self.layer_scale_2 = LayerScale(d_model, layer_scale, **factory_kwargs)
|
||||
|
||||
self.cross_attention: tp.Optional[nn.Module] = None
|
||||
if cross_attention:
|
||||
self.cross_attention = StreamingMultiheadAttention(
|
||||
cross_attention=True, qk_layer_norm=qk_layer_norm_cross,
|
||||
**attn_kwargs, **factory_kwargs)
|
||||
# Norm and dropout
|
||||
self.dropout_cross = nn.Dropout(dropout)
|
||||
# eps value matching that used in PyTorch reference implementation.
|
||||
self.norm_cross = nn.LayerNorm(d_model, eps=1e-5, **factory_kwargs)
|
||||
self.layer_scale_cross: nn.Module
|
||||
if layer_scale is None:
|
||||
self.layer_scale_cross = nn.Identity()
|
||||
else:
|
||||
self.layer_scale_cross = LayerScale(d_model, layer_scale, **factory_kwargs)
|
||||
self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
|
||||
self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
|
||||
|
||||
def _cross_attention_block(self, src: torch.Tensor,
|
||||
cross_attention_src: torch.Tensor) -> torch.Tensor:
|
||||
assert self.cross_attention is not None
|
||||
# queries are from src, keys and values from cross_attention_src.
|
||||
x = self.cross_attention(
|
||||
src, cross_attention_src, cross_attention_src, need_weights=False)[0]
|
||||
return self.dropout_cross(x) # type: ignore
|
||||
|
||||
def forward(self, src: torch.Tensor, src_mask: tp.Optional[torch.Tensor] = None, # type: ignore
|
||||
src_key_padding_mask: tp.Optional[torch.Tensor] = None,
|
||||
cross_attention_src: tp.Optional[torch.Tensor] = None):
|
||||
if self.cross_attention is None:
|
||||
assert cross_attention_src is None
|
||||
else:
|
||||
assert cross_attention_src is not None
|
||||
x = src
|
||||
if self.norm_first:
|
||||
x = x + self.layer_scale_1(
|
||||
self._sa_block(self.norm1(x), src_mask, src_key_padding_mask))
|
||||
if cross_attention_src is not None:
|
||||
x = x + self.layer_scale_cross(
|
||||
self._cross_attention_block(
|
||||
self.norm_cross(x), cross_attention_src))
|
||||
x = x + self.layer_scale_2(self._ff_block(self.norm2(x)))
|
||||
else:
|
||||
x = self.norm1(x + self.layer_scale_1(
|
||||
self._sa_block(x, src_mask, src_key_padding_mask)))
|
||||
if cross_attention_src is not None:
|
||||
x = self.norm_cross(
|
||||
x + self.layer_scale_cross(
|
||||
self._cross_attention_block(src, cross_attention_src)))
|
||||
x = self.norm2(x + self.layer_scale_2(self._ff_block(x)))
|
||||
return x
|
||||
|
||||
|
||||
class StreamingTransformer(StreamingModule):
|
||||
"""Transformer with Streaming / Causal support.
|
||||
|
||||
Args:
|
||||
d_model (int): Dimension of the data.
|
||||
num_heads (int): Number of heads.
|
||||
dim_feedforward (int): Intermediate dimension of FF module.
|
||||
dropout (float): Dropout both for MHA and FF.
|
||||
bias_ff (bool): Use bias for FF.
|
||||
bias_attn (bool): Use bias for MHA.
|
||||
causal (bool): Causal mask applied automatically.
|
||||
past_context (int or None): Receptive field for the causal mask, infinite if None.
|
||||
custom (bool): Use custom MHA implementation, for testing / benchmarking.
|
||||
memory_efficient (bool): Use xformers based memory efficient attention.
|
||||
attention_as_float32 (bool): Perform the attention as float32
|
||||
(especially important with memory_efficient as autocast won't do this automatically).
|
||||
cross_attention (bool): If True, expect to get secondary input for cross-attention.
|
||||
layer_scale (float or None): If not None, LayerScale will be used
|
||||
with the given value as initial scale.
|
||||
positional_embedding (str): Positional embedding strategy (sin, rope, or sin_rope).
|
||||
max_period (float): Maximum period of the time embedding.
|
||||
positional_scale (float): Scale of positional embedding, set to 0 to deactivate.
|
||||
xpos (bool): Apply xpos exponential decay to positional embedding (rope only).
|
||||
lr (float or None): learning rate override through the `make_optim_group` API.
|
||||
weight_decay (float or None): Weight_decay override through the `make_optim_group` API.
|
||||
layer_class: (subclass of `StreamingTransformerLayer): class to use
|
||||
to initialize the layers, allowing further customization outside of Audiocraft.
|
||||
checkpointing (str): Checkpointing strategy to reduce memory usage.
|
||||
No checkpointing if set to 'none'. Per layer checkpointing using PyTorch
|
||||
if set to 'torch' (entire layer checkpointed, i.e. linears are evaluated twice,
|
||||
minimal memory usage, but maximal runtime). Finally, `xformers_default` provide
|
||||
a policy for opting-out some operations of the checkpointing like
|
||||
linear layers and attention, providing a middle ground between speed and memory.
|
||||
device (torch.device or None): Device on which to initialize.
|
||||
dtype (torch.dtype or None): dtype to use.
|
||||
**kwargs: See `nn.TransformerEncoderLayer`.
|
||||
"""
|
||||
def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_feedforward: int = 2048,
|
||||
dropout: float = 0.1, bias_ff: bool = True, bias_attn: bool = True,
|
||||
causal: bool = False, past_context: tp.Optional[int] = None,
|
||||
custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False,
|
||||
cross_attention: bool = False, layer_scale: tp.Optional[float] = None,
|
||||
positional_embedding: str = 'sin', max_period: float = 10_000, positional_scale: float = 1.,
|
||||
xpos: bool = False, lr: tp.Optional[float] = None, weight_decay: tp.Optional[float] = None,
|
||||
layer_class: tp.Type[StreamingTransformerLayer] = StreamingTransformerLayer,
|
||||
checkpointing: str = 'none', device=None, dtype=None, **kwargs):
|
||||
super().__init__()
|
||||
assert d_model % num_heads == 0
|
||||
|
||||
self.positional_embedding = positional_embedding
|
||||
self.max_period = max_period
|
||||
self.positional_scale = positional_scale
|
||||
self.weight_decay = weight_decay
|
||||
self.lr = lr
|
||||
|
||||
assert positional_embedding in ['sin', 'rope', 'sin_rope']
|
||||
self.rope: tp.Optional[RotaryEmbedding] = None
|
||||
if self.positional_embedding in ['rope', 'sin_rope']:
|
||||
assert _is_custom(custom, memory_efficient)
|
||||
self.rope = RotaryEmbedding(d_model // num_heads, max_period=max_period,
|
||||
xpos=xpos, scale=positional_scale, device=device)
|
||||
|
||||
self.checkpointing = checkpointing
|
||||
|
||||
assert checkpointing in ['none', 'torch', 'xformers_default', 'xformers_mm']
|
||||
if self.checkpointing.startswith('xformers'):
|
||||
_verify_xformers_internal_compat()
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
for idx in range(num_layers):
|
||||
self.layers.append(
|
||||
layer_class(
|
||||
d_model=d_model, num_heads=num_heads, dim_feedforward=dim_feedforward,
|
||||
dropout=dropout, bias_ff=bias_ff, bias_attn=bias_attn,
|
||||
causal=causal, past_context=past_context, custom=custom,
|
||||
memory_efficient=memory_efficient, attention_as_float32=attention_as_float32,
|
||||
cross_attention=cross_attention, layer_scale=layer_scale, rope=self.rope,
|
||||
device=device, dtype=dtype, **kwargs))
|
||||
|
||||
if self.checkpointing != 'none':
|
||||
for layer in self.layers:
|
||||
# see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the
|
||||
# backward hook inside of FSDP...
|
||||
layer._magma_checkpointed = True # type: ignore
|
||||
assert layer.layer_drop == 0., "Need further checking" # type: ignore
|
||||
|
||||
def _apply_layer(self, layer, *args, **kwargs):
|
||||
method = self.checkpointing
|
||||
if method == 'none':
|
||||
return layer(*args, **kwargs)
|
||||
elif method == 'torch':
|
||||
return torch_checkpoint(layer, *args, use_reentrant=False, **kwargs)
|
||||
elif method.startswith('xformers'):
|
||||
from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy
|
||||
if method == 'xformers_default':
|
||||
# those operations will be saved, and not recomputed.
|
||||
# According to Francisco we can get smarter policies but this is a good start.
|
||||
allow_list = [
|
||||
"xformers.efficient_attention_forward_cutlass.default",
|
||||
"xformers_flash.flash_fwd.default",
|
||||
"aten.addmm.default",
|
||||
"aten.mm.default",
|
||||
]
|
||||
elif method == 'xformers_mm':
|
||||
# those operations will be saved, and not recomputed.
|
||||
# According to Francisco we can get smarter policies but this is a good start.
|
||||
allow_list = [
|
||||
"aten.addmm.default",
|
||||
"aten.mm.default",
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"xformers checkpointing xformers policy {method} is not known.")
|
||||
policy_fn = _get_default_policy(allow_list)
|
||||
return checkpoint(layer, *args, policy_fn=policy_fn, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Checkpointing method {method} is unknown.")
|
||||
|
||||
def forward(self, x: torch.Tensor, *args, **kwargs):
|
||||
B, T, C = x.shape
|
||||
|
||||
if 'offsets' in self._streaming_state:
|
||||
offsets = self._streaming_state['offsets']
|
||||
else:
|
||||
offsets = torch.zeros(B, dtype=torch.long, device=x.device)
|
||||
|
||||
if self.positional_embedding in ['sin', 'sin_rope']:
|
||||
positions = torch.arange(T, device=x.device).view(1, -1, 1)
|
||||
positions = positions + offsets.view(-1, 1, 1)
|
||||
pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
|
||||
x = x + self.positional_scale * pos_emb
|
||||
|
||||
for layer in self.layers:
|
||||
x = self._apply_layer(layer, x, *args, **kwargs)
|
||||
|
||||
if self._is_streaming:
|
||||
self._streaming_state['offsets'] = offsets + T
|
||||
|
||||
return x
|
||||
|
||||
def make_optim_group(self):
|
||||
group = {"params": list(self.parameters())}
|
||||
if self.lr is not None:
|
||||
group["lr"] = self.lr
|
||||
if self.weight_decay is not None:
|
||||
group["weight_decay"] = self.weight_decay
|
||||
return group
|
||||
|
||||
|
||||
# special attention attention related function
|
||||
|
||||
def _verify_xformers_memory_efficient_compat():
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention, LowerTriangularMask # noqa
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"xformers is not installed. Please install it and try again.\n"
|
||||
"To install on AWS and Azure, run \n"
|
||||
"FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n"
|
||||
"pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n"
|
||||
"To install on FAIR Cluster, run \n"
|
||||
"FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n"
|
||||
"pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n")
|
||||
|
||||
|
||||
def _verify_xformers_internal_compat():
|
||||
try:
|
||||
from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy # noqa
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Francisco's fairinternal xformers is not installed. Please install it and try again.\n"
|
||||
"To install on AWS and Azure, run \n"
|
||||
"FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n"
|
||||
"pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n"
|
||||
"To install on FAIR Cluster, run \n"
|
||||
"FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n"
|
||||
"pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n")
|
||||
|
||||
|
||||
def _is_custom(custom: bool, memory_efficient: bool):
|
||||
return custom or memory_efficient
|
|
@ -0,0 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# flake8: noqa
|
||||
from .vq import ResidualVectorQuantizer
|
||||
from .base import BaseQuantizer, DummyQuantizer, QuantizedResult
|
|
@ -0,0 +1,107 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Base class for all quantizers.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
import typing as tp
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuantizedResult:
|
||||
x: torch.Tensor
|
||||
codes: torch.Tensor
|
||||
bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
|
||||
penalty: tp.Optional[torch.Tensor] = None
|
||||
metrics: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
class BaseQuantizer(nn.Module):
|
||||
"""Base class for quantizers.
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult:
|
||||
"""
|
||||
Given input tensor x, returns first the quantized (or approximately quantized)
|
||||
representation along with quantized codes, bandwidth, and any penalty term for the loss.
|
||||
Finally, this returns a dict of metrics to update logging etc.
|
||||
Frame rate must be passed so that the bandwidth is properly computed.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
||||
"""Decode the given codes to the quantized representation.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def total_codebooks(self):
|
||||
"""Total number of codebooks.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def num_codebooks(self):
|
||||
"""Number of active codebooks.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def set_num_codebooks(self, n: int):
|
||||
"""Set the number of active codebooks.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DummyQuantizer(BaseQuantizer):
|
||||
"""Fake quantizer that actually does not perform any quantization.
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x: torch.Tensor, frame_rate: int):
|
||||
q = x.unsqueeze(1)
|
||||
return QuantizedResult(x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x))
|
||||
|
||||
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
|
||||
In the case of the DummyQuantizer, the codes are actually identical
|
||||
to the input and resulting quantized representation as no quantization is done.
|
||||
"""
|
||||
return x.unsqueeze(1)
|
||||
|
||||
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
||||
"""Decode the given codes to the quantized representation.
|
||||
In the case of the DummyQuantizer, the codes are actually identical
|
||||
to the input and resulting quantized representation as no quantization is done.
|
||||
"""
|
||||
return codes.squeeze(1)
|
||||
|
||||
@property
|
||||
def total_codebooks(self):
|
||||
"""Total number of codebooks.
|
||||
"""
|
||||
return 1
|
||||
|
||||
@property
|
||||
def num_codebooks(self):
|
||||
"""Total number of codebooks.
|
||||
"""
|
||||
return self.total_codebooks
|
||||
|
||||
def set_num_codebooks(self, n: int):
|
||||
"""Set the number of active codebooks.
|
||||
"""
|
||||
raise AttributeError("Cannot override the number of codebooks for the dummy quantizer")
|
|
@ -0,0 +1,400 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import typing as tp
|
||||
|
||||
from einops import rearrange, repeat
|
||||
import flashy
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def exists(val: tp.Optional[tp.Any]) -> bool:
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val: tp.Any, d: tp.Any) -> tp.Any:
|
||||
return val if exists(val) else d
|
||||
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, p=2, dim=-1)
|
||||
|
||||
|
||||
def ema_inplace(moving_avg, new, decay: float):
|
||||
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
||||
|
||||
|
||||
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
|
||||
return (x + epsilon) / (x.sum() + n_categories * epsilon)
|
||||
|
||||
|
||||
def uniform_init(*shape: int):
|
||||
t = torch.empty(shape)
|
||||
nn.init.kaiming_uniform_(t)
|
||||
return t
|
||||
|
||||
|
||||
def sample_vectors(samples, num: int):
|
||||
num_samples, device = samples.shape[0], samples.device
|
||||
|
||||
if num_samples >= num:
|
||||
indices = torch.randperm(num_samples, device=device)[:num]
|
||||
else:
|
||||
indices = torch.randint(0, num_samples, (num,), device=device)
|
||||
|
||||
return samples[indices]
|
||||
|
||||
|
||||
def kmeans(samples, num_clusters: int, num_iters: int = 10):
|
||||
dim, dtype = samples.shape[-1], samples.dtype
|
||||
|
||||
means = sample_vectors(samples, num_clusters)
|
||||
|
||||
for _ in range(num_iters):
|
||||
diffs = rearrange(samples, "n d -> n () d") - rearrange(
|
||||
means, "c d -> () c d"
|
||||
)
|
||||
dists = -(diffs ** 2).sum(dim=-1)
|
||||
|
||||
buckets = dists.max(dim=-1).indices
|
||||
bins = torch.bincount(buckets, minlength=num_clusters)
|
||||
zero_mask = bins == 0
|
||||
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
||||
|
||||
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
||||
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
|
||||
new_means = new_means / bins_min_clamped[..., None]
|
||||
|
||||
means = torch.where(zero_mask[..., None], means, new_means)
|
||||
|
||||
return means, bins
|
||||
|
||||
|
||||
def orthgonal_loss_fn(t):
|
||||
# eq (2) from https://arxiv.org/abs/2112.00384
|
||||
n = t.shape[0]
|
||||
normed_codes = l2norm(t)
|
||||
identity = torch.eye(n, device=t.device)
|
||||
cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes)
|
||||
return ((cosine_sim - identity) ** 2).sum() / (n ** 2)
|
||||
|
||||
|
||||
class EuclideanCodebook(nn.Module):
|
||||
"""Codebook with Euclidean distance.
|
||||
|
||||
Args:
|
||||
dim (int): Dimension.
|
||||
codebook_size (int): Codebook size.
|
||||
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
|
||||
If set to true, run the k-means algorithm on the first training batch and use
|
||||
the learned centroids as initialization.
|
||||
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
|
||||
decay (float): Decay for exponential moving average over the codebooks.
|
||||
epsilon (float): Epsilon value for numerical stability.
|
||||
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
||||
that have an exponential moving average cluster size less than the specified threshold with
|
||||
randomly selected vector from the current batch.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
codebook_size: int,
|
||||
kmeans_init: int = False,
|
||||
kmeans_iters: int = 10,
|
||||
decay: float = 0.8,
|
||||
epsilon: float = 1e-5,
|
||||
threshold_ema_dead_code: int = 2,
|
||||
):
|
||||
super().__init__()
|
||||
self.decay = decay
|
||||
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
|
||||
embed = init_fn(codebook_size, dim)
|
||||
|
||||
self.codebook_size = codebook_size
|
||||
|
||||
self.kmeans_iters = kmeans_iters
|
||||
self.epsilon = epsilon
|
||||
self.threshold_ema_dead_code = threshold_ema_dead_code
|
||||
|
||||
self.register_buffer("inited", torch.Tensor([not kmeans_init]))
|
||||
self.register_buffer("cluster_size", torch.zeros(codebook_size))
|
||||
self.register_buffer("embed", embed)
|
||||
self.register_buffer("embed_avg", embed.clone())
|
||||
|
||||
@torch.jit.ignore
|
||||
def init_embed_(self, data):
|
||||
if self.inited:
|
||||
return
|
||||
|
||||
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
||||
self.embed.data.copy_(embed)
|
||||
self.embed_avg.data.copy_(embed.clone())
|
||||
self.cluster_size.data.copy_(cluster_size)
|
||||
self.inited.data.copy_(torch.Tensor([True]))
|
||||
# Make sure all buffers across workers are in sync after initialization
|
||||
flashy.distrib.broadcast_tensors(self.buffers())
|
||||
|
||||
def replace_(self, samples, mask):
|
||||
modified_codebook = torch.where(
|
||||
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
|
||||
)
|
||||
self.embed.data.copy_(modified_codebook)
|
||||
|
||||
def expire_codes_(self, batch_samples):
|
||||
if self.threshold_ema_dead_code == 0:
|
||||
return
|
||||
|
||||
expired_codes = self.cluster_size < self.threshold_ema_dead_code
|
||||
if not torch.any(expired_codes):
|
||||
return
|
||||
|
||||
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
||||
self.replace_(batch_samples, mask=expired_codes)
|
||||
flashy.distrib.broadcast_tensors(self.buffers())
|
||||
|
||||
def preprocess(self, x):
|
||||
x = rearrange(x, "... d -> (...) d")
|
||||
return x
|
||||
|
||||
def quantize(self, x):
|
||||
embed = self.embed.t()
|
||||
dist = -(
|
||||
x.pow(2).sum(1, keepdim=True)
|
||||
- 2 * x @ embed
|
||||
+ embed.pow(2).sum(0, keepdim=True)
|
||||
)
|
||||
embed_ind = dist.max(dim=-1).indices
|
||||
return embed_ind
|
||||
|
||||
def postprocess_emb(self, embed_ind, shape):
|
||||
return embed_ind.view(*shape[:-1])
|
||||
|
||||
def dequantize(self, embed_ind):
|
||||
quantize = F.embedding(embed_ind, self.embed)
|
||||
return quantize
|
||||
|
||||
def encode(self, x):
|
||||
shape = x.shape
|
||||
# pre-process
|
||||
x = self.preprocess(x)
|
||||
# quantize
|
||||
embed_ind = self.quantize(x)
|
||||
# post-process
|
||||
embed_ind = self.postprocess_emb(embed_ind, shape)
|
||||
return embed_ind
|
||||
|
||||
def decode(self, embed_ind):
|
||||
quantize = self.dequantize(embed_ind)
|
||||
return quantize
|
||||
|
||||
def forward(self, x):
|
||||
shape, dtype = x.shape, x.dtype
|
||||
x = self.preprocess(x)
|
||||
self.init_embed_(x)
|
||||
|
||||
embed_ind = self.quantize(x)
|
||||
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
|
||||
embed_ind = self.postprocess_emb(embed_ind, shape)
|
||||
quantize = self.dequantize(embed_ind)
|
||||
|
||||
if self.training:
|
||||
# We do the expiry of code at that point as buffers are in sync
|
||||
# and all the workers will take the same decision.
|
||||
self.expire_codes_(x)
|
||||
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
|
||||
embed_sum = x.t() @ embed_onehot
|
||||
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
||||
cluster_size = (
|
||||
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
|
||||
* self.cluster_size.sum()
|
||||
)
|
||||
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
||||
self.embed.data.copy_(embed_normalized)
|
||||
|
||||
return quantize, embed_ind
|
||||
|
||||
|
||||
class VectorQuantization(nn.Module):
|
||||
"""Vector quantization implementation.
|
||||
Currently supports only euclidean distance.
|
||||
|
||||
Args:
|
||||
dim (int): Dimension
|
||||
codebook_size (int): Codebook size
|
||||
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
|
||||
decay (float): Decay for exponential moving average over the codebooks.
|
||||
epsilon (float): Epsilon value for numerical stability.
|
||||
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
||||
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
||||
threshold_ema_dead_code (int):
|
||||
channels_last (bool): Channels are the last dimension in the input tensors.
|
||||
commitment_weight (float): Weight for commitment loss.
|
||||
orthogonal_reg_weight (float): Orthogonal regularization weights.
|
||||
orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
|
||||
orthogonal_reg_max_codes (optional int): Maximum number of codes to consider
|
||||
for orthogonal regulariation.
|
||||
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
||||
that have an exponential moving average cluster size less than the specified threshold with
|
||||
randomly selected vector from the current batch.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
codebook_size: int,
|
||||
codebook_dim: tp.Optional[int] = None,
|
||||
decay: float = 0.8,
|
||||
epsilon: float = 1e-5,
|
||||
kmeans_init: bool = False,
|
||||
kmeans_iters: int = 10,
|
||||
threshold_ema_dead_code: int = 2,
|
||||
channels_last: bool = False,
|
||||
commitment_weight: float = 1.,
|
||||
orthogonal_reg_weight: float = 0.0,
|
||||
orthogonal_reg_active_codes_only: bool = False,
|
||||
orthogonal_reg_max_codes: tp.Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
_codebook_dim: int = default(codebook_dim, dim)
|
||||
|
||||
requires_projection = _codebook_dim != dim
|
||||
self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
|
||||
self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
|
||||
|
||||
self.epsilon = epsilon
|
||||
self.commitment_weight = commitment_weight
|
||||
|
||||
self.orthogonal_reg_weight = orthogonal_reg_weight
|
||||
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
|
||||
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
|
||||
|
||||
self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
|
||||
kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
|
||||
decay=decay, epsilon=epsilon,
|
||||
threshold_ema_dead_code=threshold_ema_dead_code)
|
||||
self.codebook_size = codebook_size
|
||||
|
||||
self.channels_last = channels_last
|
||||
|
||||
@property
|
||||
def codebook(self):
|
||||
return self._codebook.embed
|
||||
|
||||
@property
|
||||
def inited(self):
|
||||
return self._codebook.inited
|
||||
|
||||
def _preprocess(self, x):
|
||||
if not self.channels_last:
|
||||
x = rearrange(x, "b d n -> b n d")
|
||||
return x
|
||||
|
||||
def _postprocess(self, quantize):
|
||||
if not self.channels_last:
|
||||
quantize = rearrange(quantize, "b n d -> b d n")
|
||||
return quantize
|
||||
|
||||
def encode(self, x):
|
||||
x = self._preprocess(x)
|
||||
x = self.project_in(x)
|
||||
embed_in = self._codebook.encode(x)
|
||||
return embed_in
|
||||
|
||||
def decode(self, embed_ind):
|
||||
quantize = self._codebook.decode(embed_ind)
|
||||
quantize = self.project_out(quantize)
|
||||
quantize = self._postprocess(quantize)
|
||||
return quantize
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
x = self._preprocess(x)
|
||||
|
||||
x = self.project_in(x)
|
||||
quantize, embed_ind = self._codebook(x)
|
||||
|
||||
if self.training:
|
||||
quantize = x + (quantize - x).detach()
|
||||
|
||||
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
|
||||
|
||||
if self.training:
|
||||
if self.commitment_weight > 0:
|
||||
commit_loss = F.mse_loss(quantize.detach(), x)
|
||||
loss = loss + commit_loss * self.commitment_weight
|
||||
|
||||
if self.orthogonal_reg_weight > 0:
|
||||
codebook = self.codebook
|
||||
|
||||
if self.orthogonal_reg_active_codes_only:
|
||||
# only calculate orthogonal loss for the activated codes for this batch
|
||||
unique_code_ids = torch.unique(embed_ind)
|
||||
codebook = codebook[unique_code_ids]
|
||||
|
||||
num_codes = codebook.shape[0]
|
||||
if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes:
|
||||
rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes]
|
||||
codebook = codebook[rand_ids]
|
||||
|
||||
orthogonal_reg_loss = orthgonal_loss_fn(codebook)
|
||||
loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
|
||||
|
||||
quantize = self.project_out(quantize)
|
||||
quantize = self._postprocess(quantize)
|
||||
|
||||
return quantize, embed_ind, loss
|
||||
|
||||
|
||||
class ResidualVectorQuantization(nn.Module):
|
||||
"""Residual vector quantization implementation.
|
||||
|
||||
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
|
||||
"""
|
||||
def __init__(self, *, num_quantizers, **kwargs):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(
|
||||
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
|
||||
)
|
||||
|
||||
def forward(self, x, n_q: tp.Optional[int] = None):
|
||||
quantized_out = 0.0
|
||||
residual = x
|
||||
|
||||
all_losses = []
|
||||
all_indices = []
|
||||
|
||||
n_q = n_q or len(self.layers)
|
||||
|
||||
for i, layer in enumerate(self.layers[:n_q]):
|
||||
quantized, indices, loss = layer(residual)
|
||||
residual = residual - quantized
|
||||
quantized_out = quantized_out + quantized
|
||||
all_indices.append(indices)
|
||||
all_losses.append(loss)
|
||||
|
||||
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
|
||||
return quantized_out, out_indices, out_losses
|
||||
|
||||
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
|
||||
residual = x
|
||||
all_indices = []
|
||||
n_q = n_q or len(self.layers)
|
||||
for layer in self.layers[:n_q]:
|
||||
indices = layer.encode(residual)
|
||||
quantized = layer.decode(indices)
|
||||
residual = residual - quantized
|
||||
all_indices.append(indices)
|
||||
out_indices = torch.stack(all_indices)
|
||||
return out_indices
|
||||
|
||||
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
|
||||
quantized_out = torch.tensor(0.0, device=q_indices.device)
|
||||
for i, indices in enumerate(q_indices):
|
||||
layer = self.layers[i]
|
||||
quantized = layer.decode(indices)
|
||||
quantized_out = quantized_out + quantized
|
||||
return quantized_out
|
|
@ -0,0 +1,116 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
import typing as tp
|
||||
|
||||
import torch
|
||||
|
||||
from .base import BaseQuantizer, QuantizedResult
|
||||
from .core_vq import ResidualVectorQuantization
|
||||
|
||||
|
||||
class ResidualVectorQuantizer(BaseQuantizer):
|
||||
"""Residual Vector Quantizer.
|
||||
|
||||
Args:
|
||||
dimension (int): Dimension of the codebooks.
|
||||
n_q (int): Number of residual vector quantizers used.
|
||||
q_dropout (bool): Random quantizer drop out at train time.
|
||||
bins (int): Codebook size.
|
||||
decay (float): Decay for exponential moving average over the codebooks.
|
||||
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
||||
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
||||
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
||||
that have an exponential moving average cluster size less than the specified threshold with
|
||||
randomly selected vector from the current batch.
|
||||
orthogonal_reg_weight (float): Orthogonal regularization weights.
|
||||
orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
|
||||
orthogonal_reg_max_codes (optional int): Maximum number of codes to consider.
|
||||
for orthogonal regulariation.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
dimension: int = 256,
|
||||
n_q: int = 8,
|
||||
q_dropout: bool = False,
|
||||
bins: int = 1024,
|
||||
decay: float = 0.99,
|
||||
kmeans_init: bool = True,
|
||||
kmeans_iters: int = 10,
|
||||
threshold_ema_dead_code: int = 2,
|
||||
orthogonal_reg_weight: float = 0.0,
|
||||
orthogonal_reg_active_codes_only: bool = False,
|
||||
orthogonal_reg_max_codes: tp.Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.max_n_q = n_q
|
||||
self.n_q = n_q
|
||||
self.q_dropout = q_dropout
|
||||
self.dimension = dimension
|
||||
self.bins = bins
|
||||
self.decay = decay
|
||||
self.kmeans_init = kmeans_init
|
||||
self.kmeans_iters = kmeans_iters
|
||||
self.threshold_ema_dead_code = threshold_ema_dead_code
|
||||
self.orthogonal_reg_weight = orthogonal_reg_weight
|
||||
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
|
||||
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
|
||||
self.vq = ResidualVectorQuantization(
|
||||
dim=self.dimension,
|
||||
codebook_size=self.bins,
|
||||
num_quantizers=self.n_q,
|
||||
decay=self.decay,
|
||||
kmeans_init=self.kmeans_init,
|
||||
kmeans_iters=self.kmeans_iters,
|
||||
threshold_ema_dead_code=self.threshold_ema_dead_code,
|
||||
orthogonal_reg_weight=self.orthogonal_reg_weight,
|
||||
orthogonal_reg_active_codes_only=self.orthogonal_reg_active_codes_only,
|
||||
orthogonal_reg_max_codes=self.orthogonal_reg_max_codes,
|
||||
channels_last=False
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, frame_rate: int):
|
||||
n_q = self.n_q
|
||||
if self.training and self.q_dropout:
|
||||
n_q = int(torch.randint(1, self.n_q + 1, (1,)).item())
|
||||
bw_per_q = math.log2(self.bins) * frame_rate / 1000
|
||||
quantized, codes, commit_loss = self.vq(x, n_q=n_q)
|
||||
codes = codes.transpose(0, 1)
|
||||
# codes is [B, K, T], with T frames, K nb of codebooks.
|
||||
bw = torch.tensor(n_q * bw_per_q).to(x)
|
||||
return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
|
||||
|
||||
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Encode a given input tensor with the specified frame rate at the given bandwidth.
|
||||
The RVQ encode method sets the appropriate number of quantizer to use
|
||||
and returns indices for each quantizer.
|
||||
"""
|
||||
n_q = self.n_q
|
||||
codes = self.vq.encode(x, n_q=n_q)
|
||||
codes = codes.transpose(0, 1)
|
||||
# codes is [B, K, T], with T frames, K nb of codebooks.
|
||||
return codes
|
||||
|
||||
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
||||
"""Decode the given codes to the quantized representation.
|
||||
"""
|
||||
# codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T].
|
||||
codes = codes.transpose(0, 1)
|
||||
quantized = self.vq.decode(codes)
|
||||
return quantized
|
||||
|
||||
@property
|
||||
def total_codebooks(self):
|
||||
return self.max_n_q
|
||||
|
||||
@property
|
||||
def num_codebooks(self):
|
||||
return self.n_q
|
||||
|
||||
def set_num_codebooks(self, n: int):
|
||||
assert n > 0 and n <= self.max_n_q
|
||||
self.n_q = n
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class TorchAutocast:
|
||||
"""TorchAutocast utility class.
|
||||
Allows you to enable and disable autocast. This is specially useful
|
||||
when dealing with different architectures and clusters with different
|
||||
levels of support.
|
||||
|
||||
Args:
|
||||
enabled (bool): Whether to enable torch.autocast or not.
|
||||
args: Additional args for torch.autocast.
|
||||
kwargs: Additional kwargs for torch.autocast
|
||||
"""
|
||||
def __init__(self, enabled: bool, *args, **kwargs):
|
||||
self.autocast = torch.autocast(*args, **kwargs) if enabled else None
|
||||
|
||||
def __enter__(self):
|
||||
if self.autocast is None:
|
||||
return
|
||||
try:
|
||||
self.autocast.__enter__()
|
||||
except RuntimeError:
|
||||
device = self.autocast.device
|
||||
dtype = self.autocast.fast_dtype
|
||||
raise RuntimeError(
|
||||
f"There was an error autocasting with dtype={dtype} device={device}\n"
|
||||
"If you are on the FAIR Cluster, you might need to use autocast_dtype=float16"
|
||||
)
|
||||
|
||||
def __exit__(self, *args, **kwargs):
|
||||
if self.autocast is None:
|
||||
return
|
||||
self.autocast.__exit__(*args, **kwargs)
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Utility to export a training checkpoint to a lightweight release checkpoint.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
import typing as tp
|
||||
|
||||
from omegaconf import OmegaConf, DictConfig
|
||||
import torch
|
||||
|
||||
|
||||
def _clean_lm_cfg(cfg: DictConfig):
|
||||
OmegaConf.set_struct(cfg, False)
|
||||
# This used to be set automatically in the LM solver, need a more robust solution
|
||||
# for the future.
|
||||
cfg['transformer_lm']['card'] = 2048
|
||||
cfg['transformer_lm']['n_q'] = 4
|
||||
# Experimental params no longer supported.
|
||||
bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters',
|
||||
'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop']
|
||||
for name in bad_params:
|
||||
del cfg['transformer_lm'][name]
|
||||
OmegaConf.set_struct(cfg, True)
|
||||
return cfg
|
||||
|
||||
|
||||
def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
|
||||
sig = Path(checkpoint_path).parent.name
|
||||
assert len(sig) == 8, "Not a valid Dora signature"
|
||||
pkg = torch.load(checkpoint_path, 'cpu')
|
||||
new_pkg = {
|
||||
'best_state': pkg['ema']['state']['model'],
|
||||
'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
|
||||
}
|
||||
out_file = Path(out_folder) / f'{sig}.th'
|
||||
torch.save(new_pkg, out_file)
|
||||
return out_file
|
||||
|
||||
|
||||
def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
|
||||
sig = Path(checkpoint_path).parent.name
|
||||
assert len(sig) == 8, "Not a valid Dora signature"
|
||||
pkg = torch.load(checkpoint_path, 'cpu')
|
||||
new_pkg = {
|
||||
'best_state': pkg['fsdp_best_state']['model'],
|
||||
'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg']))
|
||||
}
|
||||
out_file = Path(out_folder) / f'{sig}.th'
|
||||
torch.save(new_pkg, out_file)
|
||||
return out_file
|
|
@ -0,0 +1,32 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
try:
|
||||
import IPython.display as ipd # type: ignore
|
||||
except ImportError:
|
||||
# Note in a notebook...
|
||||
pass
|
||||
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def display_audio(samples: torch.Tensor, sample_rate: int):
|
||||
"""Renders an audio player for the given audio samples.
|
||||
|
||||
Args:
|
||||
samples (torch.Tensor): a Tensor of decoded audio samples
|
||||
with shapes [B, C, T] or [C, T]
|
||||
sample_rate (int): sample rate audio should be displayed with.
|
||||
"""
|
||||
assert samples.dim() == 2 or samples.dim() == 3
|
||||
|
||||
samples = samples.detach().cpu()
|
||||
if samples.dim() == 2:
|
||||
samples = samples[None, ...]
|
||||
|
||||
for audio in samples:
|
||||
ipd.display(ipd.Audio(audio, rate=sample_rate))
|
|
@ -0,0 +1,234 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from functools import wraps
|
||||
import hashlib
|
||||
import logging
|
||||
import typing as tp
|
||||
|
||||
import flashy
|
||||
import flashy.distrib
|
||||
import omegaconf
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def dict_from_config(cfg: omegaconf.DictConfig) -> dict:
|
||||
"""Convenience function to map an omegaconf configuration to a dictionary.
|
||||
|
||||
Args:
|
||||
cfg (omegaconf.DictConfig): Original configuration to map to dict.
|
||||
Returns:
|
||||
dict: Config as dictionary object.
|
||||
"""
|
||||
dct = omegaconf.OmegaConf.to_container(cfg, resolve=True)
|
||||
assert isinstance(dct, dict)
|
||||
return dct
|
||||
|
||||
|
||||
def random_subset(dataset, max_samples: int, seed: int = 42) -> torch.utils.data.Subset:
|
||||
if max_samples >= len(dataset):
|
||||
return dataset
|
||||
|
||||
generator = torch.Generator().manual_seed(seed)
|
||||
perm = torch.randperm(len(dataset), generator=generator)
|
||||
return torch.utils.data.Subset(dataset, perm[:max_samples].tolist())
|
||||
|
||||
|
||||
def get_loader(dataset, num_samples: tp.Optional[int], batch_size: int,
|
||||
num_workers: int, seed: int, **kwargs) -> torch.utils.data.DataLoader:
|
||||
"""Convenience function to load dataset into a dataloader with optional subset sampling.
|
||||
|
||||
Args:
|
||||
dataset: Dataset to load.
|
||||
num_samples (Optional[int]): Number of samples to limit subset size.
|
||||
batch_size (int): Batch size.
|
||||
num_workers (int): Number of workers for data loading.
|
||||
seed (int): Random seed.
|
||||
"""
|
||||
if num_samples is not None:
|
||||
dataset = random_subset(dataset, num_samples, seed)
|
||||
|
||||
dataloader = flashy.distrib.loader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
**kwargs
|
||||
)
|
||||
return dataloader
|
||||
|
||||
|
||||
def get_dataset_from_loader(dataloader):
|
||||
dataset = dataloader.dataset
|
||||
if isinstance(dataset, torch.utils.data.Subset):
|
||||
return dataset.dataset
|
||||
else:
|
||||
return dataset
|
||||
|
||||
|
||||
def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
|
||||
"""torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): The input tensor containing probabilities.
|
||||
num_samples (int): Number of samples to draw.
|
||||
replacement (bool): Whether to draw with replacement or not.
|
||||
Keywords args:
|
||||
generator (torch.Generator): A pseudorandom number generator for sampling.
|
||||
Returns:
|
||||
torch.Tensor: Last dimension contains num_samples indices
|
||||
sampled from the multinomial probability distribution
|
||||
located in the last dimension of tensor input.
|
||||
"""
|
||||
input_ = input.reshape(-1, input.shape[-1])
|
||||
output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
|
||||
output = output_.reshape(*list(input.shape[:-1]), -1)
|
||||
return output
|
||||
|
||||
|
||||
def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
|
||||
"""Sample next token from top K values along the last dimension of the input probs tensor.
|
||||
|
||||
Args:
|
||||
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
|
||||
k (int): The k in “top-k”.
|
||||
Returns:
|
||||
torch.Tensor: Sampled tokens.
|
||||
"""
|
||||
top_k_value, _ = torch.topk(probs, k, dim=-1)
|
||||
min_value_top_k = top_k_value[..., [-1]]
|
||||
probs *= (probs >= min_value_top_k).float()
|
||||
probs.div_(probs.sum(dim=-1, keepdim=True))
|
||||
next_token = multinomial(probs, num_samples=1)
|
||||
return next_token
|
||||
|
||||
|
||||
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
|
||||
"""Sample next token from top P probabilities along the last dimension of the input probs tensor.
|
||||
|
||||
Args:
|
||||
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
|
||||
p (int): The p in “top-p”.
|
||||
Returns:
|
||||
torch.Tensor: Sampled tokens.
|
||||
"""
|
||||
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
mask = probs_sum - probs_sort > p
|
||||
probs_sort *= (~mask).float()
|
||||
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
||||
next_token = multinomial(probs_sort, num_samples=1)
|
||||
next_token = torch.gather(probs_idx, -1, next_token)
|
||||
return next_token
|
||||
|
||||
|
||||
class DummyPoolExecutor:
|
||||
"""Dummy pool executor to use when we actually have only 1 worker.
|
||||
(e.g. instead of ProcessPoolExecutor).
|
||||
"""
|
||||
class DummyResult:
|
||||
def __init__(self, func, *args, **kwargs):
|
||||
self.func = func
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
def result(self):
|
||||
return self.func(*self.args, **self.kwargs)
|
||||
|
||||
def __init__(self, workers, mp_context=None):
|
||||
pass
|
||||
|
||||
def submit(self, func, *args, **kwargs):
|
||||
return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_tb):
|
||||
return
|
||||
|
||||
|
||||
def get_pool_executor(num_workers: int, mp_context=None):
|
||||
return ProcessPoolExecutor(num_workers, mp_context) if num_workers > 1 else DummyPoolExecutor(1)
|
||||
|
||||
|
||||
def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> torch.Tensor:
|
||||
"""Utility function to convert a tensor of sequence lengths to a mask (useful when working on padded sequences).
|
||||
For example: [3, 5] => [[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]]
|
||||
|
||||
Args:
|
||||
lengths (torch.Tensor): tensor with lengths
|
||||
max_len (int): can set the max length manually. Defaults to None.
|
||||
Returns:
|
||||
torch.Tensor: mask with 0s where there is pad tokens else 1s
|
||||
"""
|
||||
assert len(lengths.shape) == 1, "Length shape should be 1 dimensional."
|
||||
final_length = lengths.max().item() if not max_len else max_len
|
||||
final_length = max(final_length, 1) # if all seqs are of len zero we don't want a zero-size tensor
|
||||
return torch.arange(final_length)[None, :].to(lengths.device) < lengths[:, None]
|
||||
|
||||
|
||||
def hash_trick(word: str, vocab_size: int) -> int:
|
||||
"""Hash trick to pair each word with an index
|
||||
|
||||
Args:
|
||||
word (str): word we wish to convert to an index
|
||||
vocab_size (int): size of the vocabulary
|
||||
Returns:
|
||||
int: index of the word in the embedding LUT
|
||||
"""
|
||||
hash = int(hashlib.sha256(word.encode("utf-8")).hexdigest(), 16)
|
||||
return hash % vocab_size
|
||||
|
||||
|
||||
def with_rank_rng(base_seed: int = 1234):
|
||||
"""Decorator for a function so that the function will use a Random Number Generator
|
||||
whose state depend on the GPU rank. The original RNG state is restored upon returning.
|
||||
|
||||
Args:
|
||||
base_seed (int): Random seed.
|
||||
"""
|
||||
def _decorator(fun: tp.Callable):
|
||||
@wraps(fun)
|
||||
def _decorated(*args, **kwargs):
|
||||
state = torch.get_rng_state()
|
||||
seed = base_seed ^ flashy.distrib.rank()
|
||||
torch.manual_seed(seed)
|
||||
logger.debug('Rank dependent seed set to %d', seed)
|
||||
try:
|
||||
return fun(*args, **kwargs)
|
||||
finally:
|
||||
torch.set_rng_state(state)
|
||||
logger.debug('RNG state restored.')
|
||||
return _decorated
|
||||
return _decorator
|
||||
|
||||
|
||||
def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Get a list of tensors and collate them to a single tensor. according to the following logic:
|
||||
- `dim` specifies the time dimension which will be stacked and padded.
|
||||
- The output will contain 1 new dimension (dimension index 0) which will be the size of
|
||||
of the original list.
|
||||
|
||||
Args:
|
||||
tensors (tp.List[torch.Tensor]): List of tensors to collate.
|
||||
dim (int): Dimension which will be stacked and padded.
|
||||
Returns:
|
||||
tp.Tuple[torch.Tensor, torch.Tensor]:
|
||||
torch.Tensor: Stacked and padded tensor. The output will contain 1 new dimension
|
||||
(dimension index 0) which will be the size of the original list.
|
||||
torch.Tensor: Tensor containing length of original tensor sizes (without padding).
|
||||
"""
|
||||
tensors = [x.transpose(0, dim) for x in tensors]
|
||||
lens = torch.LongTensor([len(x) for x in tensors])
|
||||
padded_tensors = pad_sequence(tensors)
|
||||
padded_tensors = padded_tensors.transpose(0, 1)
|
||||
padded_tensors = padded_tensors.transpose(1, dim + 1)
|
||||
return padded_tensors, lens
|
|
@ -0,0 +1,20 @@
|
|||
# please make sure you have already a pytorch install that is cuda enabled!
|
||||
av
|
||||
einops
|
||||
flashy>=0.0.1
|
||||
hydra-core>=1.1
|
||||
hydra_colorlog
|
||||
julius
|
||||
num2words
|
||||
numpy
|
||||
sentencepiece
|
||||
spacy==3.5.2
|
||||
torch>=2.0.0
|
||||
torchaudio>=2.0.0
|
||||
huggingface_hub
|
||||
tqdm
|
||||
transformers
|
||||
xformers
|
||||
demucs
|
||||
librosa
|
||||
gradio
|
Loading…
Reference in New Issue