diff --git a/.gitea/workflows/build.yaml b/.gitea/workflows/build.yaml
new file mode 100644
index 0000000..ba0d002
--- /dev/null
+++ b/.gitea/workflows/build.yaml
@@ -0,0 +1,47 @@
+name: Build
+run-name: ${{ github.actor }} is upgrade release π
+on: [push]
+env:
+ REPOSITORY: ${{ github.repository }}
+ COMMIT_ID: ${{ github.sha }}
+jobs:
+ Build-Deploy-Actions:
+ runs-on: ubuntu-latest
+ steps:
+ - run: echo "π The job was automatically triggered by a ${{ github.event_name }} event."
+ - run: echo "π§ This job is now running on a ${{ runner.os }} server hosted by Gitea!"
+ - run: echo "π The name of your branch is ${{ github.ref }} and your repository is ${{ github.repository }}."
+ - name: Check out repository code
+ uses: actions/checkout@v3
+ -
+ name: Setup Git LFS
+ run: |
+ git lfs install
+ git lfs fetch
+ git lfs checkout
+ - name: List files in the repository
+ run: |
+ ls ${{ github.workspace }}
+ -
+ name: Docker Image Info
+ id: image-info
+ run: |
+ echo "::set-output name=image_name::$(echo $REPOSITORY | tr '[:upper:]' '[:lower:]')"
+ echo "::set-output name=image_tag::${COMMIT_ID:0:10}"
+ -
+ name: Login to Docker Hub
+ uses: docker/login-action@v2
+ with:
+ registry: artifacts.iflytek.com
+ username: ${{ secrets.DOCKERHUB_USERNAME }}
+ password: ${{ secrets.DOCKERHUB_TOKEN }}
+ - name: Set up Docker Buildx
+ uses: docker/setup-buildx-action@v2
+ -
+ name: Build and push
+ run: |
+ docker version
+ docker buildx build -t artifacts.iflytek.com/docker-private/atp/${{ steps.image-info.outputs.image_name }}:${{ steps.image-info.outputs.image_tag }} . --file ${{ github.workspace }}/Dockerfile --load
+ docker push artifacts.iflytek.com/docker-private/atp/${{ steps.image-info.outputs.image_name }}:${{ steps.image-info.outputs.image_tag }}
+ docker rmi artifacts.iflytek.com/docker-private/atp/${{ steps.image-info.outputs.image_name }}:${{ steps.image-info.outputs.image_tag }}
+ - run: echo "π This job's status is ${{ job.status }}."
diff --git a/Dockerfile b/Dockerfile
new file mode 100644
index 0000000..5943d24
--- /dev/null
+++ b/Dockerfile
@@ -0,0 +1,12 @@
+FROM python:3.8.13
+
+WORKDIR /app
+
+COPY . /app
+
+RUN pip config set global.index-url https://pypi.mirrors.ustc.edu.cn/simple
+RUN apt -y update && apt -y upgrade
+RUN apt -y install ffmpeg
+RUN pip install -r requirements.txt
+
+CMD ["python", "app.py"]
diff --git a/app.py b/app.py
new file mode 100644
index 0000000..8ee7f22
--- /dev/null
+++ b/app.py
@@ -0,0 +1,304 @@
+import argparse
+from concurrent.futures import ProcessPoolExecutor
+import os
+import subprocess as sp
+from tempfile import NamedTemporaryFile
+import time
+import warnings
+
+import torch
+import gradio as gr
+
+from audiocraft.data.audio_utils import convert_audio
+from audiocraft.data.audio import audio_write
+from audiocraft.models import MusicGen
+from gradio.themes.utils import sizes
+
+
+theme = gr.themes.Default(radius_size=sizes.radius_none).set(
+ block_label_text_color = '#4D63FF',
+ block_title_text_color = '#4D63FF',
+ button_primary_text_color = '#4D63FF',
+ button_primary_background_fill='#FFFFFF',
+ button_primary_border_color='#4D63FF',
+ button_primary_background_fill_hover='#EDEFFF',
+)
+
+MODEL = None # Last used model
+IS_BATCHED = "facebook/MusicGen" in os.environ.get('SPACE_ID', '')
+MAX_BATCH_SIZE = 12
+BATCHED_DURATION = 15
+INTERRUPTING = False
+# We have to wrap subprocess call to clean a bit the log when using gr.make_waveform
+_old_call = sp.call
+
+
+def _call_nostderr(*args, **kwargs):
+ # Avoid ffmpeg vomitting on the logs.
+ kwargs['stderr'] = sp.DEVNULL
+ kwargs['stdout'] = sp.DEVNULL
+ _old_call(*args, **kwargs)
+
+
+sp.call = _call_nostderr
+# Preallocating the pool of processes.
+pool = ProcessPoolExecutor(4)
+pool.__enter__()
+
+
+def interrupt():
+ global INTERRUPTING
+ INTERRUPTING = True
+
+
+def make_waveform(*args, **kwargs):
+ # Further remove some warnings.
+ be = time.time()
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore')
+ out = gr.make_waveform(*args, **kwargs)
+ print("Make a video took", time.time() - be)
+ return out
+
+
+def load_model(version='melody'):
+ global MODEL
+ print("Loading model", version)
+ if MODEL is None or MODEL.name != version:
+ MODEL = MusicGen.get_pretrained(version)
+
+
+def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs):
+ MODEL.set_generation_params(duration=duration, **gen_kwargs)
+ print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
+ be = time.time()
+ processed_melodies = []
+ target_sr = 32000
+ target_ac = 1
+ for melody in melodies:
+ if melody is None:
+ processed_melodies.append(None)
+ else:
+ sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t()
+ if melody.dim() == 1:
+ melody = melody[None]
+ melody = melody[..., :int(sr * duration)]
+ melody = convert_audio(melody, sr, target_sr, target_ac)
+ processed_melodies.append(melody)
+
+ if any(m is not None for m in processed_melodies):
+ outputs = MODEL.generate_with_chroma(
+ descriptions=texts,
+ melody_wavs=processed_melodies,
+ melody_sample_rate=target_sr,
+ progress=progress,
+ )
+ else:
+ outputs = MODEL.generate(texts, progress=progress)
+
+ outputs = outputs.detach().cpu().float()
+ out_files = []
+ for output in outputs:
+ with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
+ audio_write(
+ file.name, output, MODEL.sample_rate, strategy="loudness",
+ loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
+ out_files.append(pool.submit(make_waveform, file.name))
+ res = [out_file.result() for out_file in out_files]
+ print("batch finished", len(texts), time.time() - be)
+ return res
+
+
+def predict_batched(texts, melodies):
+ max_text_length = 512
+ texts = [text[:max_text_length] for text in texts]
+ load_model('melody')
+ res = _do_predictions(texts, melodies, BATCHED_DURATION)
+ return [res]
+
+
+def predict_full(model, text, melody, duration, topk, topp, temperature, cfg_coef, progress=gr.Progress()):
+ global INTERRUPTING
+ INTERRUPTING = False
+ if temperature < 0:
+ raise gr.Error("Temperature must be >= 0.")
+ if topk < 0:
+ raise gr.Error("Topk must be non-negative.")
+ if topp < 0:
+ raise gr.Error("Topp must be non-negative.")
+
+ topk = int(topk)
+ load_model(model)
+
+ def _progress(generated, to_generate):
+ progress((generated, to_generate))
+ if INTERRUPTING:
+ raise gr.Error("Interrupted.")
+ MODEL.set_custom_progress_callback(_progress)
+
+ outs = _do_predictions(
+ [text], [melody], duration, progress=True,
+ top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef)
+ return outs[0]
+
+
+def ui_full(launch_kwargs):
+ with gr.Blocks(theme=theme, css="footer {visibility: hidden}") as interface:
+ gr.Markdown(
+ """
+
ι³δΉηζ
+ """
+ )
+ with gr.Row():
+ with gr.Column():
+ with gr.Row():
+ text = gr.Text(label="θΎε
₯ζζ¬", interactive=True)
+ melody = gr.Audio(source="upload", type="numpy", label="ζεΎ(ε―ι)", interactive=True)
+ with gr.Row():
+ submit = gr.Button("Submit")
+ # Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
+ _ = gr.Button("δΈζ").click(fn=interrupt, queue=False)
+ with gr.Row():
+ model = gr.Radio(["melody", "medium", "small", "large"], label="樑ε", value="melody", interactive=True)
+ with gr.Row():
+ duration = gr.Slider(minimum=1, maximum=120, value=10, label="Duration", interactive=True)
+ with gr.Row():
+ topk = gr.Number(label="Top-k", value=250, interactive=True)
+ topp = gr.Number(label="Top-p", value=0, interactive=True)
+ temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
+ cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
+ with gr.Column():
+ output = gr.Video(label="ηζηι³δΉ")
+ submit.click(predict_full, inputs=[model, text, melody, duration, topk, topp, temperature, cfg_coef], outputs=[output])
+ gr.Examples(
+ fn=predict_full,
+ examples=[
+ [
+ "An 80s driving pop song with heavy drums and synth pads in the background",
+ "./assets/bach.mp3",
+ "melody"
+ ],
+ [
+ "A cheerful country song with acoustic guitars",
+ "./assets/bolero_ravel.mp3",
+ "melody"
+ ],
+ [
+ "90s rock song with electric guitar and heavy drums",
+ None,
+ "medium"
+ ],
+ [
+ "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions",
+ "./assets/bach.mp3",
+ "melody"
+ ],
+ [
+ "lofi slow bpm electro chill with organic samples",
+ None,
+ "medium",
+ ],
+ ],
+ inputs=[text, melody, model],
+ outputs=[output],
+ label="δΎε"
+ )
+
+ interface.queue().launch(**launch_kwargs)
+
+
+def ui_batched(launch_kwargs):
+ with gr.Blocks(theme=theme, css="footer {visibility: hidden}") as demo:
+ gr.Markdown(
+ """
+ ι³δΉηζ
+ """
+ )
+ with gr.Row():
+ with gr.Column():
+ with gr.Row():
+ text = gr.Text(label="Describe your music", lines=2, interactive=True)
+ melody = gr.Audio(source="upload", type="numpy", label="Condition on a melody (optional)", interactive=True)
+ with gr.Row():
+ submit = gr.Button("Generate")
+ with gr.Column():
+ output = gr.Video(label="Generated Music")
+ submit.click(predict_batched, inputs=[text, melody], outputs=[output], batch=True, max_batch_size=MAX_BATCH_SIZE)
+ gr.Examples(
+ fn=predict_batched,
+ examples=[
+ [
+ "An 80s driving pop song with heavy drums and synth pads in the background",
+ "./assets/bach.mp3",
+ ],
+ [
+ "A cheerful country song with acoustic guitars",
+ "./assets/bolero_ravel.mp3",
+ ],
+ [
+ "90s rock song with electric guitar and heavy drums",
+ None,
+ ],
+ [
+ "a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130",
+ "./assets/bach.mp3",
+ ],
+ [
+ "lofi slow bpm electro chill with organic samples",
+ None,
+ ],
+ ],
+ inputs=[text, melody],
+ outputs=[output],
+ label="δΎε"
+ )
+
+ demo.queue(max_size=8 * 4).launch(**launch_kwargs)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ '--listen',
+ type=str,
+ default='0.0.0.0' if 'SPACE_ID' in os.environ else '0.0.0.0',
+ help='IP to listen on for connections to Gradio',
+ )
+ parser.add_argument(
+ '--username', type=str, default='', help='Username for authentication'
+ )
+ parser.add_argument(
+ '--password', type=str, default='', help='Password for authentication'
+ )
+ parser.add_argument(
+ '--server_port',
+ type=int,
+ default=0,
+ help='Port to run the server listener on',
+ )
+ parser.add_argument(
+ '--inbrowser', action='store_true', help='Open in browser'
+ )
+ parser.add_argument(
+ '--share', action='store_true', help='Share the gradio UI'
+ )
+
+ args = parser.parse_args()
+
+ launch_kwargs = {}
+ launch_kwargs['server_name'] = args.listen
+
+ if args.username and args.password:
+ launch_kwargs['auth'] = (args.username, args.password)
+ if args.server_port:
+ launch_kwargs['server_port'] = args.server_port
+ if args.inbrowser:
+ launch_kwargs['inbrowser'] = args.inbrowser
+ if args.share:
+ launch_kwargs['share'] = args.share
+
+ # Show the interface
+ if IS_BATCHED:
+ ui_batched(launch_kwargs)
+ else:
+ ui_full(launch_kwargs)
diff --git a/assets/bach.mp3 b/assets/bach.mp3
new file mode 100644
index 0000000..16d0da7
Binary files /dev/null and b/assets/bach.mp3 differ
diff --git a/assets/bolero_ravel.mp3 b/assets/bolero_ravel.mp3
new file mode 100644
index 0000000..cbec949
Binary files /dev/null and b/assets/bolero_ravel.mp3 differ
diff --git a/audiocraft/__init__.py b/audiocraft/__init__.py
new file mode 100644
index 0000000..6b8594f
--- /dev/null
+++ b/audiocraft/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# flake8: noqa
+from . import data, modules, models
+
+__version__ = '0.0.2a2'
diff --git a/audiocraft/data/__init__.py b/audiocraft/data/__init__.py
new file mode 100644
index 0000000..708a3dc
--- /dev/null
+++ b/audiocraft/data/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# flake8: noqa
+from . import audio, audio_dataset
diff --git a/audiocraft/data/audio.py b/audiocraft/data/audio.py
new file mode 100644
index 0000000..2048df6
--- /dev/null
+++ b/audiocraft/data/audio.py
@@ -0,0 +1,215 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Audio IO methods are defined in this module (info, read, write),
+We rely on av library for faster read when possible, otherwise on torchaudio.
+"""
+
+from dataclasses import dataclass
+from pathlib import Path
+import logging
+import typing as tp
+
+import numpy as np
+import soundfile
+import torch
+from torch.nn import functional as F
+import torchaudio as ta
+
+import av
+
+from .audio_utils import f32_pcm, i16_pcm, normalize_audio
+
+
+_av_initialized = False
+
+
+def _init_av():
+ global _av_initialized
+ if _av_initialized:
+ return
+ logger = logging.getLogger('libav.mp3')
+ logger.setLevel(logging.ERROR)
+ _av_initialized = True
+
+
+@dataclass(frozen=True)
+class AudioFileInfo:
+ sample_rate: int
+ duration: float
+ channels: int
+
+
+def _av_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
+ _init_av()
+ with av.open(str(filepath)) as af:
+ stream = af.streams.audio[0]
+ sample_rate = stream.codec_context.sample_rate
+ duration = float(stream.duration * stream.time_base)
+ channels = stream.channels
+ return AudioFileInfo(sample_rate, duration, channels)
+
+
+def _soundfile_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
+ info = soundfile.info(filepath)
+ return AudioFileInfo(info.samplerate, info.duration, info.channels)
+
+
+def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
+ # torchaudio no longer returns useful duration informations for some formats like mp3s.
+ filepath = Path(filepath)
+ if filepath.suffix in ['.flac', '.ogg']: # TODO: Validate .ogg can be safely read with av_info
+ # ffmpeg has some weird issue with flac.
+ return _soundfile_info(filepath)
+ else:
+ return _av_info(filepath)
+
+
+def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: float = -1.) -> tp.Tuple[torch.Tensor, int]:
+ """FFMPEG-based audio file reading using PyAV bindings.
+ Soundfile cannot read mp3 and av_read is more efficient than torchaudio.
+
+ Args:
+ filepath (str or Path): Path to audio file to read.
+ seek_time (float): Time at which to start reading in the file.
+ duration (float): Duration to read from the file. If set to -1, the whole file is read.
+ Returns:
+ Tuple[torch.Tensor, int]: Tuple containing audio data and sample rate
+ """
+ _init_av()
+ with av.open(str(filepath)) as af:
+ stream = af.streams.audio[0]
+ sr = stream.codec_context.sample_rate
+ num_frames = int(sr * duration) if duration >= 0 else -1
+ frame_offset = int(sr * seek_time)
+ # we need a small negative offset otherwise we get some edge artifact
+ # from the mp3 decoder.
+ af.seek(int(max(0, (seek_time - 0.1)) / stream.time_base), stream=stream)
+ frames = []
+ length = 0
+ for frame in af.decode(streams=stream.index):
+ current_offset = int(frame.rate * frame.pts * frame.time_base)
+ strip = max(0, frame_offset - current_offset)
+ buf = torch.from_numpy(frame.to_ndarray())
+ if buf.shape[0] != stream.channels:
+ buf = buf.view(-1, stream.channels).t()
+ buf = buf[:, strip:]
+ frames.append(buf)
+ length += buf.shape[1]
+ if num_frames > 0 and length >= num_frames:
+ break
+ assert frames
+ # If the above assert fails, it is likely because we seeked past the end of file point,
+ # in which case ffmpeg returns a single frame with only zeros, and a weird timestamp.
+ # This will need proper debugging, in due time.
+ wav = torch.cat(frames, dim=1)
+ assert wav.shape[0] == stream.channels
+ if num_frames > 0:
+ wav = wav[:, :num_frames]
+ return f32_pcm(wav), sr
+
+
+def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
+ duration: float = -1., pad: bool = False) -> tp.Tuple[torch.Tensor, int]:
+ """Read audio by picking the most appropriate backend tool based on the audio format.
+
+ Args:
+ filepath (str or Path): Path to audio file to read.
+ seek_time (float): Time at which to start reading in the file.
+ duration (float): Duration to read from the file. If set to -1, the whole file is read.
+ pad (bool): Pad output audio if not reaching expected duration.
+ Returns:
+ Tuple[torch.Tensor, int]: Tuple containing audio data and sample rate.
+ """
+ fp = Path(filepath)
+ if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg
+ # There is some bug with ffmpeg and reading flac
+ info = _soundfile_info(filepath)
+ frames = -1 if duration <= 0 else int(duration * info.sample_rate)
+ frame_offset = int(seek_time * info.sample_rate)
+ wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32)
+ assert info.sample_rate == sr, f"Mismatch of sample rates {info.sample_rate} {sr}"
+ wav = torch.from_numpy(wav).t().contiguous()
+ if len(wav.shape) == 1:
+ wav = torch.unsqueeze(wav, 0)
+ elif (
+ fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats()
+ and duration <= 0 and seek_time == 0
+ ):
+ # Torchaudio is faster if we load an entire file at once.
+ wav, sr = ta.load(fp)
+ else:
+ wav, sr = _av_read(filepath, seek_time, duration)
+ if pad and duration > 0:
+ expected_frames = int(duration * sr)
+ wav = F.pad(wav, (0, expected_frames - wav.shape[-1]))
+ return wav, sr
+
+
+def audio_write(stem_name: tp.Union[str, Path],
+ wav: torch.Tensor, sample_rate: int,
+ format: str = 'wav', mp3_rate: int = 320, normalize: bool = True,
+ strategy: str = 'peak', peak_clip_headroom_db: float = 1,
+ rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
+ loudness_compressor: bool = False,
+ log_clipping: bool = True, make_parent_dir: bool = True,
+ add_suffix: bool = True) -> Path:
+ """Convenience function for saving audio to disk. Returns the filename the audio was written to.
+
+ Args:
+ stem_name (str or Path): Filename without extension which will be added automatically.
+ format (str): Either "wav" or "mp3".
+ mp3_rate (int): kbps when using mp3s.
+ normalize (bool): if `True` (default), normalizes according to the prescribed
+ strategy (see after). If `False`, the strategy is only used in case clipping
+ would happen.
+ strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
+ i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
+ with extra headroom to avoid clipping. 'clip' just clips.
+ peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
+ rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
+ than the `peak_clip` one to avoid further clipping.
+ loudness_headroom_db (float): Target loudness for loudness normalization.
+ loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'.
+ when strategy is 'loudness'log_clipping (bool): If True, basic logging on stderr when clipping still
+ occurs despite strategy (only for 'rms').
+ make_parent_dir (bool): Make parent directory if it doesn't exist.
+ Returns:
+ Path: Path of the saved audio.
+ """
+ assert wav.dtype.is_floating_point, "wav is not floating point"
+ if wav.dim() == 1:
+ wav = wav[None]
+ elif wav.dim() > 2:
+ raise ValueError("Input wav should be at most 2 dimension.")
+ assert wav.isfinite().all()
+ wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
+ rms_headroom_db, loudness_headroom_db, log_clipping=log_clipping,
+ sample_rate=sample_rate, stem_name=str(stem_name))
+ kwargs: dict = {}
+ if format == 'mp3':
+ suffix = '.mp3'
+ kwargs.update({"compression": mp3_rate})
+ elif format == 'wav':
+ wav = i16_pcm(wav)
+ suffix = '.wav'
+ kwargs.update({"encoding": "PCM_S", "bits_per_sample": 16})
+ else:
+ raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
+ if not add_suffix:
+ suffix = ''
+ path = Path(str(stem_name) + suffix)
+ if make_parent_dir:
+ path.parent.mkdir(exist_ok=True, parents=True)
+ try:
+ ta.save(path, wav, sample_rate, **kwargs)
+ except Exception:
+ if path.exists():
+ # we do not want to leave half written files around.
+ path.unlink()
+ raise
+ return path
diff --git a/audiocraft/data/audio_dataset.py b/audiocraft/data/audio_dataset.py
new file mode 100644
index 0000000..cf21422
--- /dev/null
+++ b/audiocraft/data/audio_dataset.py
@@ -0,0 +1,525 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import copy
+from concurrent.futures import ThreadPoolExecutor, Future
+from dataclasses import dataclass, fields
+from contextlib import ExitStack
+import gzip
+import json
+import logging
+import os
+from pathlib import Path
+import random
+import sys
+import typing as tp
+
+import torch
+import torch.nn.functional as F
+
+from .audio import audio_read, audio_info
+from .audio_utils import convert_audio
+from .zip import PathInZip
+
+try:
+ import dora
+except ImportError:
+ dora = None # type: ignore
+
+
+@dataclass(order=True)
+class BaseInfo:
+
+ @classmethod
+ def _dict2fields(cls, dictionary: dict):
+ return {
+ field.name: dictionary[field.name]
+ for field in fields(cls) if field.name in dictionary
+ }
+
+ @classmethod
+ def from_dict(cls, dictionary: dict):
+ _dictionary = cls._dict2fields(dictionary)
+ return cls(**_dictionary)
+
+ def to_dict(self):
+ return {
+ field.name: self.__getattribute__(field.name)
+ for field in fields(self)
+ }
+
+
+@dataclass(order=True)
+class AudioMeta(BaseInfo):
+ path: str
+ duration: float
+ sample_rate: int
+ amplitude: tp.Optional[float] = None
+ weight: tp.Optional[float] = None
+ # info_path is used to load additional information about the audio file that is stored in zip files.
+ info_path: tp.Optional[PathInZip] = None
+
+ @classmethod
+ def from_dict(cls, dictionary: dict):
+ base = cls._dict2fields(dictionary)
+ if 'info_path' in base and base['info_path'] is not None:
+ base['info_path'] = PathInZip(base['info_path'])
+ return cls(**base)
+
+ def to_dict(self):
+ d = super().to_dict()
+ if d['info_path'] is not None:
+ d['info_path'] = str(d['info_path'])
+ return d
+
+
+@dataclass(order=True)
+class SegmentInfo(BaseInfo):
+ meta: AudioMeta
+ seek_time: float
+ n_frames: int # actual number of frames without padding
+ total_frames: int # total number of frames, padding included
+ sample_rate: int # actual sample rate
+
+
+DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
+
+logger = logging.getLogger(__name__)
+
+
+def _get_audio_meta(file_path: str, minimal: bool = True) -> AudioMeta:
+ """AudioMeta from a path to an audio file.
+
+ Args:
+ file_path (str): Resolved path of valid audio file.
+ minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
+ Returns:
+ AudioMeta: Audio file path and its metadata.
+ """
+ info = audio_info(file_path)
+ amplitude: tp.Optional[float] = None
+ if not minimal:
+ wav, sr = audio_read(file_path)
+ amplitude = wav.abs().max().item()
+ return AudioMeta(file_path, info.duration, info.sample_rate, amplitude)
+
+
+def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta:
+ """If Dora is available as a dependency, try to resolve potential relative paths
+ in list of AudioMeta. This method is expected to be used when loading meta from file.
+
+ Args:
+ m (AudioMeta): Audio meta to resolve.
+ fast (bool): If True, uses a really fast check for determining if a file is already absolute or not.
+ Only valid on Linux/Mac.
+ Returns:
+ AudioMeta: Audio meta with resolved path.
+ """
+ def is_abs(m):
+ if fast:
+ return str(m)[0] == '/'
+ else:
+ os.path.isabs(str(m))
+
+ if not dora:
+ return m
+
+ if not is_abs(m.path):
+ m.path = dora.git_save.to_absolute_path(m.path)
+ if m.info_path is not None and not is_abs(m.info_path.zip_path):
+ m.info_path.zip_path = dora.git_save.to_absolute_path(m.path)
+ return m
+
+
+def find_audio_files(path: tp.Union[Path, str],
+ exts: tp.List[str] = DEFAULT_EXTS,
+ resolve: bool = True,
+ minimal: bool = True,
+ progress: bool = False,
+ workers: int = 0) -> tp.List[AudioMeta]:
+ """Build a list of AudioMeta from a given path,
+ collecting relevant audio files and fetching meta info.
+
+ Args:
+ path (str or Path): Path to folder containing audio files.
+ exts (list of str): List of file extensions to consider for audio files.
+ minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
+ progress (bool): Whether to log progress on audio files collection.
+ workers (int): number of parallel workers, if 0, use only the current thread.
+ Returns:
+ List[AudioMeta]: List of audio file path and its metadata.
+ """
+ audio_files = []
+ futures: tp.List[Future] = []
+ pool: tp.Optional[ThreadPoolExecutor] = None
+ with ExitStack() as stack:
+ if workers > 0:
+ pool = ThreadPoolExecutor(workers)
+ stack.enter_context(pool)
+
+ if progress:
+ print("Finding audio files...")
+ for root, folders, files in os.walk(path, followlinks=True):
+ for file in files:
+ full_path = Path(root) / file
+ if full_path.suffix.lower() in exts:
+ audio_files.append(full_path)
+ if pool is not None:
+ futures.append(pool.submit(_get_audio_meta, str(audio_files[-1]), minimal))
+ if progress:
+ print(format(len(audio_files), " 8d"), end='\r', file=sys.stderr)
+
+ if progress:
+ print("Getting audio metadata...")
+ meta: tp.List[AudioMeta] = []
+ for idx, file_path in enumerate(audio_files):
+ try:
+ if pool is None:
+ m = _get_audio_meta(str(file_path), minimal)
+ else:
+ m = futures[idx].result()
+ if resolve:
+ m = _resolve_audio_meta(m)
+ except Exception as err:
+ print("Error with", str(file_path), err, file=sys.stderr)
+ continue
+ meta.append(m)
+ if progress:
+ print(format((1 + idx) / len(audio_files), " 3.1%"), end='\r', file=sys.stderr)
+ meta.sort()
+ return meta
+
+
+def load_audio_meta(path: tp.Union[str, Path],
+ resolve: bool = True, fast: bool = True) -> tp.List[AudioMeta]:
+ """Load list of AudioMeta from an optionally compressed json file.
+
+ Args:
+ path (str or Path): Path to JSON file.
+ resolve (bool): Whether to resolve the path from AudioMeta (default=True).
+ fast (bool): activates some tricks to make things faster.
+ Returns:
+ List[AudioMeta]: List of audio file path and its total duration.
+ """
+ open_fn = gzip.open if str(path).lower().endswith('.gz') else open
+ with open_fn(path, 'rb') as fp: # type: ignore
+ lines = fp.readlines()
+ meta = []
+ for line in lines:
+ d = json.loads(line)
+ m = AudioMeta.from_dict(d)
+ if resolve:
+ m = _resolve_audio_meta(m, fast=fast)
+ meta.append(m)
+ return meta
+
+
+def save_audio_meta(path: tp.Union[str, Path], meta: tp.List[AudioMeta]):
+ """Save the audio metadata to the file pointer as json.
+
+ Args:
+ path (str or Path): Path to JSON file.
+ metadata (list of BaseAudioMeta): List of audio meta to save.
+ """
+ Path(path).parent.mkdir(exist_ok=True, parents=True)
+ open_fn = gzip.open if str(path).lower().endswith('.gz') else open
+ with open_fn(path, 'wb') as fp: # type: ignore
+ for m in meta:
+ json_str = json.dumps(m.to_dict()) + '\n'
+ json_bytes = json_str.encode('utf-8')
+ fp.write(json_bytes)
+
+
+class AudioDataset:
+ """Base audio dataset.
+
+ The dataset takes a list of AudioMeta and create a dataset composed of segments of audio
+ and potentially additional information, by creating random segments from the list of audio
+ files referenced in the metadata and applying minimal data pre-processing such as resampling,
+ mixing of channels, padding, etc.
+
+ If no segment_duration value is provided, the AudioDataset will return the full wav for each
+ audio file. Otherwise, it will randomly sample audio files and create a segment of the specified
+ duration, applying padding if required.
+
+ By default, only the torch Tensor corresponding to the waveform is returned. Setting return_info=True
+ allows to return a tuple containing the torch Tensor and additional metadata on the segment and the
+ original audio meta.
+
+ Args:
+ meta (tp.List[AudioMeta]): List of audio files metadata.
+ segment_duration (float): Optional segment duration of audio to load.
+ If not specified, the dataset will load the full audio segment from the file.
+ shuffle (bool): Set to `True` to have the data reshuffled at every epoch.
+ sample_rate (int): Target sample rate of the loaded audio samples.
+ channels (int): Target number of channels of the loaded audio samples.
+ sample_on_duration (bool): Set to `True` to sample segments with probability
+ dependent on audio file duration. This is only used if `segment_duration` is provided.
+ sample_on_weight (bool): Set to `True` to sample segments using the `weight` entry of
+ `AudioMeta`. If `sample_on_duration` is also True, the actual weight will be the product
+ of the file duration and file weight. This is only used if `segment_duration` is provided.
+ min_segment_ratio (float): Minimum segment ratio to use when the audio file
+ is shorter than the desired segment.
+ max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset.
+ return_info (bool): Whether to return the wav only or return wav along with segment info and metadata.
+ min_audio_duration (tp.Optional[float], optional): Minimum audio file duration, in seconds, if provided
+ audio shorter than this will be filtered out.
+ max_audio_duration (tp.Optional[float], optional): Maximal audio file duration in seconds, if provided
+ audio longer than this will be filtered out.
+ """
+ def __init__(self,
+ meta: tp.List[AudioMeta],
+ segment_duration: tp.Optional[float] = None,
+ shuffle: bool = True,
+ num_samples: int = 10_000,
+ sample_rate: int = 48_000,
+ channels: int = 2,
+ pad: bool = True,
+ sample_on_duration: bool = True,
+ sample_on_weight: bool = True,
+ min_segment_ratio: float = 0.5,
+ max_read_retry: int = 10,
+ return_info: bool = False,
+ min_audio_duration: tp.Optional[float] = None,
+ max_audio_duration: tp.Optional[float] = None
+ ):
+ assert len(meta) > 0, 'No audio meta provided to AudioDataset. Please check loading of audio meta.'
+ assert segment_duration is None or segment_duration > 0
+ assert segment_duration is None or min_segment_ratio >= 0
+ logging.debug(f'sample_on_duration: {sample_on_duration}')
+ logging.debug(f'sample_on_weight: {sample_on_weight}')
+ logging.debug(f'pad: {pad}')
+ logging.debug(f'min_segment_ratio: {min_segment_ratio}')
+
+ self.segment_duration = segment_duration
+ self.min_segment_ratio = min_segment_ratio
+ self.max_audio_duration = max_audio_duration
+ self.min_audio_duration = min_audio_duration
+ if self.min_audio_duration is not None and self.max_audio_duration is not None:
+ assert self.min_audio_duration <= self.max_audio_duration
+ self.meta: tp.List[AudioMeta] = self._filter_duration(meta)
+ assert len(self.meta) # Fail fast if all data has been filtered.
+ self.total_duration = sum(d.duration for d in self.meta)
+
+ if segment_duration is None:
+ num_samples = len(self.meta)
+ self.num_samples = num_samples
+ self.shuffle = shuffle
+ self.sample_rate = sample_rate
+ self.channels = channels
+ self.pad = pad
+ self.sample_on_weight = sample_on_weight
+ self.sample_on_duration = sample_on_duration
+ self.sampling_probabilities = self._get_sampling_probabilities()
+ self.max_read_retry = max_read_retry
+ self.return_info = return_info
+
+ def __len__(self):
+ return self.num_samples
+
+ def _get_sampling_probabilities(self, normalized: bool = True):
+ """Return the sampling probabilities for each file inside `self.meta`.
+ """
+ scores: tp.List[float] = []
+ for file_meta in self.meta:
+ score = 1.
+ if self.sample_on_weight and file_meta.weight is not None:
+ score *= file_meta.weight
+ if self.sample_on_duration:
+ score *= file_meta.duration
+ scores.append(score)
+ probabilities = torch.tensor(scores)
+ if normalized:
+ probabilities /= probabilities.sum()
+ return probabilities
+
+ def sample_file(self, rng: torch.Generator) -> AudioMeta:
+ """Sample a given file from `self.meta`. Can be overriden in subclasses.
+ This is only called if `segment_duration` is not None.
+
+ You must use the provided random number generator `rng` for reproducibility.
+ """
+ if not self.sample_on_weight and not self.sample_on_duration:
+ file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
+ else:
+ file_index = int(torch.multinomial(self.sampling_probabilities, 1, generator=rng).item())
+
+ return self.meta[file_index]
+
+ def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]:
+ if self.segment_duration is None:
+ file_meta = self.meta[index]
+ out, sr = audio_read(file_meta.path)
+ out = convert_audio(out, sr, self.sample_rate, self.channels)
+ n_frames = out.shape[-1]
+ segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames,
+ sample_rate=self.sample_rate)
+ else:
+ rng = torch.Generator()
+ if self.shuffle:
+ # We use index, plus extra randomness
+ rng.manual_seed(index + self.num_samples * random.randint(0, 2**24))
+ else:
+ # We only use index
+ rng.manual_seed(index)
+
+ for retry in range(self.max_read_retry):
+ file_meta = self.sample_file(rng)
+ # We add some variance in the file position even if audio file is smaller than segment
+ # without ending up with empty segments
+ max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio)
+ seek_time = torch.rand(1, generator=rng).item() * max_seek
+ try:
+ out, sr = audio_read(file_meta.path, seek_time, self.segment_duration, pad=False)
+ out = convert_audio(out, sr, self.sample_rate, self.channels)
+ n_frames = out.shape[-1]
+ target_frames = int(self.segment_duration * self.sample_rate)
+ if self.pad:
+ out = F.pad(out, (0, target_frames - n_frames))
+ segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames,
+ sample_rate=self.sample_rate)
+ except Exception as exc:
+ logger.warning("Error opening file %s: %r", file_meta.path, exc)
+ if retry == self.max_read_retry - 1:
+ raise
+ else:
+ break
+
+ if self.return_info:
+ # Returns the wav and additional information on the wave segment
+ return out, segment_info
+ else:
+ return out
+
+ def collater(self, samples):
+ """The collater function has to be provided to the dataloader
+ if AudioDataset has return_info=True in order to properly collate
+ the samples of a batch.
+ """
+ if self.segment_duration is None and len(samples) > 1:
+ assert self.pad, "Must allow padding when batching examples of different durations."
+
+ # In this case the audio reaching the collater is of variable length as segment_duration=None.
+ to_pad = self.segment_duration is None and self.pad
+ if to_pad:
+ max_len = max([wav.shape[-1] for wav, _ in samples])
+
+ def _pad_wav(wav):
+ return F.pad(wav, (0, max_len - wav.shape[-1]))
+
+ if self.return_info:
+ if len(samples) > 0:
+ assert len(samples[0]) == 2
+ assert isinstance(samples[0][0], torch.Tensor)
+ assert isinstance(samples[0][1], SegmentInfo)
+
+ wavs = [wav for wav, _ in samples]
+ segment_infos = [copy.deepcopy(info) for _, info in samples]
+
+ if to_pad:
+ # Each wav could be of a different duration as they are not segmented.
+ for i in range(len(samples)):
+ # Determines the total legth of the signal with padding, so we update here as we pad.
+ segment_infos[i].total_frames = max_len
+ wavs[i] = _pad_wav(wavs[i])
+
+ wav = torch.stack(wavs)
+ return wav, segment_infos
+ else:
+ assert isinstance(samples[0], torch.Tensor)
+ if to_pad:
+ samples = [_pad_wav(s) for s in samples]
+ return torch.stack(samples)
+
+ def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
+ """Filters out audio files with short durations.
+ Removes from meta files that have durations that will not allow to samples examples from them.
+ """
+ orig_len = len(meta)
+
+ # Filter data that is too short.
+ if self.min_audio_duration is not None:
+ meta = [m for m in meta if m.duration >= self.min_audio_duration]
+
+ # Filter data that is too long.
+ if self.max_audio_duration is not None:
+ meta = [m for m in meta if m.duration <= self.max_audio_duration]
+
+ filtered_len = len(meta)
+ removed_percentage = 100*(1-float(filtered_len)/orig_len)
+ msg = 'Removed %.2f percent of the data because it was too short or too long.' % removed_percentage
+ if removed_percentage < 10:
+ logging.debug(msg)
+ else:
+ logging.warning(msg)
+ return meta
+
+ @classmethod
+ def from_meta(cls, root: tp.Union[str, Path], **kwargs):
+ """Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file.
+
+ Args:
+ root (str or Path): Path to root folder containing audio files.
+ kwargs: Additional keyword arguments for the AudioDataset.
+ """
+ root = Path(root)
+ if root.is_dir():
+ if (root / 'data.jsonl').exists():
+ root = root / 'data.jsonl'
+ elif (root / 'data.jsonl.gz').exists():
+ root = root / 'data.jsonl.gz'
+ else:
+ raise ValueError("Don't know where to read metadata from in the dir. "
+ "Expecting either a data.jsonl or data.jsonl.gz file but none found.")
+ meta = load_audio_meta(root)
+ return cls(meta, **kwargs)
+
+ @classmethod
+ def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True,
+ exts: tp.List[str] = DEFAULT_EXTS, **kwargs):
+ """Instantiate AudioDataset from a path containing (possibly nested) audio files.
+
+ Args:
+ root (str or Path): Path to root folder containing audio files.
+ minimal_meta (bool): Whether to only load minimal metadata or not.
+ exts (list of str): Extensions for audio files.
+ kwargs: Additional keyword arguments for the AudioDataset.
+ """
+ root = Path(root)
+ if root.is_file():
+ meta = load_audio_meta(root, resolve=True)
+ else:
+ meta = find_audio_files(root, exts, minimal=minimal_meta, resolve=True)
+ return cls(meta, **kwargs)
+
+
+def main():
+ logging.basicConfig(stream=sys.stderr, level=logging.INFO)
+ parser = argparse.ArgumentParser(
+ prog='audio_dataset',
+ description='Generate .jsonl files by scanning a folder.')
+ parser.add_argument('root', help='Root folder with all the audio files')
+ parser.add_argument('output_meta_file',
+ help='Output file to store the metadata, ')
+ parser.add_argument('--complete',
+ action='store_false', dest='minimal', default=True,
+ help='Retrieve all metadata, even the one that are expansive '
+ 'to compute (e.g. normalization).')
+ parser.add_argument('--resolve',
+ action='store_true', default=False,
+ help='Resolve the paths to be absolute and with no symlinks.')
+ parser.add_argument('--workers',
+ default=10, type=int,
+ help='Number of workers.')
+ args = parser.parse_args()
+ meta = find_audio_files(args.root, DEFAULT_EXTS, progress=True,
+ resolve=args.resolve, minimal=args.minimal, workers=args.workers)
+ save_audio_meta(args.output_meta_file, meta)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/audiocraft/data/audio_utils.py b/audiocraft/data/audio_utils.py
new file mode 100644
index 0000000..76d4bc2
--- /dev/null
+++ b/audiocraft/data/audio_utils.py
@@ -0,0 +1,174 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import sys
+import typing as tp
+
+import julius
+import torch
+import torchaudio
+
+
+def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor:
+ """Convert audio to the given number of channels.
+
+ Args:
+ wav (torch.Tensor): Audio wave of shape [B, C, T].
+ channels (int): Expected number of channels as output.
+ Returns:
+ torch.Tensor: Downmixed or unchanged audio wave [B, C, T].
+ """
+ *shape, src_channels, length = wav.shape
+ if src_channels == channels:
+ pass
+ elif channels == 1:
+ # Case 1:
+ # The caller asked 1-channel audio, and the stream has multiple
+ # channels, downmix all channels.
+ wav = wav.mean(dim=-2, keepdim=True)
+ elif src_channels == 1:
+ # Case 2:
+ # The caller asked for multiple channels, but the input file has
+ # a single channel, replicate the audio over all channels.
+ wav = wav.expand(*shape, channels, length)
+ elif src_channels >= channels:
+ # Case 3:
+ # The caller asked for multiple channels, and the input file has
+ # more channels than requested. In that case return the first channels.
+ wav = wav[..., :channels, :]
+ else:
+ # Case 4: What is a reasonable choice here?
+ raise ValueError('The audio file has less channels than requested but is not mono.')
+ return wav
+
+
+def convert_audio(wav: torch.Tensor, from_rate: float,
+ to_rate: float, to_channels: int) -> torch.Tensor:
+ """Convert audio to new sample rate and number of audio channels.
+ """
+ wav = julius.resample_frac(wav, int(from_rate), int(to_rate))
+ wav = convert_audio_channels(wav, to_channels)
+ return wav
+
+
+def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14,
+ loudness_compressor: bool = False, energy_floor: float = 2e-3):
+ """Normalize an input signal to a user loudness in dB LKFS.
+ Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.
+
+ Args:
+ wav (torch.Tensor): Input multichannel audio data.
+ sample_rate (int): Sample rate.
+ loudness_headroom_db (float): Target loudness of the output in dB LUFS.
+ loudness_compressor (bool): Uses tanh for soft clipping.
+ energy_floor (float): anything below that RMS level will not be rescaled.
+ Returns:
+ output (torch.Tensor): Loudness normalized output data.
+ """
+ energy = wav.pow(2).mean().sqrt().item()
+ if energy < energy_floor:
+ return wav
+ transform = torchaudio.transforms.Loudness(sample_rate)
+ input_loudness_db = transform(wav).item()
+ # calculate the gain needed to scale to the desired loudness level
+ delta_loudness = -loudness_headroom_db - input_loudness_db
+ gain = 10.0 ** (delta_loudness / 20.0)
+ output = gain * wav
+ if loudness_compressor:
+ output = torch.tanh(output)
+ assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt())
+ return output
+
+
+def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: tp.Optional[str] = None) -> None:
+ """Utility function to clip the audio with logging if specified."""
+ max_scale = wav.abs().max()
+ if log_clipping and max_scale > 1:
+ clamp_prob = (wav.abs() > 1).float().mean().item()
+ print(f"CLIPPING {stem_name or ''} happening with proba (a bit of clipping is okay):",
+ clamp_prob, "maximum scale: ", max_scale.item(), file=sys.stderr)
+ wav.clamp_(-1, 1)
+
+
+def normalize_audio(wav: torch.Tensor, normalize: bool = True,
+ strategy: str = 'peak', peak_clip_headroom_db: float = 1,
+ rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
+ loudness_compressor: bool = False, log_clipping: bool = False,
+ sample_rate: tp.Optional[int] = None,
+ stem_name: tp.Optional[str] = None) -> torch.Tensor:
+ """Normalize the audio according to the prescribed strategy (see after).
+
+ Args:
+ wav (torch.Tensor): Audio data.
+ normalize (bool): if `True` (default), normalizes according to the prescribed
+ strategy (see after). If `False`, the strategy is only used in case clipping
+ would happen.
+ strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
+ i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
+ with extra headroom to avoid clipping. 'clip' just clips.
+ peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
+ rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
+ than the `peak_clip` one to avoid further clipping.
+ loudness_headroom_db (float): Target loudness for loudness normalization.
+ loudness_compressor (bool): If True, uses tanh based soft clipping.
+ log_clipping (bool): If True, basic logging on stderr when clipping still
+ occurs despite strategy (only for 'rms').
+ sample_rate (int): Sample rate for the audio data (required for loudness).
+ stem_name (Optional[str]): Stem name for clipping logging.
+ Returns:
+ torch.Tensor: Normalized audio.
+ """
+ scale_peak = 10 ** (-peak_clip_headroom_db / 20)
+ scale_rms = 10 ** (-rms_headroom_db / 20)
+ if strategy == 'peak':
+ rescaling = (scale_peak / wav.abs().max())
+ if normalize or rescaling < 1:
+ wav = wav * rescaling
+ elif strategy == 'clip':
+ wav = wav.clamp(-scale_peak, scale_peak)
+ elif strategy == 'rms':
+ mono = wav.mean(dim=0)
+ rescaling = scale_rms / mono.pow(2).mean().sqrt()
+ if normalize or rescaling < 1:
+ wav = wav * rescaling
+ _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
+ elif strategy == 'loudness':
+ assert sample_rate is not None, "Loudness normalization requires sample rate."
+ wav = normalize_loudness(wav, sample_rate, loudness_headroom_db, loudness_compressor)
+ _clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
+ else:
+ assert wav.abs().max() < 1
+ assert strategy == '' or strategy == 'none', f"Unexpected strategy: '{strategy}'"
+ return wav
+
+
+def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
+ """Convert audio to float 32 bits PCM format.
+ """
+ if wav.dtype.is_floating_point:
+ return wav
+ else:
+ assert wav.dtype == torch.int16
+ return wav.float() / 2**15
+
+
+def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
+ """Convert audio to int 16 bits PCM format.
+
+ ..Warning:: There exist many formula for doing this convertion. None are perfect
+ due to the asymetry of the int16 range. One either have possible clipping, DC offset,
+ or inconsistancies with f32_pcm. If the given wav doesn't have enough headroom,
+ it is possible that `i16_pcm(f32_pcm)) != Identity`.
+ """
+ if wav.dtype.is_floating_point:
+ assert wav.abs().max() <= 1
+ candidate = (wav * 2 ** 15).round()
+ if candidate.max() >= 2 ** 15: # clipping would occur
+ candidate = (wav * (2 ** 15 - 1)).round()
+ return candidate.short()
+ else:
+ assert wav.dtype == torch.int16
+ return wav
diff --git a/audiocraft/data/zip.py b/audiocraft/data/zip.py
new file mode 100644
index 0000000..1f11542
--- /dev/null
+++ b/audiocraft/data/zip.py
@@ -0,0 +1,74 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import typing
+import zipfile
+
+from dataclasses import dataclass
+from functools import lru_cache
+from typing_extensions import Literal
+
+
+DEFAULT_SIZE = 32
+MODE = Literal['r', 'w', 'x', 'a']
+
+
+@dataclass(order=True)
+class PathInZip:
+ """Class for holding a path of file within a zip file.
+
+ Args:
+ path: The convention is :
+ Let's assume there is a zip file /some/location/foo.zip
+ and inside of it is a json file located at /data/file1.json,
+ Then we expect path = "/some/location/foo.zip:/data/file1.json"
+ """
+
+ INFO_PATH_SEP = ':'
+ zip_path: str
+ file_path: str
+
+ def __init__(self, path: str) -> None:
+ split_path = path.split(self.INFO_PATH_SEP)
+ assert len(split_path) == 2
+ self.zip_path, self.file_path = split_path
+
+ @classmethod
+ def from_paths(cls, zip_path: str, file_path: str):
+ return cls(zip_path + cls.INFO_PATH_SEP + file_path)
+
+ def __str__(self) -> str:
+ return self.zip_path + self.INFO_PATH_SEP + self.file_path
+
+
+def _open_zip(path: str, mode: MODE = 'r'):
+ return zipfile.ZipFile(path, mode)
+
+
+_cached_open_zip = lru_cache(DEFAULT_SIZE)(_open_zip)
+
+
+def set_zip_cache_size(max_size: int):
+ """Sets the maximal LRU caching for zip file opening.
+
+ Args:
+ max_size: the maximal LRU cache.
+ """
+ global _cached_open_zip
+ _cached_open_zip = lru_cache(max_size)(_open_zip)
+
+
+def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO:
+ """Opens a file stored inside a zip and returns a file-like object.
+
+ Args:
+ path_in_zip: A PathInZip object representing the file to return a file-like object of.
+ mode: The mode in which to open the file with.
+ Returns:
+ A file-like object for PathInZip.
+ """
+ zf = _cached_open_zip(path_in_zip.zip_path)
+ return zf.open(path_in_zip.file_path)
diff --git a/audiocraft/models/__init__.py b/audiocraft/models/__init__.py
new file mode 100644
index 0000000..92c7a48
--- /dev/null
+++ b/audiocraft/models/__init__.py
@@ -0,0 +1,10 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# flake8: noqa
+from .musicgen import MusicGen
+from .lm import LMModel
+from .encodec import CompressionModel, EncodecModel
diff --git a/audiocraft/models/builders.py b/audiocraft/models/builders.py
new file mode 100644
index 0000000..77ee5f9
--- /dev/null
+++ b/audiocraft/models/builders.py
@@ -0,0 +1,218 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+All the functions to build the relevant models and modules
+from the Hydra config.
+"""
+
+import typing as tp
+import warnings
+
+import audiocraft
+import omegaconf
+import torch
+
+from .encodec import CompressionModel, EncodecModel, FlattenedCompressionModel # noqa
+from .lm import LMModel
+from ..modules.codebooks_patterns import (
+ CodebooksPatternProvider,
+ DelayedPatternProvider,
+ ParallelPatternProvider,
+ UnrolledPatternProvider,
+ VALLEPattern,
+ MusicLMPattern,
+)
+from ..modules.conditioners import (
+ BaseConditioner,
+ ConditioningProvider,
+ LUTConditioner,
+ T5Conditioner,
+ ConditionFuser,
+ ChromaStemConditioner,
+)
+from .. import quantization as qt
+from ..utils.utils import dict_from_config
+
+
+def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> qt.BaseQuantizer:
+ klass = {
+ 'no_quant': qt.DummyQuantizer,
+ 'rvq': qt.ResidualVectorQuantizer
+ }[quantizer]
+ kwargs = dict_from_config(getattr(cfg, quantizer))
+ if quantizer != 'no_quant':
+ kwargs['dimension'] = dimension
+ return klass(**kwargs)
+
+
+def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
+ if encoder_name == 'seanet':
+ kwargs = dict_from_config(getattr(cfg, 'seanet'))
+ encoder_override_kwargs = kwargs.pop('encoder')
+ decoder_override_kwargs = kwargs.pop('decoder')
+ encoder_kwargs = {**kwargs, **encoder_override_kwargs}
+ decoder_kwargs = {**kwargs, **decoder_override_kwargs}
+ encoder = audiocraft.modules.SEANetEncoder(**encoder_kwargs)
+ decoder = audiocraft.modules.SEANetDecoder(**decoder_kwargs)
+ return encoder, decoder
+ else:
+ raise KeyError(f'Unexpected compression model {cfg.compression_model}')
+
+
+def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel:
+ """Instantiate a compression model.
+ """
+ if cfg.compression_model == 'encodec':
+ kwargs = dict_from_config(getattr(cfg, 'encodec'))
+ encoder_name = kwargs.pop('autoencoder')
+ quantizer_name = kwargs.pop('quantizer')
+ encoder, decoder = get_encodec_autoencoder(encoder_name, cfg)
+ quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
+ frame_rate = kwargs['sample_rate'] // encoder.hop_length
+ renormalize = kwargs.pop('renormalize', None)
+ renorm = kwargs.pop('renorm')
+ if renormalize is None:
+ renormalize = renorm is not None
+ warnings.warn("You are using a deprecated EnCodec model. Please migrate to new renormalization.")
+ return EncodecModel(encoder, decoder, quantizer,
+ frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
+ else:
+ raise KeyError(f'Unexpected compression model {cfg.compression_model}')
+
+
+def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
+ """Instantiate a transformer LM.
+ """
+ if cfg.lm_model == 'transformer_lm':
+ kwargs = dict_from_config(getattr(cfg, 'transformer_lm'))
+ n_q = kwargs['n_q']
+ q_modeling = kwargs.pop('q_modeling', None)
+ codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
+ attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
+ cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
+ cfg_prob, cfg_coef = cls_free_guidance["training_dropout"], cls_free_guidance["inference_coef"]
+ fuser = get_condition_fuser(cfg)
+ condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device)
+ if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programatically
+ kwargs['cross_attention'] = True
+ if codebooks_pattern_cfg.modeling is None:
+ assert q_modeling is not None, \
+ 'LM model should either have a codebook pattern defined or transformer_lm.q_modeling'
+ codebooks_pattern_cfg = omegaconf.OmegaConf.create(
+ {'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}}
+ )
+ pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
+ return LMModel(
+ pattern_provider=pattern_provider,
+ condition_provider=condition_provider,
+ fuser=fuser,
+ cfg_dropout=cfg_prob,
+ cfg_coef=cfg_coef,
+ attribute_dropout=attribute_dropout,
+ dtype=getattr(torch, cfg.dtype),
+ device=cfg.device,
+ **kwargs
+ ).to(cfg.device)
+ else:
+ raise KeyError(f'Unexpected LM model {cfg.lm_model}')
+
+
+def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider:
+ """Instantiate a conditioning model.
+ """
+ device = cfg.device
+ duration = cfg.dataset.segment_duration
+ cfg = getattr(cfg, "conditioners")
+ cfg = omegaconf.OmegaConf.create({}) if cfg is None else cfg
+ conditioners: tp.Dict[str, BaseConditioner] = {}
+ with omegaconf.open_dict(cfg):
+ condition_provider_args = cfg.pop('args', {})
+ for cond, cond_cfg in cfg.items():
+ model_type = cond_cfg["model"]
+ model_args = cond_cfg[model_type]
+ if model_type == "t5":
+ conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args)
+ elif model_type == "lut":
+ conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args)
+ elif model_type == "chroma_stem":
+ model_args.pop('cache_path', None)
+ conditioners[str(cond)] = ChromaStemConditioner(
+ output_dim=output_dim,
+ duration=duration,
+ device=device,
+ **model_args
+ )
+ else:
+ raise ValueError(f"unrecognized conditioning model: {model_type}")
+ conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args)
+ return conditioner
+
+
+def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
+ """Instantiate a condition fuser object.
+ """
+ fuser_cfg = getattr(cfg, "fuser")
+ fuser_methods = ["sum", "cross", "prepend", "input_interpolate"]
+ fuse2cond = {k: fuser_cfg[k] for k in fuser_methods}
+ kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
+ fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
+ return fuser
+
+
+def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider:
+ """Instantiate a codebooks pattern provider object.
+ """
+ pattern_providers = {
+ 'parallel': ParallelPatternProvider,
+ 'delay': DelayedPatternProvider,
+ 'unroll': UnrolledPatternProvider,
+ 'valle': VALLEPattern,
+ 'musiclm': MusicLMPattern,
+ }
+ name = cfg.modeling
+ kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
+ klass = pattern_providers[name]
+ return klass(n_q, **kwargs)
+
+
+def get_debug_compression_model(device='cpu'):
+ """Instantiate a debug compression model to be used for unit tests.
+ """
+ seanet_kwargs = {
+ 'n_filters': 4,
+ 'n_residual_layers': 1,
+ 'dimension': 32,
+ 'ratios': [10, 8, 16] # 25 Hz at 32kHz
+ }
+ encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs)
+ decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs)
+ quantizer = qt.ResidualVectorQuantizer(dimension=32, bins=400, n_q=4)
+ init_x = torch.randn(8, 32, 128)
+ quantizer(init_x, 1) # initialize kmeans etc.
+ compression_model = EncodecModel(
+ encoder, decoder, quantizer,
+ frame_rate=25, sample_rate=32000, channels=1).to(device)
+ return compression_model.eval()
+
+
+def get_debug_lm_model(device='cpu'):
+ """Instantiate a debug LM to be used for unit tests.
+ """
+ pattern = DelayedPatternProvider(n_q=4)
+ dim = 16
+ providers = {
+ 'description': LUTConditioner(n_bins=128, dim=dim, output_dim=dim, tokenizer="whitespace"),
+ }
+ condition_provider = ConditioningProvider(providers)
+ fuser = ConditionFuser(
+ {'cross': ['description'], 'prepend': [],
+ 'sum': [], 'input_interpolate': []})
+ lm = LMModel(
+ pattern, condition_provider, fuser,
+ n_q=4, card=400, dim=dim, num_heads=4, custom=True, num_layers=2,
+ cross_attention=True, causal=True)
+ return lm.to(device).eval()
diff --git a/audiocraft/models/encodec.py b/audiocraft/models/encodec.py
new file mode 100644
index 0000000..69621a6
--- /dev/null
+++ b/audiocraft/models/encodec.py
@@ -0,0 +1,302 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from abc import ABC, abstractmethod
+import typing as tp
+
+from einops import rearrange
+import torch
+from torch import nn
+
+from .. import quantization as qt
+
+
+class CompressionModel(ABC, nn.Module):
+
+ @abstractmethod
+ def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
+ ...
+
+ @abstractmethod
+ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+ """See `EncodecModel.encode`"""
+ ...
+
+ @abstractmethod
+ def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
+ """See `EncodecModel.decode`"""
+ ...
+
+ @property
+ @abstractmethod
+ def channels(self) -> int:
+ ...
+
+ @property
+ @abstractmethod
+ def frame_rate(self) -> int:
+ ...
+
+ @property
+ @abstractmethod
+ def sample_rate(self) -> int:
+ ...
+
+ @property
+ @abstractmethod
+ def cardinality(self) -> int:
+ ...
+
+ @property
+ @abstractmethod
+ def num_codebooks(self) -> int:
+ ...
+
+ @property
+ @abstractmethod
+ def total_codebooks(self) -> int:
+ ...
+
+ @abstractmethod
+ def set_num_codebooks(self, n: int):
+ """Set the active number of codebooks used by the quantizer.
+ """
+ ...
+
+
+class EncodecModel(CompressionModel):
+ """Encodec model operating on the raw waveform.
+
+ Args:
+ encoder (nn.Module): Encoder network.
+ decoder (nn.Module): Decoder network.
+ quantizer (qt.BaseQuantizer): Quantizer network.
+ frame_rate (int): Frame rate for the latent representation.
+ sample_rate (int): Audio sample rate.
+ channels (int): Number of audio channels.
+ causal (bool): Whether to use a causal version of the model.
+ renormalize (bool): Whether to renormalize the audio before running the model.
+ """
+ # we need assignement to override the property in the abstract class,
+ # I couldn't find a better way...
+ frame_rate: int = 0
+ sample_rate: int = 0
+ channels: int = 0
+
+ def __init__(self,
+ encoder: nn.Module,
+ decoder: nn.Module,
+ quantizer: qt.BaseQuantizer,
+ frame_rate: int,
+ sample_rate: int,
+ channels: int,
+ causal: bool = False,
+ renormalize: bool = False):
+ super().__init__()
+ self.encoder = encoder
+ self.decoder = decoder
+ self.quantizer = quantizer
+ self.frame_rate = frame_rate
+ self.sample_rate = sample_rate
+ self.channels = channels
+ self.renormalize = renormalize
+ self.causal = causal
+ if self.causal:
+ # we force disabling here to avoid handling linear overlap of segments
+ # as supported in original EnCodec codebase.
+ assert not self.renormalize, 'Causal model does not support renormalize'
+
+ @property
+ def total_codebooks(self):
+ """Total number of quantizer codebooks available.
+ """
+ return self.quantizer.total_codebooks
+
+ @property
+ def num_codebooks(self):
+ """Active number of codebooks used by the quantizer.
+ """
+ return self.quantizer.num_codebooks
+
+ def set_num_codebooks(self, n: int):
+ """Set the active number of codebooks used by the quantizer.
+ """
+ self.quantizer.set_num_codebooks(n)
+
+ @property
+ def cardinality(self):
+ """Cardinality of each codebook.
+ """
+ return self.quantizer.bins
+
+ def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+ scale: tp.Optional[torch.Tensor]
+ if self.renormalize:
+ mono = x.mean(dim=1, keepdim=True)
+ volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
+ scale = 1e-8 + volume
+ x = x / scale
+ scale = scale.view(-1, 1)
+ else:
+ scale = None
+ return x, scale
+
+ def postprocess(self,
+ x: torch.Tensor,
+ scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
+ if scale is not None:
+ assert self.renormalize
+ x = x * scale.view(-1, 1, 1)
+ return x
+
+ def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
+ assert x.dim() == 3
+ length = x.shape[-1]
+ x, scale = self.preprocess(x)
+
+ emb = self.encoder(x)
+ q_res = self.quantizer(emb, self.frame_rate)
+ out = self.decoder(q_res.x)
+
+ # remove extra padding added by the encoder and decoder
+ assert out.shape[-1] >= length, (out.shape[-1], length)
+ out = out[..., :length]
+
+ q_res.x = self.postprocess(out, scale)
+
+ return q_res
+
+ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+ """Encode the given input tensor to quantized representation along with scale parameter.
+
+ Args:
+ x (torch.Tensor): Float tensor of shape [B, C, T]
+
+ Returns:
+ codes, scale (tp.Tuple[torch.Tensor, torch.Tensor]): Tuple composed of:
+ codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
+ scale a float tensor containing the scale for audio renormalizealization.
+ """
+ assert x.dim() == 3
+ x, scale = self.preprocess(x)
+ emb = self.encoder(x)
+ codes = self.quantizer.encode(emb)
+ return codes, scale
+
+ def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
+ """Decode the given codes to a reconstructed representation, using the scale to perform
+ audio denormalization if needed.
+
+ Args:
+ codes (torch.Tensor): Int tensor of shape [B, K, T]
+ scale (tp.Optional[torch.Tensor]): Float tensor containing the scale value.
+
+ Returns:
+ out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
+ """
+ emb = self.quantizer.decode(codes)
+ out = self.decoder(emb)
+ out = self.postprocess(out, scale)
+ # out contains extra padding added by the encoder and decoder
+ return out
+
+
+class FlattenedCompressionModel(CompressionModel):
+ """Wraps a CompressionModel and flatten its codebooks, e.g.
+ instead of returning [B, K, T], return [B, S, T * (K // S)] with
+ S the number of codebooks per step, and `K // S` the number of 'virtual steps'
+ for each real time step.
+
+ Args:
+ model (CompressionModel): compression model to wrap.
+ codebooks_per_step (int): number of codebooks to keep per step,
+ this must divide the number of codebooks provided by the wrapped model.
+ extend_cardinality (bool): if True, and for instance if codebooks_per_step = 1,
+ if each codebook has a cardinality N, then the first codebook will
+ use the range [0, N - 1], and the second [N, 2 N - 1] etc.
+ On decoding, this can lead to potentially invalid sequences.
+ Any invalid entry will be silently remapped to the proper range
+ with a modulo.
+ """
+ def __init__(self, model: CompressionModel, codebooks_per_step: int = 1,
+ extend_cardinality: bool = True):
+ super().__init__()
+ self.model = model
+ self.codebooks_per_step = codebooks_per_step
+ self.extend_cardinality = extend_cardinality
+
+ @property
+ def total_codebooks(self):
+ return self.model.total_codebooks
+
+ @property
+ def num_codebooks(self):
+ """Active number of codebooks used by the quantizer.
+
+ ..Warning:: this reports the number of codebooks after the flattening
+ of the codebooks!
+ """
+ assert self.model.num_codebooks % self.codebooks_per_step == 0
+ return self.codebooks_per_step
+
+ def set_num_codebooks(self, n: int):
+ """Set the active number of codebooks used by the quantizer.
+
+ ..Warning:: this sets the number of codebooks **before** the flattening
+ of the codebooks.
+ """
+ assert n % self.codebooks_per_step == 0
+ self.model.set_num_codebooks(n)
+
+ @property
+ def num_virtual_steps(self) -> int:
+ """Return the number of virtual steps, e.g. one real step
+ will be split into that many steps.
+ """
+ return self.model.num_codebooks // self.codebooks_per_step
+
+ @property
+ def frame_rate(self) -> int:
+ return self.model.frame_rate * self.num_virtual_steps
+
+ @property
+ def sample_rate(self) -> int:
+ return self.model.sample_rate
+
+ @property
+ def channels(self) -> int:
+ return self.model.channels
+
+ @property
+ def cardinality(self):
+ """Cardinality of each codebook.
+ """
+ if self.extend_cardinality:
+ return self.model.cardinality * self.num_virtual_steps
+ else:
+ return self.model.cardinality
+
+ def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
+ raise NotImplementedError("Not supported, use encode and decode.")
+
+ def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
+ indices, scales = self.model.encode(x)
+ B, K, T = indices.shape
+ indices = rearrange(indices, 'b (k v) t -> b k t v', k=self.codebooks_per_step)
+ if self.extend_cardinality:
+ for virtual_step in range(1, self.num_virtual_steps):
+ indices[..., virtual_step] += self.model.cardinality * virtual_step
+ indices = rearrange(indices, 'b k t v -> b k (t v)')
+ return (indices, scales)
+
+ def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
+ B, K, T = codes.shape
+ assert T % self.num_virtual_steps == 0
+ codes = rearrange(codes, 'b k (t v) -> b (k v) t', v=self.num_virtual_steps)
+ # We silently ignore potential errors from the LM when
+ # using extend_cardinality.
+ codes = codes % self.model.cardinality
+ return self.model.decode(codes, scale)
diff --git a/audiocraft/models/lm.py b/audiocraft/models/lm.py
new file mode 100644
index 0000000..c8aad8f
--- /dev/null
+++ b/audiocraft/models/lm.py
@@ -0,0 +1,527 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from dataclasses import dataclass
+from functools import partial
+import logging
+import math
+import typing as tp
+
+import torch
+from torch import nn
+
+from ..utils import utils
+from ..modules.streaming import StreamingModule, State
+from ..modules.transformer import StreamingTransformer, create_norm_fn
+from ..modules.conditioners import (
+ ConditionFuser,
+ ClassifierFreeGuidanceDropout,
+ AttributeDropout,
+ ConditioningProvider,
+ ConditioningAttributes,
+ ConditionType,
+)
+from ..modules.codebooks_patterns import CodebooksPatternProvider
+from ..modules.activations import get_activation_fn
+
+
+logger = logging.getLogger(__name__)
+ConditionTensors = tp.Dict[str, ConditionType]
+CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]]
+
+
+def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None):
+ """LM layer initialization.
+ Inspired from xlformers: https://github.com/fairinternal/xlformers
+
+ Args:
+ method (str): Method name for init function. Valid options are:
+ 'gaussian', 'uniform'.
+ input_dim (int): Input dimension of the initialized module.
+ init_depth (Optional[int]): Optional init depth value used to rescale
+ the standard deviation if defined.
+ """
+ # Compute std
+ std = 1 / math.sqrt(input_dim)
+ # Rescale with depth
+ if init_depth is not None:
+ std = std / math.sqrt(2 * init_depth)
+
+ if method == 'gaussian':
+ return partial(
+ torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std
+ )
+ elif method == 'uniform':
+ bound = math.sqrt(3) * std # ensure the standard deviation is `std`
+ return partial(torch.nn.init.uniform_, a=-bound, b=bound)
+ else:
+ raise ValueError("Unsupported layer initialization method")
+
+
+def init_layer(m: nn.Module,
+ method: str,
+ init_depth: tp.Optional[int] = None,
+ zero_bias_init: bool = False):
+ """Wrapper around ``get_init_fn`` for proper initialization of LM modules.
+
+ Args:
+ m (nn.Module): Module to initialize.
+ method (str): Method name for the init function.
+ init_depth (Optional[int]): Optional init depth value used to rescale
+ the standard deviation if defined.
+ zero_bias_init (bool): Whether to initialize the bias to 0 or not.
+ """
+ if isinstance(m, nn.Linear):
+ init_fn = get_init_fn(method, m.in_features, init_depth=init_depth)
+ if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
+ weight = m.weight.float()
+ init_fn(weight)
+ m.weight.data[:] = weight.half()
+ else:
+ init_fn(m.weight)
+ if zero_bias_init and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Embedding):
+ init_fn = get_init_fn(method, m.embedding_dim, init_depth=None)
+ if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
+ weight = m.weight.float()
+ init_fn(weight)
+ m.weight.data[:] = weight.half()
+ else:
+ init_fn(m.weight)
+
+
+class ScaledEmbedding(nn.Embedding):
+ """Boost learning rate for embeddings (with `scale`).
+ """
+ def __init__(self, *args, lr=None, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.lr = lr
+
+ def make_optim_group(self):
+ group = {"params": list(self.parameters())}
+ if self.lr is not None:
+ group["lr"] = self.lr
+ return group
+
+
+@dataclass
+class LMOutput:
+ # The logits are already re-aligned with the input codes
+ # hence no extra shift is required, e.g. when computing CE
+ logits: torch.Tensor # [B, K, T, card]
+ mask: torch.Tensor # [B, K, T]
+
+
+class LMModel(StreamingModule):
+ """Transformer-based language model on multiple streams of codes.
+
+ Args:
+ pattern_provider (CodebooksPatternProvider): Pattern provider for codebook interleaving.
+ condition_provider (MusicConditioningProvider): Conditioning provider from metadata.
+ fuser (ConditionFuser): Fuser handling the fusing of conditions with language model input.
+ n_q (int): Number of parallel streams to model.
+ card (int): Cardinality, vocabulary size.
+ dim (int): Dimension of the transformer encoder.
+ num_heads (int): Number of heads for the transformer encoder.
+ hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder.
+ norm (str): Normalization method.
+ norm_first (bool): Use pre-norm instead of post-norm.
+ emb_lr (Optional[float]): Embedding-specific learning rate.
+ bias_proj (bool): Use bias for output projections.
+ weight_init (Optional[str]): Method for weight initialization.
+ depthwise_init (Optional[str]): Method for depthwise weight initialization.
+ zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros.
+ cfg_dropout (float): Classifier-free guidance dropout.
+ cfg_coef (float): Classifier-free guidance coefficient.
+ attribute_dropout (dict): Attribute dropout probabilities.
+ two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps.
+ **kwargs: Additional parameters for the transformer encoder.
+ """
+ def __init__(self, pattern_provider: CodebooksPatternProvider, condition_provider: ConditioningProvider,
+ fuser: ConditionFuser, n_q: int = 8, card: int = 1024, dim: int = 128, num_heads: int = 8,
+ hidden_scale: int = 4, norm: str = 'layer_norm', norm_first: bool = False,
+ emb_lr: tp.Optional[float] = None, bias_proj: bool = True,
+ weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None,
+ zero_bias_init: bool = False, cfg_dropout: float = 0, cfg_coef: float = 1.0,
+ attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {}, two_step_cfg: bool = False,
+ **kwargs):
+ super().__init__()
+ self.cfg_coef = cfg_coef
+ self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout)
+ self.att_dropout = AttributeDropout(p=attribute_dropout)
+ self.condition_provider = condition_provider
+ self.fuser = fuser
+ self.card = card
+ embed_dim = self.card + 1
+ self.n_q = n_q
+ self.dim = dim
+ self.pattern_provider = pattern_provider
+ self.two_step_cfg = two_step_cfg
+ self.emb = nn.ModuleList([ScaledEmbedding(embed_dim, dim, lr=emb_lr) for _ in range(n_q)])
+ if 'activation' in kwargs:
+ kwargs['activation'] = get_activation_fn(kwargs['activation'])
+ self.transformer = StreamingTransformer(
+ d_model=dim, num_heads=num_heads, dim_feedforward=int(hidden_scale * dim),
+ norm=norm, norm_first=norm_first, **kwargs)
+ self.out_norm: tp.Optional[nn.Module] = None
+ if norm_first:
+ self.out_norm = create_norm_fn(norm, dim)
+ self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=bias_proj) for _ in range(n_q)])
+ self._init_weights(weight_init, depthwise_init, zero_bias_init)
+ self._fsdp: tp.Optional[nn.Module]
+ self.__dict__['_fsdp'] = None
+
+ def _init_weights(self, weight_init: tp.Optional[str], depthwise_init: tp.Optional[str], zero_bias_init: bool):
+ """Initialization of the transformer module weights.
+
+ Args:
+ weight_init (Optional[str]): Weight initialization strategy. See ``get_init_fn`` for valid options.
+ depthwise_init (Optional[str]): Depwthwise initialization strategy. The following options are valid:
+ 'current' where the depth corresponds to the current layer index or 'global' where the total number
+ of layer is used as depth. If not set, no depthwise initialization strategy is used.
+ zero_bias_init (bool): Whether to initalize bias to zero or not.
+ """
+ assert depthwise_init is None or depthwise_init in ['current', 'global']
+ assert depthwise_init is None or weight_init is not None, \
+ "If 'depthwise_init' is defined, a 'weight_init' method should be provided."
+ assert not zero_bias_init or weight_init is not None, \
+ "If 'zero_bias_init', a 'weight_init' method should be provided"
+
+ if weight_init is None:
+ return
+
+ for emb_layer in self.emb:
+ init_layer(emb_layer, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
+
+ for layer_idx, tr_layer in enumerate(self.transformer.layers):
+ depth = None
+ if depthwise_init == 'current':
+ depth = layer_idx + 1
+ elif depthwise_init == 'global':
+ depth = len(self.transformer.layers)
+ init_fn = partial(init_layer, method=weight_init, init_depth=depth, zero_bias_init=zero_bias_init)
+ tr_layer.apply(init_fn)
+
+ for linear in self.linears:
+ init_layer(linear, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
+
+ @property
+ def special_token_id(self) -> int:
+ return self.card
+
+ @property
+ def num_codebooks(self) -> int:
+ return self.n_q
+
+ def forward(self, sequence: torch.Tensor,
+ conditions: tp.List[ConditioningAttributes],
+ condition_tensors: tp.Optional[ConditionTensors] = None) -> torch.Tensor:
+ """Apply language model on sequence and conditions.
+ Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and
+ S the sequence steps, return the logits with shape [B, card, K, S].
+
+ Args:
+ indices (torch.Tensor): indices of the codes to model.
+ conditions (list[ConditioningAttributes]): conditionings to use when modeling
+ the given codes. Note that when evaluating multiple time with the same conditioning
+ you should pre-compute those and pass them as `condition_tensors`.
+ condition_tensors (dict[str, ConditionType] or None): pre-computed conditioning
+ tensors, see `conditions`.
+ Returns:
+ torch.Tensor: Logits.
+ """
+ B, K, S = sequence.shape
+ assert K == self.num_codebooks, 'Sequence shape must match the specified number of codebooks'
+ input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
+ if condition_tensors is None:
+ assert not self._is_streaming, "Conditions tensors should be precomputed when streaming."
+ # apply dropout modules
+ conditions = self.cfg_dropout(conditions)
+ conditions = self.att_dropout(conditions)
+ tokenized = self.condition_provider.tokenize(conditions)
+ # encode conditions and fuse, both have a streaming cache to not recompute when generating.
+ condition_tensors = self.condition_provider(tokenized)
+ else:
+ assert not conditions, "Shouldn't pass both conditions and condition_tensors."
+
+ input_, cross_attention_input = self.fuser(input_, condition_tensors)
+
+ out = self.transformer(input_, cross_attention_src=cross_attention_input)
+ if self.out_norm:
+ out = self.out_norm(out)
+ logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1) # [B, K, S, card]
+
+ # remove the prefix from the model outputs
+ if len(self.fuser.fuse2cond['prepend']) > 0:
+ logits = logits[:, :, -S:]
+
+ return logits # [B, K, S, card]
+
+ def compute_predictions(
+ self, codes: torch.Tensor,
+ conditions: tp.List[ConditioningAttributes],
+ condition_tensors: tp.Optional[ConditionTensors] = None) -> LMOutput:
+ """Given an input tensor of codes [B, K, T] and list of conditions, runs the model
+ forward using the specified codes interleaving pattern.
+
+ Args:
+ codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size,
+ K the number of codebooks and T the number of timesteps.
+ conditions (list[ConditioningAttributes]): conditionings to use when modeling
+ the given codes. Note that when evaluating multiple time with the same conditioning
+ you should pre-compute those and pass them as `condition_tensors`.
+ condition_tensors (dict[str, ConditionType] or None): pre-computed conditioning
+ tensors, see `conditions`.
+ Returns:
+ LMOutput: Language model outputs
+ logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes,
+ i.e. the first item corresponds to logits to predict the first code, meaning that
+ no additional shifting of codes and logits is required.
+ mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions.
+ Given the specified interleaving strategies, parts of the logits and codes should
+ not be considered as valid predictions because of invalid context.
+ """
+ B, K, T = codes.shape
+ codes = codes.contiguous()
+ # map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
+ pattern = self.pattern_provider.get_pattern(T)
+ sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
+ codes, self.special_token_id, keep_only_valid_steps=True
+ )
+ # apply model on pattern sequence
+ model = self if self._fsdp is None else self._fsdp
+ logits = model(sequence_codes, conditions, condition_tensors) # [B, K, S, card]
+ # map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card]
+ # and provide the corresponding mask over invalid positions of tokens
+ logits = logits.permute(0, 3, 1, 2) # [B, card, K, S]
+ # note: we use nans as special token to make it obvious if we feed unexpected logits
+ logits, logits_indexes, logits_mask = pattern.revert_pattern_logits(
+ logits, float('nan'), keep_only_valid_steps=True
+ )
+ logits = logits.permute(0, 2, 3, 1) # [B, K, T, card]
+ logits_mask = logits_mask[None, :, :].expand(B, -1, -1) # [K, T] -> [B, K, T]
+ return LMOutput(logits, logits_mask)
+
+ def _sample_next_token(self,
+ sequence: torch.Tensor,
+ cfg_conditions: CFGConditions,
+ unconditional_state: State,
+ use_sampling: bool = False,
+ temp: float = 1.0,
+ top_k: int = 0,
+ top_p: float = 0.0,
+ cfg_coef: tp.Optional[float] = None) -> torch.Tensor:
+ """Sample next token from the model given a sequence and a set of conditions. The model supports
+ multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).
+
+ Args:
+ sequence (torch.Tensor): Current sequence of shape [B, K, S]
+ with K corresponding to the number of codebooks and S the number of sequence steps.
+ S = 1 in streaming mode, except for the first step that contains a bigger prompt.
+ condition_tensors (Dict[str, ConditionType): Set of conditions. If CFG is used,
+ should be twice the batch size, being the concatenation of the conditions + null conditions.
+ use_sampling (bool): Whether to use a sampling strategy or not.
+ temp (float): Sampling temperature.
+ top_k (int): K for "top-k" sampling.
+ top_p (float): P for "top-p" sampling.
+ cfg_coef (float): classifier free guidance coefficient
+ Returns:
+ next_token (torch.Tensor): Next token tensor of shape [B, K, 1].
+ """
+ B = sequence.shape[0]
+ cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
+ model = self if self._fsdp is None else self._fsdp
+ if self.two_step_cfg and cfg_conditions != {}:
+ assert isinstance(cfg_conditions, tuple)
+ condition_tensors, null_condition_tensors = cfg_conditions
+ cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors)
+ state = self.get_streaming_state()
+ self.set_streaming_state(unconditional_state)
+ uncond_logits = model(sequence, conditions=[], condition_tensors=null_condition_tensors)
+ unconditional_state.update(self.get_streaming_state())
+ self.set_streaming_state(state)
+ logits = uncond_logits + (cond_logits - uncond_logits) * self.cfg_coef
+ else:
+ assert isinstance(cfg_conditions, dict)
+ condition_tensors = cfg_conditions
+ if condition_tensors:
+ # Preparing for CFG, predicting both conditional and unconditional logits.
+ sequence = torch.cat([sequence, sequence], dim=0)
+ all_logits = model(
+ sequence,
+ conditions=[], condition_tensors=condition_tensors)
+ if condition_tensors:
+ cond_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card]
+ logits = uncond_logits + (cond_logits - uncond_logits) * cfg_coef
+ else:
+ logits = all_logits
+
+ logits = logits.permute(0, 1, 3, 2) # [B, K, card, T]
+ logits = logits[..., -1] # [B x K x card]
+
+ # Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
+ if use_sampling and temp > 0.0:
+ probs = torch.softmax(logits / temp, dim=-1)
+ if top_p > 0.0:
+ next_token = utils.sample_top_p(probs, p=top_p)
+ elif top_k > 0:
+ next_token = utils.sample_top_k(probs, k=top_k)
+ else:
+ next_token = utils.multinomial(probs, num_samples=1)
+ else:
+ next_token = torch.argmax(logits, dim=-1, keepdim=True)
+
+ return next_token
+
+ @torch.no_grad()
+ def generate(self,
+ prompt: tp.Optional[torch.Tensor] = None,
+ conditions: tp.List[ConditioningAttributes] = [],
+ num_samples: tp.Optional[int] = None,
+ max_gen_len: int = 256,
+ use_sampling: bool = True,
+ temp: float = 1.0,
+ top_k: int = 250,
+ top_p: float = 0.0,
+ cfg_coef: tp.Optional[float] = None,
+ two_step_cfg: bool = False,
+ remove_prompts: bool = False,
+ check: bool = False,
+ callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> torch.Tensor:
+ """Generate tokens sampling from the model given a prompt or unconditionally. Generation can
+ be perform in a greedy fashion or using sampling with top K and top P strategies.
+
+ Args:
+ prompt (Optional[torch.Tensor]): Prompt tokens of shape [B, K, T].
+ conditions_tensors (Dict[str, torch.Tensor]): Set of conditions or None.
+ num_samples (int or None): Number of samples to generate when no prompt and no conditions are given.
+ max_gen_len (int): Maximum generation length.
+ use_sampling (bool): Whether to use a sampling strategy or not.
+ temp (float): Sampling temperature.
+ top_k (int): K for "top-k" sampling.
+ top_p (float): P for "top-p" sampling.
+ remove_prompts (bool): Whether to remove prompts from generation or not.
+ Returns:
+ torch.Tensor: Generated tokens.
+ """
+ assert not self.training, "generation shouldn't be used in training mode."
+ first_param = next(iter(self.parameters()))
+ device = first_param.device
+
+ # Checking all input shapes are consistents.
+ possible_num_samples = []
+ if num_samples is not None:
+ possible_num_samples.append(num_samples)
+ elif prompt is not None:
+ possible_num_samples.append(prompt.shape[0])
+ elif conditions:
+ possible_num_samples.append(len(conditions))
+ else:
+ possible_num_samples.append(1)
+ assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsitent inputs shapes"
+ num_samples = possible_num_samples[0]
+
+ # below we create set of conditions: one conditional and one unconditional
+ # to do that we merge the regular condition together with the null condition
+ # we then do 1 forward pass instead of 2.
+ # the reason for that is two-fold:
+ # 1. it is about x2 faster than doing 2 forward passes
+ # 2. avoid the streaming API treating the 2 passes as part of different time steps
+ # We also support doing two different passes, in particular to ensure that
+ # the padding structure is exactly the same between train anf test.
+ # With a batch size of 1, this can be slower though.
+ cfg_conditions: CFGConditions
+ two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
+ if conditions:
+ null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
+ if two_step_cfg:
+ cfg_conditions = (
+ self.condition_provider(self.condition_provider.tokenize(conditions)),
+ self.condition_provider(self.condition_provider.tokenize(null_conditions)),
+ )
+ else:
+ conditions = conditions + null_conditions
+ tokenized = self.condition_provider.tokenize(conditions)
+ cfg_conditions = self.condition_provider(tokenized)
+ else:
+ cfg_conditions = {}
+
+ if prompt is None:
+ assert num_samples > 0
+ prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device)
+
+ B, K, T = prompt.shape
+ start_offset = T
+ assert start_offset < max_gen_len
+
+ pattern = self.pattern_provider.get_pattern(max_gen_len)
+ # this token is used as default value for codes that are not generated yet
+ unknown_token = -1
+
+ # we generate codes up to the max_gen_len that will be mapped to the pattern sequence
+ gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device)
+ # filling the gen_codes with the prompt if needed
+ gen_codes[..., :start_offset] = prompt
+ # create the gen_sequence with proper interleaving from the pattern: [B, K, S]
+ gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
+ # retrieve the start_offset in the sequence:
+ # it is the first sequence step that contains the `start_offset` timestep
+ start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
+ assert start_offset_sequence is not None
+
+ with self.streaming():
+ unconditional_state = self.get_streaming_state()
+ prev_offset = 0
+ gen_sequence_len = gen_sequence.shape[-1] # gen_sequence shape is [B, K, S]
+ for offset in range(start_offset_sequence, gen_sequence_len):
+ # get current sequence (note that the streaming API is providing the caching over previous offsets)
+ curr_sequence = gen_sequence[..., prev_offset:offset]
+ curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1)
+ if check:
+ # check coherence between mask and sequence
+ assert (curr_sequence == torch.where(curr_mask, curr_sequence, self.special_token_id)).all()
+ # should never happen as gen_sequence is filled progressively
+ assert not (curr_sequence == unknown_token).any()
+ # sample next token from the model, next token shape is [B, K, 1]
+ next_token = self._sample_next_token(
+ curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p,
+ cfg_coef=cfg_coef)
+ # ensure the tokens that should be masked are properly set to special_token_id
+ # as the model never output special_token_id
+ valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
+ next_token[~valid_mask] = self.special_token_id
+ # ensure we don't overwrite prompt tokens, we only write over unknown tokens
+ # (then mask tokens should be left as is as well, which is correct)
+ gen_sequence[..., offset:offset+1] = torch.where(
+ gen_sequence[..., offset:offset+1] == unknown_token,
+ next_token, gen_sequence[..., offset:offset+1]
+ )
+ prev_offset = offset
+ if callback is not None:
+ callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
+ unconditional_state.clear()
+
+ # ensure sequence has been entirely filled
+ assert not (gen_sequence == unknown_token).any()
+ # ensure gen_sequence pattern and mask are matching
+ # which means the gen_sequence is valid according to the pattern
+ assert (
+ gen_sequence == torch.where(mask[None, ...].expand(B, -1, -1), gen_sequence, self.special_token_id)
+ ).all()
+ # get back the codes, trimming the prompt if needed and cutting potentially incomplete timesteps
+ out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
+
+ # sanity checks over the returned codes and corresponding masks
+ assert (out_codes[..., :max_gen_len] != unknown_token).all()
+ assert (out_mask[..., :max_gen_len] == 1).all()
+
+ out_start_offset = start_offset if remove_prompts else 0
+ out_codes = out_codes[..., out_start_offset:max_gen_len]
+
+ # ensure the returned codes are all valid
+ assert (out_codes >= 0).all() and (out_codes <= self.card).all()
+ return out_codes
diff --git a/audiocraft/models/loaders.py b/audiocraft/models/loaders.py
new file mode 100644
index 0000000..19837d4
--- /dev/null
+++ b/audiocraft/models/loaders.py
@@ -0,0 +1,90 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Utility functions to load from the checkpoints.
+Each checkpoint is a torch.saved dict with the following keys:
+- 'xp.cfg': the hydra config as dumped during training. This should be used
+ to rebuild the object using the audiocraft.models.builders functions,
+- 'model_best_state': a readily loadable best state for the model, including
+ the conditioner. The model obtained from `xp.cfg` should be compatible
+ with this state dict. In the case of a LM, the encodec model would not be
+ bundled along but instead provided separately.
+
+Those functions also support loading from a remote location with the Torch Hub API.
+They also support overriding some parameters, in particular the device and dtype
+of the returned model.
+"""
+
+from pathlib import Path
+from huggingface_hub import hf_hub_download
+import typing as tp
+import os
+
+from omegaconf import OmegaConf
+import torch
+
+from . import builders
+
+
+HF_MODEL_CHECKPOINTS_MAP = {
+ "small": "facebook/musicgen-small",
+ "medium": "facebook/musicgen-medium",
+ "large": "facebook/musicgen-large",
+ "melody": "facebook/musicgen-melody",
+}
+
+
+def _get_state_dict(
+ file_or_url_or_id: tp.Union[Path, str],
+ filename: tp.Optional[str] = None,
+ device='cpu',
+ cache_dir: tp.Optional[str] = None,
+):
+ # Return the state dict either from a file or url
+ file_or_url_or_id = str(file_or_url_or_id)
+ assert isinstance(file_or_url_or_id, str)
+
+ if os.path.isfile(file_or_url_or_id):
+ return torch.load(file_or_url_or_id, map_location=device)
+
+ elif file_or_url_or_id.startswith('https://'):
+ return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True)
+
+ elif file_or_url_or_id in HF_MODEL_CHECKPOINTS_MAP:
+ assert filename is not None, "filename needs to be defined if using HF checkpoints"
+
+ repo_id = HF_MODEL_CHECKPOINTS_MAP[file_or_url_or_id]
+ file = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir)
+ return torch.load(file, map_location=device)
+
+ else:
+ raise ValueError(f"{file_or_url_or_id} is not a valid name, path or link that can be loaded.")
+
+
+def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
+ pkg = _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir)
+ cfg = OmegaConf.create(pkg['xp.cfg'])
+ cfg.device = str(device)
+ model = builders.get_compression_model(cfg)
+ model.load_state_dict(pkg['best_state'])
+ model.eval()
+ return model
+
+
+def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
+ pkg = _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir)
+ cfg = OmegaConf.create(pkg['xp.cfg'])
+ cfg.device = str(device)
+ if cfg.device == 'cpu':
+ cfg.dtype = 'float32'
+ else:
+ cfg.dtype = 'float16'
+ model = builders.get_lm_model(cfg)
+ model.load_state_dict(pkg['best_state'])
+ model.eval()
+ model.cfg = cfg
+ return model
diff --git a/audiocraft/models/musicgen.py b/audiocraft/models/musicgen.py
new file mode 100644
index 0000000..2870b27
--- /dev/null
+++ b/audiocraft/models/musicgen.py
@@ -0,0 +1,361 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Main model for using MusicGen. This will combine all the required components
+and provide easy access to the generation API.
+"""
+
+import os
+import typing as tp
+
+import torch
+
+from .encodec import CompressionModel
+from .lm import LMModel
+from .builders import get_debug_compression_model, get_debug_lm_model
+from .loaders import load_compression_model, load_lm_model, HF_MODEL_CHECKPOINTS_MAP
+from ..data.audio_utils import convert_audio
+from ..modules.conditioners import ConditioningAttributes, WavCondition
+from ..utils.autocast import TorchAutocast
+
+
+MelodyList = tp.List[tp.Optional[torch.Tensor]]
+MelodyType = tp.Union[torch.Tensor, MelodyList]
+
+
+class MusicGen:
+ """MusicGen main model with convenient generation API.
+
+ Args:
+ name (str): name of the model.
+ compression_model (CompressionModel): Compression model
+ used to map audio to invertible discrete representations.
+ lm (LMModel): Language model over discrete representations.
+ """
+ def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
+ max_duration: float = 30):
+ self.name = name
+ self.compression_model = compression_model
+ self.lm = lm
+ self.max_duration = max_duration
+ self.device = next(iter(lm.parameters())).device
+ self.generation_params: dict = {}
+ self.set_generation_params(duration=15) # 15 seconds by default
+ self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
+ if self.device.type == 'cpu':
+ self.autocast = TorchAutocast(enabled=False)
+ else:
+ self.autocast = TorchAutocast(
+ enabled=True, device_type=self.device.type, dtype=torch.float16)
+
+ @property
+ def frame_rate(self) -> int:
+ """Roughly the number of AR steps per seconds."""
+ return self.compression_model.frame_rate
+
+ @property
+ def sample_rate(self) -> int:
+ """Sample rate of the generated audio."""
+ return self.compression_model.sample_rate
+
+ @property
+ def audio_channels(self) -> int:
+ """Audio channels of the generated audio."""
+ return self.compression_model.channels
+
+ @staticmethod
+ def get_pretrained(name: str = 'melody', device=None):
+ """Return pretrained model, we provide four models:
+ - small (300M), text to music, # see: https://huggingface.co/facebook/musicgen-small
+ - medium (1.5B), text to music, # see: https://huggingface.co/facebook/musicgen-medium
+ - melody (1.5B) text to music and text+melody to music, # see: https://huggingface.co/facebook/musicgen-melody
+ - large (3.3B), text to music, # see: https://huggingface.co/facebook/musicgen-large
+ """
+
+ if device is None:
+ if torch.cuda.device_count():
+ device = 'cuda'
+ else:
+ device = 'cpu'
+
+ if name == 'debug':
+ # used only for unit tests
+ compression_model = get_debug_compression_model(device)
+ lm = get_debug_lm_model(device)
+ return MusicGen(name, compression_model, lm)
+
+ if name not in HF_MODEL_CHECKPOINTS_MAP:
+ raise ValueError(
+ f"{name} is not a valid checkpoint name. "
+ f"Choose one of {', '.join(HF_MODEL_CHECKPOINTS_MAP.keys())}"
+ )
+
+ cache_dir = os.environ.get('MUSICGEN_ROOT', None)
+ compression_model = load_compression_model(name, device=device, cache_dir=cache_dir)
+ lm = load_lm_model(name, device=device, cache_dir=cache_dir)
+ if name == 'melody':
+ lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True
+
+ return MusicGen(name, compression_model, lm)
+
+ def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
+ top_p: float = 0.0, temperature: float = 1.0,
+ duration: float = 30.0, cfg_coef: float = 3.0,
+ two_step_cfg: bool = False, extend_stride: float = 18):
+ """Set the generation parameters for MusicGen.
+
+ Args:
+ use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
+ top_k (int, optional): top_k used for sampling. Defaults to 250.
+ top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
+ temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
+ duration (float, optional): Duration of the generated waveform. Defaults to 30.0.
+ cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
+ two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
+ instead of batching together the two. This has some impact on how things
+ are padded but seems to have little impact in practice.
+ extend_stride: when doing extended generation (i.e. more than 30 seconds), by how much
+ should we extend the audio each time. Larger values will mean less context is
+ preserved, and shorter value will require extra computations.
+ """
+ assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
+ self.extend_stride = extend_stride
+ self.duration = duration
+ self.generation_params = {
+ 'use_sampling': use_sampling,
+ 'temp': temperature,
+ 'top_k': top_k,
+ 'top_p': top_p,
+ 'cfg_coef': cfg_coef,
+ 'two_step_cfg': two_step_cfg,
+ }
+
+ def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
+ """Override the default progress callback."""
+ self._progress_callback = progress_callback
+
+ def generate_unconditional(self, num_samples: int, progress: bool = False) -> torch.Tensor:
+ """Generate samples in an unconditional manner.
+
+ Args:
+ num_samples (int): Number of samples to be generated.
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+ """
+ descriptions: tp.List[tp.Optional[str]] = [None] * num_samples
+ attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
+ return self._generate_tokens(attributes, prompt_tokens, progress)
+
+ def generate(self, descriptions: tp.List[str], progress: bool = False) -> torch.Tensor:
+ """Generate samples conditioned on text.
+
+ Args:
+ descriptions (tp.List[str]): A list of strings used as text conditioning.
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+ """
+ attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
+ assert prompt_tokens is None
+ return self._generate_tokens(attributes, prompt_tokens, progress)
+
+ def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs: MelodyType,
+ melody_sample_rate: int, progress: bool = False) -> torch.Tensor:
+ """Generate samples conditioned on text and melody.
+
+ Args:
+ descriptions (tp.List[str]): A list of strings used as text conditioning.
+ melody_wavs: (torch.Tensor or list of Tensor): A batch of waveforms used as
+ melody conditioning. Should have shape [B, C, T] with B matching the description length,
+ C=1 or 2. It can be [C, T] if there is a single description. It can also be
+ a list of [C, T] tensors.
+ melody_sample_rate: (int): Sample rate of the melody waveforms.
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+ """
+ if isinstance(melody_wavs, torch.Tensor):
+ if melody_wavs.dim() == 2:
+ melody_wavs = melody_wavs[None]
+ if melody_wavs.dim() != 3:
+ raise ValueError("Melody wavs should have a shape [B, C, T].")
+ melody_wavs = list(melody_wavs)
+ else:
+ for melody in melody_wavs:
+ if melody is not None:
+ assert melody.dim() == 2, "One melody in the list has the wrong number of dims."
+
+ melody_wavs = [
+ convert_audio(wav, melody_sample_rate, self.sample_rate, self.audio_channels)
+ if wav is not None else None
+ for wav in melody_wavs]
+ attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=None,
+ melody_wavs=melody_wavs)
+ assert prompt_tokens is None
+ return self._generate_tokens(attributes, prompt_tokens, progress)
+
+ def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
+ descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
+ progress: bool = False) -> torch.Tensor:
+ """Generate samples conditioned on audio prompts.
+
+ Args:
+ prompt (torch.Tensor): A batch of waveforms used for continuation.
+ Prompt should be [B, C, T], or [C, T] if only one sample is generated.
+ prompt_sample_rate (int): Sampling rate of the given audio waveforms.
+ descriptions (tp.List[str], optional): A list of strings used as text conditioning. Defaults to None.
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+ """
+ if prompt.dim() == 2:
+ prompt = prompt[None]
+ if prompt.dim() != 3:
+ raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).")
+ prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels)
+ if descriptions is None:
+ descriptions = [None] * len(prompt)
+ attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
+ assert prompt_tokens is not None
+ return self._generate_tokens(attributes, prompt_tokens, progress)
+
+ @torch.no_grad()
+ def _prepare_tokens_and_attributes(
+ self,
+ descriptions: tp.Sequence[tp.Optional[str]],
+ prompt: tp.Optional[torch.Tensor],
+ melody_wavs: tp.Optional[MelodyList] = None,
+ ) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
+ """Prepare model inputs.
+
+ Args:
+ descriptions (tp.List[str]): A list of strings used as text conditioning.
+ prompt (torch.Tensor): A batch of waveforms used for continuation.
+ melody_wavs (tp.Optional[torch.Tensor], optional): A batch of waveforms
+ used as melody conditioning. Defaults to None.
+ """
+ attributes = [
+ ConditioningAttributes(text={'description': description})
+ for description in descriptions]
+
+ if melody_wavs is None:
+ for attr in attributes:
+ attr.wav['self_wav'] = WavCondition(
+ torch.zeros((1, 1), device=self.device),
+ torch.tensor([0], device=self.device),
+ path='null_wav') # type: ignore
+ else:
+ if self.name != "melody":
+ raise RuntimeError("This model doesn't support melody conditioning. "
+ "Use the `melody` model.")
+ assert len(melody_wavs) == len(descriptions), \
+ f"number of melody wavs must match number of descriptions! " \
+ f"got melody len={len(melody_wavs)}, and descriptions len={len(descriptions)}"
+ for attr, melody in zip(attributes, melody_wavs):
+ if melody is None:
+ attr.wav['self_wav'] = WavCondition(
+ torch.zeros((1, 1), device=self.device),
+ torch.tensor([0], device=self.device),
+ path='null_wav') # type: ignore
+ else:
+ attr.wav['self_wav'] = WavCondition(
+ melody.to(device=self.device),
+ torch.tensor([melody.shape[-1]], device=self.device))
+
+ if prompt is not None:
+ if descriptions is not None:
+ assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match"
+ prompt = prompt.to(self.device)
+ prompt_tokens, scale = self.compression_model.encode(prompt)
+ assert scale is None
+ else:
+ prompt_tokens = None
+ return attributes, prompt_tokens
+
+ def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
+ prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
+ """Generate discrete audio tokens given audio prompt and/or conditions.
+
+ Args:
+ attributes (tp.List[ConditioningAttributes]): Conditions used for generation (text/melody).
+ prompt_tokens (tp.Optional[torch.Tensor]): Audio prompt used for continuation.
+ progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
+ Returns:
+ torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
+ """
+ total_gen_len = int(self.duration * self.frame_rate)
+ max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
+ current_gen_offset: int = 0
+
+ def _progress_callback(generated_tokens: int, tokens_to_generate: int):
+ generated_tokens += current_gen_offset
+ if self._progress_callback is not None:
+ # Note that total_gen_len might be quite wrong depending on the
+ # codebook pattern used, but with delay it is almost accurate.
+ self._progress_callback(generated_tokens, total_gen_len)
+ else:
+ print(f'{generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
+
+ if prompt_tokens is not None:
+ assert max_prompt_len >= prompt_tokens.shape[-1], \
+ "Prompt is longer than audio to generate"
+
+ callback = None
+ if progress:
+ callback = _progress_callback
+
+ if self.duration <= self.max_duration:
+ # generate by sampling from LM, simple case.
+ with self.autocast:
+ gen_tokens = self.lm.generate(
+ prompt_tokens, attributes,
+ callback=callback, max_gen_len=total_gen_len, **self.generation_params)
+
+ else:
+ # now this gets a bit messier, we need to handle prompts,
+ # melody conditioning etc.
+ ref_wavs = [attr.wav['self_wav'] for attr in attributes]
+ all_tokens = []
+ if prompt_tokens is None:
+ prompt_length = 0
+ else:
+ all_tokens.append(prompt_tokens)
+ prompt_length = prompt_tokens.shape[-1]
+
+ stride_tokens = int(self.frame_rate * self.extend_stride)
+
+ while current_gen_offset + prompt_length < total_gen_len:
+ time_offset = current_gen_offset / self.frame_rate
+ chunk_duration = min(self.duration - time_offset, self.max_duration)
+ max_gen_len = int(chunk_duration * self.frame_rate)
+ for attr, ref_wav in zip(attributes, ref_wavs):
+ wav_length = ref_wav.length.item()
+ if wav_length == 0:
+ continue
+ # We will extend the wav periodically if it not long enough.
+ # we have to do it here rather than in conditioners.py as otherwise
+ # we wouldn't have the full wav.
+ initial_position = int(time_offset * self.sample_rate)
+ wav_target_length = int(self.max_duration * self.sample_rate)
+ print(initial_position / self.sample_rate, wav_target_length / self.sample_rate)
+ positions = torch.arange(initial_position,
+ initial_position + wav_target_length, device=self.device)
+ attr.wav['self_wav'] = WavCondition(
+ ref_wav[0][:, positions % wav_length],
+ torch.full_like(ref_wav[1], wav_target_length))
+ with self.autocast:
+ gen_tokens = self.lm.generate(
+ prompt_tokens, attributes,
+ callback=callback, max_gen_len=max_gen_len, **self.generation_params)
+ if prompt_tokens is None:
+ all_tokens.append(gen_tokens)
+ else:
+ all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
+ prompt_tokens = gen_tokens[:, :, stride_tokens:]
+ prompt_length = prompt_tokens.shape[-1]
+ current_gen_offset += stride_tokens
+
+ gen_tokens = torch.cat(all_tokens, dim=-1)
+
+ # generate audio
+ assert gen_tokens.dim() == 3
+ with torch.no_grad():
+ gen_audio = self.compression_model.decode(gen_tokens, None)
+ return gen_audio
diff --git a/audiocraft/modules/__init__.py b/audiocraft/modules/__init__.py
new file mode 100644
index 0000000..81ba30f
--- /dev/null
+++ b/audiocraft/modules/__init__.py
@@ -0,0 +1,20 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# flake8: noqa
+from .conv import (
+ NormConv1d,
+ NormConv2d,
+ NormConvTranspose1d,
+ NormConvTranspose2d,
+ StreamableConv1d,
+ StreamableConvTranspose1d,
+ pad_for_conv1d,
+ pad1d,
+ unpad1d,
+)
+from .lstm import StreamableLSTM
+from .seanet import SEANetEncoder, SEANetDecoder
diff --git a/audiocraft/modules/activations.py b/audiocraft/modules/activations.py
new file mode 100644
index 0000000..8bd6f29
--- /dev/null
+++ b/audiocraft/modules/activations.py
@@ -0,0 +1,96 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+from typing import Union, Callable
+
+
+class CustomGLU(nn.Module):
+ """Custom Gated Linear Unit activation.
+ Applies a modified gated linear unit :math:`a * f(b)` where :math:`a` is the first half
+ of the input matrices, :math:`b` is the second half, and :math:`f` is a provided activation
+ function (i.e. sigmoid, swish, etc.).
+
+ Args:
+ activation (nn.Module): The custom activation to apply in the Gated Linear Unit
+ dim (int): the dimension on which to split the input. Default: -1
+
+ Shape:
+ - Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
+ dimensions
+ - Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
+
+ Examples::
+ >>> m = CustomGLU(nn.Sigmoid())
+ >>> input = torch.randn(4, 2)
+ >>> output = m(input)
+ """
+ def __init__(self, activation: nn.Module, dim: int = -1):
+ super(CustomGLU, self).__init__()
+ self.dim = dim
+ self.activation = activation
+
+ def forward(self, x: Tensor):
+ assert x.shape[self.dim] % 2 == 0 # M = N / 2
+ a, b = torch.chunk(x, 2, dim=self.dim)
+ return a * self.activation(b)
+
+
+class SwiGLU(CustomGLU):
+ """SiLU Gated Linear Unit activation.
+ Applies SiLU Gated Linear Unit :math:`a * SiLU(b)` where :math:`a` is
+ the first half of the input matrices, :math:`b` is the second half.
+
+ Args:
+ dim (int): the dimension on which to split the input. Default: -1
+ """
+ def __init__(self, dim: int = -1):
+ super(SwiGLU, self).__init__(nn.SiLU(), dim)
+
+
+class GeGLU(CustomGLU):
+ """GeLU Gated Linear Unit activation.
+ Applies GeLU Gated Linear Unit :math:`a * GELU(b)` where :math:`a` is
+ the first half of the input matrices, :math:`b` is the second half.
+
+ Args:
+ dim (int): the dimension on which to split the input. Default: -1
+ """
+ def __init__(self, dim: int = -1):
+ super(GeGLU, self).__init__(nn.GELU(), dim)
+
+
+class ReGLU(CustomGLU):
+ """ReLU Gated Linear Unit activation.
+ Applies ReLU Gated Linear Unit :math:`a * ReLU(b)` where :math:`a` is
+ the first half of the input matrices, :math:`b` is the second half.
+
+ Args:
+ dim (int): the dimension on which to split the input. Default: -1
+ """
+ def __init__(self, dim: int = -1):
+ super(ReGLU, self).__init__(nn.ReLU(), dim)
+
+
+def get_activation_fn(
+ activation: Union[str, Callable[[Tensor], Tensor]]
+) -> Union[str, Callable[[Tensor], Tensor]]:
+ """Helper function to map an activation string to the activation class.
+ If the supplied activation is not a string that is recognized, the activation is passed back.
+
+ Args:
+ activation (Union[str, Callable[[Tensor], Tensor]]): Activation to check
+ """
+ if isinstance(activation, str):
+ if activation == "reglu":
+ return ReGLU()
+ elif activation == "geglu":
+ return GeGLU()
+ elif activation == "swiglu":
+ return SwiGLU()
+ return activation
diff --git a/audiocraft/modules/codebooks_patterns.py b/audiocraft/modules/codebooks_patterns.py
new file mode 100644
index 0000000..c5b35cb
--- /dev/null
+++ b/audiocraft/modules/codebooks_patterns.py
@@ -0,0 +1,539 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from collections import namedtuple
+from dataclasses import dataclass
+from functools import lru_cache
+import logging
+import typing as tp
+
+from abc import ABC, abstractmethod
+import torch
+
+LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index)
+PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class Pattern:
+ """Base implementation of a pattern over a sequence with multiple codebooks.
+
+ The codebook pattern consists in a layout, defining for each sequence step
+ the list of coordinates of each codebook timestep in the resulting interleaved sequence.
+ The first item of the pattern is always an empty list in order to properly insert a special token
+ to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern
+ and ``timesteps`` the number of timesteps corresponding to the original sequence.
+
+ The pattern provides convenient methods to build and revert interleaved sequences from it:
+ ``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
+ to the interleaved sequence of shape [B, K, S] applying the pattern, with S being the batch size,
+ K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
+ for the output sequence. The unfilled positions are replaced with a special token and the built sequence
+ is returned along with a mask indicating valid tokens.
+ ``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment
+ of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask
+ to fill and specify invalid positions if needed.
+ See the dedicated methods for more details.
+ """
+ # Pattern layout, for each sequence step, we have a list of coordinates
+ # corresponding to the original codebook timestep and position.
+ # The first list is always an empty list in order to properly insert
+ # a special token to start with.
+ layout: PatternLayout
+ timesteps: int
+ n_q: int
+
+ def __post_init__(self):
+ assert len(self.layout) > 0
+ assert self.layout[0] == []
+ self._validate_layout()
+ self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
+ self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
+ logger.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
+
+ def _validate_layout(self):
+ """Runs checks on the layout to ensure a valid pattern is defined.
+ A pattern is considered invalid if:
+ - Multiple timesteps for a same codebook are defined in the same sequence step
+ - The timesteps for a given codebook are not in ascending order as we advance in the sequence
+ (this would mean that we have future timesteps before past timesteps).
+ """
+ q_timesteps = {q: 0 for q in range(self.n_q)}
+ for s, seq_coords in enumerate(self.layout):
+ if len(seq_coords) > 0:
+ qs = set()
+ for coord in seq_coords:
+ qs.add(coord.q)
+ last_q_timestep = q_timesteps[coord.q]
+ assert coord.t >= last_q_timestep, \
+ f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}"
+ q_timesteps[coord.q] = coord.t
+ # each sequence step contains at max 1 coordinate per codebook
+ assert len(qs) == len(seq_coords), \
+ f"Multiple entries for a same codebook are found at step {s}"
+
+ @property
+ def num_sequence_steps(self):
+ return len(self.layout) - 1
+
+ @property
+ def max_delay(self):
+ max_t_in_seq_coords = 0
+ for seq_coords in self.layout[1:]:
+ for coords in seq_coords:
+ max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
+ return max_t_in_seq_coords - self.timesteps
+
+ @property
+ def valid_layout(self):
+ valid_step = len(self.layout) - self.max_delay
+ return self.layout[:valid_step]
+
+ def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
+ """Get codebook coordinates in the layout that corresponds to the specified timestep t
+ and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
+ and the actual codebook coordinates.
+ """
+ assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps"
+ if q is not None:
+ assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks"
+ coords = []
+ for s, seq_codes in enumerate(self.layout):
+ for code in seq_codes:
+ if code.t == t and (q is None or code.q == q):
+ coords.append((s, code))
+ return coords
+
+ def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]:
+ return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
+
+ def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]:
+ steps_with_timesteps = self.get_steps_with_timestep(t, q)
+ return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
+
+ def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool,
+ device: tp.Union[torch.device, str] = 'cpu'):
+ """Build scatter indexes corresponding to the pattern, up to the provided sequence_steps.
+
+ Args:
+ timesteps (int): Maximum number of timesteps steps to consider.
+ keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
+ device (Union[torch.device, str]): Device for created tensors.
+ Returns:
+ indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
+ """
+ assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
+ assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern"
+ # use the proper layout based on whether we limit ourselves to valid steps only or not,
+ # note that using the valid_layout will result in a truncated sequence up to the valid steps
+ ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
+ # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
+ indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy()
+ mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy()
+ # fill indexes with last sequence step value that will correspond to our special token
+ # the last value is n_q * timesteps as we have flattened z and append special token as the last token
+ # which will correspond to the index: n_q * timesteps
+ indexes[:] = n_q * timesteps
+ # iterate over the pattern and fill scattered indexes and mask
+ for s, sequence_coords in enumerate(ref_layout):
+ for coords in sequence_coords:
+ if coords.t < timesteps:
+ indexes[coords.q, s] = coords.t + coords.q * timesteps
+ mask[coords.q, s] = 1
+ indexes = torch.from_numpy(indexes).to(device)
+ mask = torch.from_numpy(mask).to(device)
+ return indexes, mask
+
+ def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
+ """Build sequence corresponding to the pattern from the input tensor z.
+ The sequence is built using up to sequence_steps if specified, and non-pattern
+ coordinates are filled with the special token.
+
+ Args:
+ z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T].
+ special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
+ keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
+ Steps that are beyond valid steps will be replaced by the special_token in that case.
+ Returns:
+ values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S
+ corresponding either to the sequence_steps if provided, otherwise to the length of the pattern.
+ indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S].
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].
+ """
+ B, K, T = z.shape
+ indexes, mask = self._build_pattern_sequence_scatter_indexes(
+ T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
+ )
+ z = z.view(B, -1)
+ # we append the special token as the last index of our flattened z tensor
+ z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
+ values = z[:, indexes.view(-1)]
+ values = values.view(B, K, indexes.shape[-1])
+ return values, indexes, mask
+
+ def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int,
+ keep_only_valid_steps: bool = False,
+ is_model_output: bool = False,
+ device: tp.Union[torch.device, str] = 'cpu'):
+ """Builds scatter indexes required to retrieve the original multi-codebook sequence
+ from interleaving pattern.
+
+ Args:
+ sequence_steps (int): Sequence steps.
+ n_q (int): Number of codebooks.
+ keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
+ Steps that are beyond valid steps will be replaced by the special_token in that case.
+ is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
+ device (Union[torch.device, str]): Device for created tensors.
+ Returns:
+ torch.Tensor: Indexes for reconstructing the output, of shape [K, T].
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
+ """
+ ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
+ # TODO(jade): Do we want to further truncate to only valid timesteps here as well?
+ timesteps = self.timesteps
+ assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
+ assert sequence_steps <= len(ref_layout), \
+ f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
+
+ # ensure we take the appropriate indexes to keep the model output from the first special token as well
+ if is_model_output:
+ ref_layout = ref_layout[1:]
+
+ # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
+ indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy()
+ mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy()
+ # fill indexes with last sequence step value that will correspond to our special token
+ indexes[:] = n_q * sequence_steps
+ for s, sequence_codes in enumerate(ref_layout):
+ if s < sequence_steps:
+ for code in sequence_codes:
+ if code.t < timesteps:
+ indexes[code.q, code.t] = s + code.q * sequence_steps
+ mask[code.q, code.t] = 1
+ indexes = torch.from_numpy(indexes).to(device)
+ mask = torch.from_numpy(mask).to(device)
+ return indexes, mask
+
+ def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
+ """Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving.
+ The sequence is reverted using up to timesteps if specified, and non-pattern coordinates
+ are filled with the special token.
+
+ Args:
+ s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
+ special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
+ Returns:
+ values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T
+ corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise.
+ indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
+ """
+ B, K, S = s.shape
+ indexes, mask = self._build_reverted_sequence_scatter_indexes(
+ S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
+ )
+ s = s.view(B, -1)
+ # we append the special token as the last index of our flattened z tensor
+ s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
+ values = s[:, indexes.view(-1)]
+ values = values.view(B, K, indexes.shape[-1])
+ return values, indexes, mask
+
+ def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False):
+ """Revert model logits obtained on a sequence built from the pattern
+ back to a tensor matching the original sequence.
+
+ This method is similar to ``revert_pattern_sequence`` with the following specificities:
+ 1. It is designed to work with the extra cardinality dimension
+ 2. We return the logits for the first sequence item that matches the special_token and
+ which matching target in the original sequence is the first item of the sequence,
+ while we skip the last logits as there is no matching target
+ """
+ B, card, K, S = logits.shape
+ indexes, mask = self._build_reverted_sequence_scatter_indexes(
+ S, K, keep_only_valid_steps, is_model_output=True, device=logits.device
+ )
+ logits = logits.reshape(B, card, -1)
+ # we append the special token as the last index of our flattened z tensor
+ logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S]
+ values = logits[:, :, indexes.view(-1)]
+ values = values.view(B, card, K, indexes.shape[-1])
+ return values, indexes, mask
+
+
+class CodebooksPatternProvider(ABC):
+ """Abstraction around providing pattern for interleaving codebooks.
+
+ The CodebooksPatternProvider abstraction allows to implement various strategies to
+ define interleaving pattern of sequences composed of multiple codebooks. For a given
+ number of codebooks `n_q`, the pattern provider can generate a specified pattern
+ corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern
+ can be used to construct a new sequence from the original codes respecting the specified
+ pattern. The pattern is defined as a list of list of code coordinates, code coordinate
+ being a tuple with the original timestep and codebook to build the new sequence.
+ Note that all patterns must start with an empty list that is then used to insert a first
+ sequence step of special tokens in the newly generated sequence.
+
+ Args:
+ n_q (int): number of codebooks.
+ cached (bool): if True, patterns for a given length are cached. In general
+ that should be true for efficiency reason to avoid synchronization points.
+ """
+ def __init__(self, n_q: int, cached: bool = True):
+ assert n_q > 0
+ self.n_q = n_q
+ self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore
+
+ @abstractmethod
+ def get_pattern(self, timesteps: int) -> Pattern:
+ """Builds pattern with specific interleaving between codebooks.
+
+ Args:
+ timesteps (int): Total numer of timesteps.
+ """
+ raise NotImplementedError()
+
+
+class DelayedPatternProvider(CodebooksPatternProvider):
+ """Provider for delayed pattern across delayed codebooks.
+ Codebooks are delayed in the sequence and sequence steps will contain codebooks
+ from different timesteps.
+
+ Example:
+ Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence:
+ [[1, 2, 3, 4],
+ [1, 2, 3, 4],
+ [1, 2, 3, 4]]
+ The resulting sequence obtained from the returned pattern is:
+ [[S, 1, 2, 3, 4],
+ [S, S, 1, 2, 3],
+ [S, S, S, 1, 2]]
+ (with S being a special token)
+
+ Args:
+ n_q (int): Number of codebooks.
+ delays (Optional[List[int]]): Delay for each of the codebooks.
+ If delays not defined, each codebook is delayed by 1 compared to the previous one.
+ flatten_first (int): Flatten the first N timesteps.
+ empty_initial (int): Prepend with N empty list of coordinates.
+ """
+ def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None,
+ flatten_first: int = 0, empty_initial: int = 0):
+ super().__init__(n_q)
+ if delays is None:
+ delays = list(range(n_q))
+ self.delays = delays
+ self.flatten_first = flatten_first
+ self.empty_initial = empty_initial
+ assert len(self.delays) == self.n_q
+ assert sorted(self.delays) == self.delays
+
+ def get_pattern(self, timesteps: int) -> Pattern:
+ out: PatternLayout = [[]]
+ max_delay = max(self.delays)
+ if self.empty_initial:
+ out += [[] for _ in range(self.empty_initial)]
+ if self.flatten_first:
+ for t in range(min(timesteps, self.flatten_first)):
+ for q in range(self.n_q):
+ out.append([LayoutCoord(t, q)])
+ for t in range(self.flatten_first, timesteps + max_delay):
+ v = []
+ for q, delay in enumerate(self.delays):
+ t_for_q = t - delay
+ if t_for_q >= self.flatten_first:
+ v.append(LayoutCoord(t_for_q, q))
+ out.append(v)
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
+
+
+class ParallelPatternProvider(DelayedPatternProvider):
+ """Provider for parallel pattern across codebooks.
+ This pattern provider is a special case of the delayed pattern with actually no delay,
+ hence delays=repeat(0, n_q).
+
+ Args:
+ n_q (int): Number of codebooks.
+ """
+ def __init__(self, n_q: int):
+ super().__init__(n_q, [0] * n_q)
+
+
+class UnrolledPatternProvider(CodebooksPatternProvider):
+ """Provider for unrolling codebooks pattern.
+ This pattern provider enables to represent the codebook flattened completely or only to some extend
+ while also specifying a given delay between the flattened codebooks representation, allowing to
+ unroll the codebooks in the sequence.
+
+ Example:
+ 1. Flattening of the codebooks.
+ By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q),
+ taking n_q = 3 and timesteps = 4:
+ [[1, 2, 3, 4],
+ [1, 2, 3, 4],
+ [1, 2, 3, 4]]
+ will result into:
+ [[S, S, 1, S, S, 2, S, S, 3, S, S, 4],
+ [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
+ [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
+ 2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step
+ for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example
+ taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]:
+ [[1, 2, 3, 4],
+ [1, 2, 3, 4],
+ [1, 2, 3, 4]]
+ will result into:
+ [[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
+ [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
+ [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
+ 3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks
+ allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the
+ same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1]
+ and delays = [0, 3, 3]:
+ [[1, 2, 3, 4],
+ [1, 2, 3, 4],
+ [1, 2, 3, 4]]
+ will result into:
+ [[S, S, S, 1, S, 2, S, 3, S, 4],
+ [S, S, S, 1, S, 2, S, 3, S, 4],
+ [1, 2, 3, S, 4, S, 5, S, 6, S]]
+
+ Args:
+ n_q (int): Number of codebooks.
+ flattening (Optional[List[int]]): Flattening schema over the codebooks. If not defined,
+ the codebooks will be flattened to 1 codebook per step, meaning that the sequence will
+ have n_q extra steps for each timestep.
+ delays (Optional[List[int]]): Delay for each of the codebooks. If not defined,
+ no delay is added and therefore will default to [0] * ``n_q``.
+ Note that two codebooks that will be flattened to the same inner step
+ should have the same delay, otherwise the pattern is considered as invalid.
+ """
+ FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay'])
+
+ def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None,
+ delays: tp.Optional[tp.List[int]] = None):
+ super().__init__(n_q)
+ if flattening is None:
+ flattening = list(range(n_q))
+ if delays is None:
+ delays = [0] * n_q
+ assert len(flattening) == n_q
+ assert len(delays) == n_q
+ assert sorted(flattening) == flattening
+ assert sorted(delays) == delays
+ self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening)
+ self.max_delay = max(delays)
+
+ def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]):
+ """Build a flattened codebooks representation as a dictionary of inner step
+ and the actual codebook indices corresponding to the flattened codebook. For convenience, we
+ also store the delay associated to the flattened codebook to avoid maintaining an extra mapping.
+ """
+ flattened_codebooks: dict = {}
+ for q, (inner_step, delay) in enumerate(zip(flattening, delays)):
+ if inner_step not in flattened_codebooks:
+ flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay)
+ else:
+ flat_codebook = flattened_codebooks[inner_step]
+ assert flat_codebook.delay == delay, (
+ "Delay and flattening between codebooks is inconsistent: ",
+ "two codebooks flattened to the same position should have the same delay."
+ )
+ flat_codebook.codebooks.append(q)
+ flattened_codebooks[inner_step] = flat_codebook
+ return flattened_codebooks
+
+ @property
+ def _num_inner_steps(self):
+ """Number of inner steps to unroll between timesteps in order to flatten the codebooks.
+ """
+ return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1
+
+ def num_virtual_steps(self, timesteps: int) -> int:
+ return timesteps * self._num_inner_steps + 1
+
+ def get_pattern(self, timesteps: int) -> Pattern:
+ """Builds pattern for delay across codebooks.
+
+ Args:
+ timesteps (int): Total numer of timesteps.
+ """
+ # the PatternLayout is built as a tuple of sequence position and list of coordinates
+ # so that it can be reordered properly given the required delay between codebooks of given timesteps
+ indexed_out: list = [(-1, [])]
+ max_timesteps = timesteps + self.max_delay
+ for t in range(max_timesteps):
+ # for each timestep, we unroll the flattened codebooks,
+ # emitting the sequence step with the corresponding delay
+ for step in range(self._num_inner_steps):
+ if step in self._flattened_codebooks:
+ # we have codebooks at this virtual step to emit
+ step_codebooks = self._flattened_codebooks[step]
+ t_for_q = t + step_codebooks.delay
+ coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks]
+ if t_for_q < max_timesteps and t < max_timesteps:
+ indexed_out.append((t_for_q, coords))
+ else:
+ # there is no codebook in this virtual step so we emit an empty list
+ indexed_out.append((t, []))
+ out = [coords for _, coords in sorted(indexed_out)]
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
+
+
+class VALLEPattern(CodebooksPatternProvider):
+ """Almost VALL-E style pattern. We futher allow some delays for the
+ codebooks other than the first one.
+
+ Args:
+ n_q (int): Number of codebooks.
+ delays (Optional[List[int]]): Delay for each of the codebooks.
+ If delays not defined, each codebook is delayed by 1 compared to the previous one.
+ """
+ def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
+ super().__init__(n_q)
+ if delays is None:
+ delays = [0] * (n_q - 1)
+ self.delays = delays
+ assert len(self.delays) == self.n_q - 1
+ assert sorted(self.delays) == self.delays
+
+ def get_pattern(self, timesteps: int) -> Pattern:
+ out: PatternLayout = [[]]
+ for t in range(timesteps):
+ out.append([LayoutCoord(t, 0)])
+ max_delay = max(self.delays)
+ for t in range(timesteps + max_delay):
+ v = []
+ for q, delay in enumerate(self.delays):
+ t_for_q = t - delay
+ if t_for_q >= 0:
+ v.append(LayoutCoord(t_for_q, q + 1))
+ out.append(v)
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
+
+
+class MusicLMPattern(CodebooksPatternProvider):
+ """Almost MusicLM style pattern. This is equivalent to full flattening
+ but in a different order.
+
+ Args:
+ n_q (int): Number of codebooks.
+ group_by (int): Number of codebooks to group together.
+ """
+ def __init__(self, n_q: int, group_by: int = 2):
+ super().__init__(n_q)
+ self.group_by = group_by
+
+ def get_pattern(self, timesteps: int) -> Pattern:
+ out: PatternLayout = [[]]
+ for offset in range(0, self.n_q, self.group_by):
+ for t in range(timesteps):
+ for q in range(offset, offset + self.group_by):
+ out.append([LayoutCoord(t, q)])
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
diff --git a/audiocraft/modules/conditioners.py b/audiocraft/modules/conditioners.py
new file mode 100644
index 0000000..8279231
--- /dev/null
+++ b/audiocraft/modules/conditioners.py
@@ -0,0 +1,990 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from collections import defaultdict
+from copy import deepcopy
+from dataclasses import dataclass, field
+from itertools import chain
+import logging
+import math
+import random
+import re
+import typing as tp
+import warnings
+
+from einops import rearrange
+from num2words import num2words
+import spacy
+from transformers import T5EncoderModel, T5Tokenizer # type: ignore
+import torchaudio
+import torch
+from torch import nn
+from torch import Tensor
+import torch.nn.functional as F
+from torch.nn.utils.rnn import pad_sequence
+
+from .streaming import StreamingModule
+from .transformer import create_sin_embedding
+from ..data.audio_dataset import SegmentInfo
+from ..utils.autocast import TorchAutocast
+from ..utils.utils import hash_trick, length_to_mask, collate
+
+
+logger = logging.getLogger(__name__)
+TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist)
+ConditionType = tp.Tuple[Tensor, Tensor] # condition, mask
+
+
+class WavCondition(tp.NamedTuple):
+ wav: Tensor
+ length: Tensor
+ path: tp.List[tp.Optional[str]] = []
+
+
+def nullify_condition(condition: ConditionType, dim: int = 1):
+ """This function transforms an input condition to a null condition.
+ The way it is done by converting it to a single zero vector similarly
+ to how it is done inside WhiteSpaceTokenizer and NoopTokenizer.
+
+ Args:
+ condition (ConditionType): a tuple of condition and mask (tp.Tuple[Tensor, Tensor])
+ dim (int): the dimension that will be truncated (should be the time dimension)
+ WARNING!: dim should not be the batch dimension!
+ Returns:
+ ConditionType: a tuple of null condition and mask
+ """
+ assert dim != 0, "dim cannot be the batch dimension!"
+ assert type(condition) == tuple and \
+ type(condition[0]) == Tensor and \
+ type(condition[1]) == Tensor, "'nullify_condition' got an unexpected input type!"
+ cond, mask = condition
+ B = cond.shape[0]
+ last_dim = cond.dim() - 1
+ out = cond.transpose(dim, last_dim)
+ out = 0. * out[..., :1]
+ out = out.transpose(dim, last_dim)
+ mask = torch.zeros((B, 1), device=out.device).int()
+ assert cond.dim() == out.dim()
+ return out, mask
+
+
+def nullify_wav(wav: Tensor) -> WavCondition:
+ """Create a nullified WavCondition from a wav tensor with appropriate shape.
+
+ Args:
+ wav (Tensor): tensor of shape [B, T]
+ Returns:
+ WavCondition: wav condition with nullified wav.
+ """
+ null_wav, _ = nullify_condition((wav, torch.zeros_like(wav)), dim=wav.dim() - 1)
+ return WavCondition(
+ wav=null_wav,
+ length=torch.tensor([0] * wav.shape[0], device=wav.device),
+ path=['null_wav'] * wav.shape[0]
+ )
+
+
+@dataclass
+class ConditioningAttributes:
+ text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
+ wav: tp.Dict[str, WavCondition] = field(default_factory=dict)
+
+ def __getitem__(self, item):
+ return getattr(self, item)
+
+ @property
+ def text_attributes(self):
+ return self.text.keys()
+
+ @property
+ def wav_attributes(self):
+ return self.wav.keys()
+
+ @property
+ def attributes(self):
+ return {"text": self.text_attributes, "wav": self.wav_attributes}
+
+ def to_flat_dict(self):
+ return {
+ **{f"text.{k}": v for k, v in self.text.items()},
+ **{f"wav.{k}": v for k, v in self.wav.items()},
+ }
+
+ @classmethod
+ def from_flat_dict(cls, x):
+ out = cls()
+ for k, v in x.items():
+ kind, att = k.split(".")
+ out[kind][att] = v
+ return out
+
+
+class SegmentWithAttributes(SegmentInfo):
+ """Base class for all dataclasses that are used for conditioning.
+ All child classes should implement `to_condition_attributes` that converts
+ the existing attributes to a dataclass of type ConditioningAttributes.
+ """
+ def to_condition_attributes(self) -> ConditioningAttributes:
+ raise NotImplementedError()
+
+
+class Tokenizer:
+ """Base class for all tokenizers
+ (in case we want to introduce more advances tokenizers in the future).
+ """
+ def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[Tensor, Tensor]:
+ raise NotImplementedError()
+
+
+class WhiteSpaceTokenizer(Tokenizer):
+ """This tokenizer should be used for natural language descriptions.
+ For example:
+ ["he didn't, know he's going home.", 'shorter sentence'] =>
+ [[78, 62, 31, 4, 78, 25, 19, 34],
+ [59, 77, 0, 0, 0, 0, 0, 0]]
+ """
+ PUNCTUATIONS = "?:!.,;"
+
+ def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_core_web_sm",
+ lemma: bool = True, stopwords: bool = True) -> None:
+ self.n_bins = n_bins
+ self.pad_idx = pad_idx
+ self.lemma = lemma
+ self.stopwords = stopwords
+ try:
+ self.nlp = spacy.load(language)
+ except IOError:
+ spacy.cli.download(language) # type: ignore
+ self.nlp = spacy.load(language)
+
+ @tp.no_type_check
+ def __call__(
+ self,
+ texts: tp.List[tp.Optional[str]],
+ return_text: bool = False
+ ) -> tp.Tuple[Tensor, Tensor]:
+ """Take a list of strings and convert them to a tensor of indices.
+
+ Args:
+ texts (tp.List[str]): List of strings.
+ return_text (bool, optional): Whether to return text as additional tuple item. Defaults to False.
+ Returns:
+ tp.Tuple[Tensor, Tensor]:
+ - Indices of words in the LUT.
+ - And a mask indicating where the padding tokens are
+ """
+ output, lengths = [], []
+ texts = deepcopy(texts)
+ for i, text in enumerate(texts):
+ # if current sample doesn't have a certain attribute, replace with pad token
+ if text is None:
+ output.append(Tensor([self.pad_idx]))
+ lengths.append(0)
+ continue
+
+ # convert numbers to words
+ text = re.sub(r"(\d+)", lambda x: num2words(int(x.group(0))), text) # type: ignore
+ # normalize text
+ text = self.nlp(text) # type: ignore
+ # remove stopwords
+ if self.stopwords:
+ text = [w for w in text if not w.is_stop] # type: ignore
+ # remove punctuations
+ text = [w for w in text if w.text not in self.PUNCTUATIONS] # type: ignore
+ # lemmatize if needed
+ text = [getattr(t, "lemma_" if self.lemma else "text") for t in text] # type: ignore
+
+ texts[i] = " ".join(text)
+ lengths.append(len(text))
+ # convert to tensor
+ tokens = Tensor([hash_trick(w, self.n_bins) for w in text])
+ output.append(tokens)
+
+ mask = length_to_mask(torch.IntTensor(lengths)).int()
+ padded_output = pad_sequence(output, padding_value=self.pad_idx).int().t()
+ if return_text:
+ return padded_output, mask, texts # type: ignore
+ return padded_output, mask
+
+
+class NoopTokenizer(Tokenizer):
+ """This tokenizer should be used for global conditioners such as: artist, genre, key, etc.
+ The difference between this and WhiteSpaceTokenizer is that NoopTokenizer does not split
+ strings, so "Jeff Buckley" will get it's own index. Whereas WhiteSpaceTokenizer will
+ split it to ["Jeff", "Buckley"] and return an index per word.
+
+ For example:
+ ["Queen", "ABBA", "Jeff Buckley"] => [43, 55, 101]
+ ["Metal", "Rock", "Classical"] => [0, 223, 51]
+ """
+ def __init__(self, n_bins: int, pad_idx: int = 0):
+ self.n_bins = n_bins
+ self.pad_idx = pad_idx
+
+ def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[Tensor, Tensor]:
+ output, lengths = [], []
+ for text in texts:
+ # if current sample doesn't have a certain attribute, replace with pad token
+ if text is None:
+ output.append(self.pad_idx)
+ lengths.append(0)
+ else:
+ output.append(hash_trick(text, self.n_bins))
+ lengths.append(1)
+
+ tokens = torch.LongTensor(output).unsqueeze(1)
+ mask = length_to_mask(torch.IntTensor(lengths)).int()
+ return tokens, mask
+
+
+class BaseConditioner(nn.Module):
+ """Base model for all conditioner modules. We allow the output dim to be different
+ than the hidden dim for two reasons: 1) keep our LUTs small when the vocab is large;
+ 2) make all condition dims consistent.
+
+ Args:
+ dim (int): Hidden dim of the model (text-encoder/LUT).
+ output_dim (int): Output dim of the conditioner.
+ """
+ def __init__(self, dim, output_dim):
+ super().__init__()
+ self.dim = dim
+ self.output_dim = output_dim
+ self.output_proj = nn.Linear(dim, output_dim)
+
+ def tokenize(self, *args, **kwargs) -> tp.Any:
+ """Should be any part of the processing that will lead to a synchronization
+ point, e.g. BPE tokenization with transfer to the GPU.
+
+ The returned value will be saved and return later when calling forward().
+ """
+ raise NotImplementedError()
+
+ def forward(self, inputs: tp.Any) -> ConditionType:
+ """Gets input that should be used as conditioning (e.g, genre, description or a waveform).
+ Outputs a ConditionType, after the input data was embedded as a dense vector.
+
+ Returns:
+ ConditionType:
+ - A tensor of size [B, T, D] where B is the batch size, T is the length of the
+ output embedding and D is the dimension of the embedding.
+ - And a mask indicating where the padding tokens.
+ """
+ raise NotImplementedError()
+
+
+class TextConditioner(BaseConditioner):
+ ...
+
+
+class LUTConditioner(TextConditioner):
+ """Lookup table TextConditioner.
+
+ Args:
+ n_bins (int): Number of bins.
+ dim (int): Hidden dim of the model (text-encoder/LUT).
+ output_dim (int): Output dim of the conditioner.
+ tokenizer (str): Name of the tokenizer.
+ pad_idx (int, optional): Index for padding token. Defaults to 0.
+ """
+ def __init__(self, n_bins: int, dim: int, output_dim: int, tokenizer: str, pad_idx: int = 0):
+ super().__init__(dim, output_dim)
+ self.embed = nn.Embedding(n_bins, dim)
+ self.tokenizer: Tokenizer
+ if tokenizer == "whitespace":
+ self.tokenizer = WhiteSpaceTokenizer(n_bins, pad_idx=pad_idx)
+ elif tokenizer == "noop":
+ self.tokenizer = NoopTokenizer(n_bins, pad_idx=pad_idx)
+ else:
+ raise ValueError(f"unrecognized tokenizer `{tokenizer}`.")
+
+ def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+ device = self.embed.weight.device
+ tokens, mask = self.tokenizer(x)
+ tokens, mask = tokens.to(device), mask.to(device)
+ return tokens, mask
+
+ def forward(self, inputs: tp.Tuple[torch.Tensor, torch.Tensor]) -> ConditionType:
+ tokens, mask = inputs
+ embeds = self.embed(tokens)
+ embeds = self.output_proj(embeds)
+ embeds = (embeds * mask.unsqueeze(-1))
+ return embeds, mask
+
+
+class T5Conditioner(TextConditioner):
+ """T5-based TextConditioner.
+
+ Args:
+ name (str): Name of the T5 model.
+ output_dim (int): Output dim of the conditioner.
+ finetune (bool): Whether to fine-tune T5 at train time.
+ device (str): Device for T5 Conditioner.
+ autocast_dtype (tp.Optional[str], optional): Autocast dtype.
+ word_dropout (float, optional): Word dropout probability.
+ normalize_text (bool, optional): Whether to apply text normalization.
+ """
+ MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
+ "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large",
+ "google/flan-t5-xl", "google/flan-t5-xxl"]
+ MODELS_DIMS = {
+ "t5-small": 512,
+ "t5-base": 768,
+ "t5-large": 1024,
+ "t5-3b": 1024,
+ "t5-11b": 1024,
+ "google/flan-t5-small": 512,
+ "google/flan-t5-base": 768,
+ "google/flan-t5-large": 1024,
+ "google/flan-t5-3b": 1024,
+ "google/flan-t5-11b": 1024,
+ }
+
+ def __init__(self, name: str, output_dim: int, finetune: bool, device: str,
+ autocast_dtype: tp.Optional[str] = 'float32', word_dropout: float = 0.,
+ normalize_text: bool = False):
+ assert name in self.MODELS, f"unrecognized t5 model name (should in {self.MODELS})"
+ super().__init__(self.MODELS_DIMS[name], output_dim)
+ self.device = device
+ self.name = name
+ self.finetune = finetune
+ self.word_dropout = word_dropout
+
+ if autocast_dtype is None or self.device == 'cpu':
+ self.autocast = TorchAutocast(enabled=False)
+ if self.device != 'cpu':
+ logger.warning("T5 has no autocast, this might lead to NaN")
+ else:
+ dtype = getattr(torch, autocast_dtype)
+ assert isinstance(dtype, torch.dtype)
+ logger.info(f"T5 will be evaluated with autocast as {autocast_dtype}")
+ self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype)
+ # Let's disable logging temporarily because T5 will vomit some errors otherwise.
+ # thanks https://gist.github.com/simon-weber/7853144
+ previous_level = logging.root.manager.disable
+ logging.disable(logging.ERROR)
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ try:
+ self.t5_tokenizer = T5Tokenizer.from_pretrained(name)
+ t5 = T5EncoderModel.from_pretrained(name).train(mode=finetune)
+ finally:
+ logging.disable(previous_level)
+ if finetune:
+ self.t5 = t5
+ else:
+ # this makes sure that the t5 models is not part
+ # of the saved checkpoint
+ self.__dict__["t5"] = t5.to(device)
+
+ self.normalize_text = normalize_text
+ if normalize_text:
+ self.text_normalizer = WhiteSpaceTokenizer(1, lemma=True, stopwords=True)
+
+ def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]:
+ # if current sample doesn't have a certain attribute, replace with empty string
+ entries: tp.List[str] = [xi if xi is not None else "" for xi in x]
+ if self.normalize_text:
+ _, _, entries = self.text_normalizer(entries, return_text=True)
+ if self.word_dropout > 0. and self.training:
+ new_entries = []
+ for entry in entries:
+ words = [word for word in entry.split(" ") if random.random() >= self.word_dropout]
+ new_entries.append(" ".join(words))
+ entries = new_entries
+
+ empty_idx = torch.LongTensor([i for i, xi in enumerate(entries) if xi == ""])
+
+ inputs = self.t5_tokenizer(entries, return_tensors="pt", padding=True).to(self.device)
+ mask = inputs["attention_mask"]
+ mask[empty_idx, :] = 0 # zero-out index where the input is non-existant
+ return inputs
+
+ def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType:
+ mask = inputs["attention_mask"]
+ with torch.set_grad_enabled(self.finetune), self.autocast:
+ embeds = self.t5(**inputs).last_hidden_state
+ embeds = self.output_proj(embeds.to(self.output_proj.weight))
+ embeds = (embeds * mask.unsqueeze(-1))
+ return embeds, mask
+
+
+class WaveformConditioner(BaseConditioner):
+ """Base class for all conditioners that take a waveform as input.
+ Classes that inherit must implement `_get_wav_embedding` that outputs
+ a continuous tensor, and `_downsampling_factor` that returns the down-sampling
+ factor of the embedding model.
+
+ Args:
+ dim (int): The internal representation dimension.
+ output_dim (int): Output dimension.
+ device (tp.Union[torch.device, str]): Device.
+ """
+ def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.device, str]):
+ super().__init__(dim, output_dim)
+ self.device = device
+
+ def tokenize(self, wav_length: WavCondition) -> WavCondition:
+ wav, length, path = wav_length
+ assert length is not None
+ return WavCondition(wav.to(self.device), length.to(self.device), path)
+
+ def _get_wav_embedding(self, wav: Tensor) -> Tensor:
+ """Gets as input a wav and returns a dense vector of conditions."""
+ raise NotImplementedError()
+
+ def _downsampling_factor(self):
+ """Returns the downsampling factor of the embedding model."""
+ raise NotImplementedError()
+
+ def forward(self, inputs: WavCondition) -> ConditionType:
+ """
+ Args:
+ input (WavCondition): Tuple of (waveform, lengths).
+ Returns:
+ ConditionType: Dense vector representing the conditioning along with its' mask.
+ """
+ wav, lengths, path = inputs
+ with torch.no_grad():
+ embeds = self._get_wav_embedding(wav)
+ embeds = embeds.to(self.output_proj.weight)
+ embeds = self.output_proj(embeds)
+
+ if lengths is not None:
+ lengths = lengths / self._downsampling_factor()
+ mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore
+ else:
+ mask = torch.ones_like(embeds)
+ embeds = (embeds * mask.unsqueeze(2).to(self.device))
+
+ return embeds, mask
+
+
+class ChromaStemConditioner(WaveformConditioner):
+ """Chroma conditioner that uses DEMUCS to first filter out drums and bass. The is followed by
+ the insight the drums and bass often dominate the chroma, leading to the chroma not containing the
+ information about melody.
+
+ Args:
+ output_dim (int): Output dimension for the conditioner.
+ sample_rate (int): Sample rate for the chroma extractor.
+ n_chroma (int): Number of chroma for the chroma extractor.
+ radix2_exp (int): Radix2 exponent for the chroma extractor.
+ duration (float): Duration used during training. This is later used for correct padding
+ in case we are using chroma as prefix.
+ match_len_on_eval (bool, optional): If True then all chromas are padded to the training
+ duration. Defaults to False.
+ eval_wavs (str, optional): Path to a json egg with waveform, this waveforms are used as
+ conditions during eval (for cases where we don't want to leak test conditions like MusicCaps).
+ Defaults to None.
+ n_eval_wavs (int, optional): Limits the number of waveforms used for conditioning. Defaults to 0.
+ device (tp.Union[torch.device, str], optional): Device for the conditioner.
+ **kwargs: Additional parameters for the chroma extractor.
+ """
+ def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, radix2_exp: int,
+ duration: float, match_len_on_eval: bool = True, eval_wavs: tp.Optional[str] = None,
+ n_eval_wavs: int = 0, device: tp.Union[torch.device, str] = "cpu", **kwargs):
+ from demucs import pretrained
+ super().__init__(dim=n_chroma, output_dim=output_dim, device=device)
+ self.autocast = TorchAutocast(enabled=device != "cpu", device_type=self.device, dtype=torch.float32)
+ self.sample_rate = sample_rate
+ self.match_len_on_eval = match_len_on_eval
+ self.duration = duration
+ self.__dict__["demucs"] = pretrained.get_model('htdemucs').to(device)
+ self.stem2idx = {'drums': 0, 'bass': 1, 'other': 2, 'vocal': 3}
+ self.stem_idx = torch.LongTensor([self.stem2idx['vocal'], self.stem2idx['other']]).to(device)
+ self.chroma = ChromaExtractor(sample_rate=sample_rate, n_chroma=n_chroma, radix2_exp=radix2_exp,
+ device=device, **kwargs)
+ self.chroma_len = self._get_chroma_len()
+
+ def _downsampling_factor(self):
+ return self.chroma.winhop
+
+ def _get_chroma_len(self):
+ """Get length of chroma during training"""
+ dummy_wav = torch.zeros((1, self.sample_rate * self.duration), device=self.device)
+ dummy_chr = self.chroma(dummy_wav)
+ return dummy_chr.shape[1]
+
+ @torch.no_grad()
+ def _get_filtered_wav(self, wav):
+ from demucs.apply import apply_model
+ from demucs.audio import convert_audio
+ with self.autocast:
+ wav = convert_audio(wav, self.sample_rate, self.demucs.samplerate, self.demucs.audio_channels)
+ stems = apply_model(self.demucs, wav, device=self.device)
+ stems = stems[:, self.stem_idx] # extract stem
+ stems = stems.sum(1) # merge extracted stems
+ stems = stems.mean(1, keepdim=True) # mono
+ stems = convert_audio(stems, self.demucs.samplerate, self.sample_rate, 1)
+ return stems
+
+ @torch.no_grad()
+ def _get_wav_embedding(self, wav):
+ # avoid 0-size tensors when we are working with null conds
+ if wav.shape[-1] == 1:
+ return self.chroma(wav)
+ stems = self._get_filtered_wav(wav)
+ chroma = self.chroma(stems)
+
+ if self.match_len_on_eval:
+ b, t, c = chroma.shape
+ if t > self.chroma_len:
+ chroma = chroma[:, :self.chroma_len]
+ logger.debug(f'chroma was truncated! ({t} -> {chroma.shape[1]})')
+ elif t < self.chroma_len:
+ # chroma = F.pad(chroma, (0, 0, 0, self.chroma_len - t))
+ n_repeat = int(math.ceil(self.chroma_len / t))
+ chroma = chroma.repeat(1, n_repeat, 1)
+ chroma = chroma[:, :self.chroma_len]
+ logger.debug(f'chroma was zero-padded! ({t} -> {chroma.shape[1]})')
+ return chroma
+
+
+class ChromaExtractor(nn.Module):
+ """Chroma extraction class, handles chroma extraction and quantization.
+
+ Args:
+ sample_rate (int): Sample rate.
+ n_chroma (int): Number of chroma to consider.
+ radix2_exp (int): Radix2 exponent.
+ nfft (tp.Optional[int], optional): Number of FFT.
+ winlen (tp.Optional[int], optional): Window length.
+ winhop (tp.Optional[int], optional): Window hop size.
+ argmax (bool, optional): Whether to use argmax. Defaults to False.
+ norm (float, optional): Norm for chroma normalization. Defaults to inf.
+ device (tp.Union[torch.device, str], optional): Device to use. Defaults to cpu.
+ """
+ def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12,
+ nfft: tp.Optional[int] = None, winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None,
+ argmax: bool = False, norm: float = torch.inf, device: tp.Union[torch.device, str] = "cpu"):
+ super().__init__()
+ from librosa import filters
+ self.device = device
+ self.autocast = TorchAutocast(enabled=device != "cpu", device_type=self.device, dtype=torch.float32)
+ self.winlen = winlen or 2 ** radix2_exp
+ self.nfft = nfft or self.winlen
+ self.winhop = winhop or (self.winlen // 4)
+ self.sr = sample_rate
+ self.n_chroma = n_chroma
+ self.norm = norm
+ self.argmax = argmax
+ self.window = torch.hann_window(self.winlen).to(device)
+ self.fbanks = torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0,
+ n_chroma=self.n_chroma)).to(device)
+ self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen,
+ hop_length=self.winhop, power=2, center=True,
+ pad=0, normalized=True).to(device)
+
+ def forward(self, wav):
+ with self.autocast:
+ T = wav.shape[-1]
+ # in case we are getting a wav that was dropped out (nullified)
+ # make sure wav length is no less that nfft
+ if T < self.nfft:
+ pad = self.nfft - T
+ r = 0 if pad % 2 == 0 else 1
+ wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0)
+ assert wav.shape[-1] == self.nfft, f'expected len {self.nfft} but got {wav.shape[-1]}'
+ spec = self.spec(wav).squeeze(1)
+ raw_chroma = torch.einsum("cf,...ft->...ct", self.fbanks, spec)
+ norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6)
+ norm_chroma = rearrange(norm_chroma, "b d t -> b t d")
+
+ if self.argmax:
+ idx = norm_chroma.argmax(-1, keepdims=True)
+ norm_chroma[:] = 0
+ norm_chroma.scatter_(dim=-1, index=idx, value=1)
+
+ return norm_chroma
+
+
+def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str):
+ """Utility function for nullifying an attribute inside an ConditioningAttributes object.
+ If the condition is of type "wav", then nullify it using "nullify_condition".
+ If the condition is of any other type, set its' value to None.
+ Works in-place.
+ """
+ if condition_type not in ["text", "wav"]:
+ raise ValueError(
+ "dropout_condition got an unexpected condition type!"
+ f" expected 'wav' or 'text' but got '{condition_type}'"
+ )
+
+ if condition not in getattr(sample, condition_type):
+ raise ValueError(
+ "dropout_condition received an unexpected condition!"
+ f" expected wav={sample.wav.keys()} and text={sample.text.keys()}"
+ f"but got '{condition}' of type '{condition_type}'!"
+ )
+
+ if condition_type == "wav":
+ wav, length, path = sample.wav[condition]
+ sample.wav[condition] = nullify_wav(wav)
+ else:
+ sample.text[condition] = None
+
+ return sample
+
+
+class DropoutModule(nn.Module):
+ """Base class for all dropout modules."""
+ def __init__(self, seed: int = 1234):
+ super().__init__()
+ self.rng = torch.Generator()
+ self.rng.manual_seed(seed)
+
+
+class AttributeDropout(DropoutModule):
+ """Applies dropout with a given probability per attribute. This is different from the behavior of
+ ClassifierFreeGuidanceDropout as this allows for attributes to be dropped out separately. For example,
+ "artist" can be dropped while "genre" remains. This is in contrast to ClassifierFreeGuidanceDropout
+ where if "artist" is dropped "genre" must also be dropped.
+
+ Args:
+ p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example:
+ ...
+ "genre": 0.1,
+ "artist": 0.5,
+ "wav": 0.25,
+ ...
+ active_on_eval (bool, optional): Whether the dropout is active at eval. Default to False.
+ seed (int, optional): Random seed.
+ """
+ def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eval: bool = False, seed: int = 1234):
+ super().__init__(seed=seed)
+ self.active_on_eval = active_on_eval
+ # construct dict that return the values from p otherwise 0
+ self.p = {}
+ for condition_type, probs in p.items():
+ self.p[condition_type] = defaultdict(lambda: 0, probs)
+
+ def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
+ """
+ Args:
+ samples (tp.List[ConditioningAttributes]): List of conditions.
+ Returns:
+ tp.List[ConditioningAttributes]: List of conditions after certain attributes were set to None.
+ """
+ if not self.training and not self.active_on_eval:
+ return samples
+
+ samples = deepcopy(samples)
+
+ for condition_type, ps in self.p.items(): # for condition types [text, wav]
+ for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre])
+ if torch.rand(1, generator=self.rng).item() < p:
+ for sample in samples:
+ dropout_condition(sample, condition_type, condition)
+
+ return samples
+
+ def __repr__(self):
+ return f"AttributeDropout({dict(self.p)})"
+
+
+class ClassifierFreeGuidanceDropout(DropoutModule):
+ """Applies Classifier Free Guidance dropout, meaning all attributes
+ are dropped with the same probability.
+
+ Args:
+ p (float): Probability to apply condition dropout during training.
+ seed (int): Random seed.
+ """
+ def __init__(self, p: float, seed: int = 1234):
+ super().__init__(seed=seed)
+ self.p = p
+
+ def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
+ """
+ Args:
+ samples (tp.List[ConditioningAttributes]): List of conditions.
+ Returns:
+ tp.List[ConditioningAttributes]: List of conditions after all attributes were set to None.
+ """
+ if not self.training:
+ return samples
+
+ # decide on which attributes to drop in a batched fashion
+ drop = torch.rand(1, generator=self.rng).item() < self.p
+ if not drop:
+ return samples
+
+ # nullify conditions of all attributes
+ samples = deepcopy(samples)
+
+ for condition_type in ["wav", "text"]:
+ for sample in samples:
+ for condition in sample.attributes[condition_type]:
+ dropout_condition(sample, condition_type, condition)
+
+ return samples
+
+ def __repr__(self):
+ return f"ClassifierFreeGuidanceDropout(p={self.p})"
+
+
+class ConditioningProvider(nn.Module):
+ """Main class to provide conditions given all the supported conditioners.
+
+ Args:
+ conditioners (dict): Dictionary of conditioners.
+ merge_text_conditions_p (float, optional): Probability to merge all text sources
+ into a single text condition. Defaults to 0.
+ drop_desc_p (float, optional): Probability to drop the original description
+ when merging all text sources into a single text condition. Defaults to 0.
+ device (tp.Union[torch.device, str], optional): Device for conditioners and output condition types.
+ """
+ def __init__(
+ self,
+ conditioners: tp.Dict[str, BaseConditioner],
+ merge_text_conditions_p: float = 0,
+ drop_desc_p: float = 0,
+ device: tp.Union[torch.device, str] = "cpu",
+ ):
+ super().__init__()
+ self.device = device
+ self.merge_text_conditions_p = merge_text_conditions_p
+ self.drop_desc_p = drop_desc_p
+ self.conditioners = nn.ModuleDict(conditioners)
+
+ @property
+ def text_conditions(self):
+ return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)]
+
+ @property
+ def wav_conditions(self):
+ return [k for k, v in self.conditioners.items() if isinstance(v, WaveformConditioner)]
+
+ @property
+ def has_wav_condition(self):
+ return len(self.wav_conditions) > 0
+
+ def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
+ """Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly.
+ This should be called before starting any real GPU work to avoid synchronization points.
+ This will return a dict matching conditioner names to their arbitrary tokenized representations.
+
+ Args:
+ inputs (list[ConditioningAttribres]): List of ConditioningAttributes objects containing
+ text and wav conditions.
+ """
+ assert all([type(x) == ConditioningAttributes for x in inputs]), \
+ "got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]" \
+ f" but types were {set([type(x) for x in inputs])}"
+
+ output = {}
+ text = self._collate_text(inputs)
+ wavs = self._collate_wavs(inputs)
+
+ assert set(text.keys() | wavs.keys()).issubset(set(self.conditioners.keys())), \
+ f"got an unexpected attribute! Expected {self.conditioners.keys()}, got {text.keys(), wavs.keys()}"
+
+ for attribute, batch in chain(text.items(), wavs.items()):
+ output[attribute] = self.conditioners[attribute].tokenize(batch)
+ return output
+
+ def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
+ """Compute pairs of `(embedding, mask)` using the configured conditioners
+ and the tokenized representations. The output is for example:
+
+ {
+ "genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])),
+ "description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])),
+ ...
+ }
+
+ Args:
+ tokenized (dict): Dict of tokenized representations as returned by `tokenize()`.
+ """
+ output = {}
+ for attribute, inputs in tokenized.items():
+ condition, mask = self.conditioners[attribute](inputs)
+ output[attribute] = (condition, mask)
+ return output
+
+ def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]:
+ """Given a list of ConditioningAttributes objects, compile a dictionary where the keys
+ are the attributes and the values are the aggregated input per attribute.
+ For example:
+ Input:
+ [
+ ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...),
+ ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, wav=...),
+ ]
+ Output:
+ {
+ "genre": ["Rock", "Hip-hop"],
+ "description": ["A rock song with a guitar solo", "A hip-hop verse"]
+ }
+ """
+ batch_per_attribute: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list)
+
+ def _merge_conds(cond, merge_text_conditions_p=0, drop_desc_p=0):
+ def is_valid(k, v):
+ k_valid = k in ['key', 'bpm', 'genre', 'moods', 'instrument']
+ v_valid = v is not None and isinstance(v, (int, float, str, list))
+ return k_valid and v_valid
+
+ def process_value(v):
+ if isinstance(v, (int, float, str)):
+ return v
+ if isinstance(v, list):
+ return ", ".join(v)
+ else:
+ RuntimeError(f"unknown type for text value! ({type(v), v})")
+
+ desc = cond.text['description']
+ meta_data = ""
+ if random.uniform(0, 1) < merge_text_conditions_p:
+ meta_pairs = [f'{k}: {process_value(v)}' for k, v in cond.text.items() if is_valid(k, v)]
+ random.shuffle(meta_pairs)
+ meta_data = ". ".join(meta_pairs)
+ desc = desc if not random.uniform(0, 1) < drop_desc_p else None
+
+ if desc is None:
+ desc = meta_data if len(meta_data) > 1 else None
+ else:
+ desc = desc.rstrip('.') + ". " + meta_data
+ cond.text['description'] = desc.strip() if desc else None
+
+ if self.training and self.merge_text_conditions_p:
+ for sample in samples:
+ _merge_conds(sample, self.merge_text_conditions_p, self.drop_desc_p)
+
+ texts = [x.text for x in samples]
+ for text in texts:
+ for condition in self.text_conditions:
+ batch_per_attribute[condition].append(text[condition])
+
+ return batch_per_attribute
+
+ def _collate_wavs(self, samples: tp.List[ConditioningAttributes]):
+ """Generate a dict where the keys are attributes by which we fetch similar wavs,
+ and the values are Tensors of wavs according to said attribtues.
+
+ *Note*: by the time the samples reach this function, each sample should have some waveform
+ inside the "wav" attribute. It should be either:
+ 1. A real waveform
+ 2. A null waveform due to the sample having no similar waveforms (nullified by the dataset)
+ 3. A null waveform due to it being dropped in a dropout module (nullified by dropout)
+
+ Args:
+ samples (tp.List[ConditioningAttributes]): List of ConditioningAttributes samples.
+ Returns:
+ dict: A dicionary mapping an attribute name to wavs.
+ """
+ wavs = defaultdict(list)
+ lens = defaultdict(list)
+ paths = defaultdict(list)
+ out = {}
+
+ for sample in samples:
+ for attribute in self.wav_conditions:
+ wav, length, path = sample.wav[attribute]
+ wavs[attribute].append(wav.flatten())
+ lens[attribute].append(length)
+ paths[attribute].append(path)
+
+ # stack all wavs to a single tensor
+ for attribute in self.wav_conditions:
+ stacked_wav, _ = collate(wavs[attribute], dim=0)
+ out[attribute] = WavCondition(stacked_wav.unsqueeze(1),
+ torch.cat(lens['self_wav']), paths[attribute]) # type: ignore
+
+ return out
+
+
+class ConditionFuser(StreamingModule):
+ """Condition fuser handles the logic to combine the different conditions
+ to the actual model input.
+
+ Args:
+ fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse
+ each condition. For example:
+ {
+ "prepend": ["description"],
+ "sum": ["genre", "bpm"],
+ "cross": ["description"],
+ }
+ cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention.
+ cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used.
+ """
+ FUSING_METHODS = ["sum", "prepend", "cross", "input_interpolate"]
+
+ def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False,
+ cross_attention_pos_emb_scale: float = 1.0):
+ super().__init__()
+ assert all(
+ [k in self.FUSING_METHODS for k in fuse2cond.keys()]
+ ), f"got invalid fuse method, allowed methods: {self.FUSING_MEHTODS}"
+ self.cross_attention_pos_emb = cross_attention_pos_emb
+ self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale
+ self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond
+ self.cond2fuse: tp.Dict[str, str] = {}
+ for fuse_method, conditions in fuse2cond.items():
+ for condition in conditions:
+ self.cond2fuse[condition] = fuse_method
+
+ def forward(
+ self,
+ input: Tensor,
+ conditions: tp.Dict[str, ConditionType]
+ ) -> tp.Tuple[Tensor, tp.Optional[Tensor]]:
+ """Fuse the conditions to the provided model input.
+
+ Args:
+ input (Tensor): Transformer input.
+ conditions (tp.Dict[str, ConditionType]): Dict of conditions.
+ Returns:
+ tp.Tuple[Tensor, Tensor]: The first tensor is the transformer input
+ after the conditions have been fused. The second output tensor is the tensor
+ used for cross-attention or None if no cross attention inputs exist.
+ """
+ B, T, _ = input.shape
+
+ if 'offsets' in self._streaming_state:
+ first_step = False
+ offsets = self._streaming_state['offsets']
+ else:
+ first_step = True
+ offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device)
+
+ assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \
+ f"given conditions contain unknown attributes for fuser, " \
+ f"expected {self.cond2fuse.keys()}, got {conditions.keys()}"
+ cross_attention_output = None
+ for cond_type, (cond, cond_mask) in conditions.items():
+ op = self.cond2fuse[cond_type]
+ if op == "sum":
+ input += cond
+ elif op == "input_interpolate":
+ cond = rearrange(cond, "b t d -> b d t")
+ cond = F.interpolate(cond, size=input.shape[1])
+ input += rearrange(cond, "b d t -> b t d")
+ elif op == "prepend":
+ if first_step:
+ input = torch.cat([cond, input], dim=1)
+ elif op == "cross":
+ if cross_attention_output is not None:
+ cross_attention_output = torch.cat([cross_attention_output, cond], dim=1)
+ else:
+ cross_attention_output = cond
+ else:
+ raise ValueError(f"unknown op ({op})")
+
+ if self.cross_attention_pos_emb and cross_attention_output is not None:
+ positions = torch.arange(
+ cross_attention_output.shape[1],
+ device=cross_attention_output.device
+ ).view(1, -1, 1)
+ pos_emb = create_sin_embedding(positions, cross_attention_output.shape[-1])
+ cross_attention_output = cross_attention_output + self.cross_attention_pos_emb_scale * pos_emb
+
+ if self._is_streaming:
+ self._streaming_state['offsets'] = offsets + T
+
+ return input, cross_attention_output
diff --git a/audiocraft/modules/conv.py b/audiocraft/modules/conv.py
new file mode 100644
index 0000000..972938a
--- /dev/null
+++ b/audiocraft/modules/conv.py
@@ -0,0 +1,245 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import typing as tp
+import warnings
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+from torch.nn.utils import spectral_norm, weight_norm
+
+
+CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
+ 'time_group_norm'])
+
+
+def apply_parametrization_norm(module: nn.Module, norm: str = 'none'):
+ assert norm in CONV_NORMALIZATIONS
+ if norm == 'weight_norm':
+ return weight_norm(module)
+ elif norm == 'spectral_norm':
+ return spectral_norm(module)
+ else:
+ # We already check was in CONV_NORMALIZATION, so any other choice
+ # doesn't need reparametrization.
+ return module
+
+
+def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs):
+ """Return the proper normalization module. If causal is True, this will ensure the returned
+ module is causal, or return an error if the normalization doesn't support causal evaluation.
+ """
+ assert norm in CONV_NORMALIZATIONS
+ if norm == 'time_group_norm':
+ if causal:
+ raise ValueError("GroupNorm doesn't support causal evaluation.")
+ assert isinstance(module, nn.modules.conv._ConvNd)
+ return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
+ else:
+ return nn.Identity()
+
+
+def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
+ padding_total: int = 0) -> int:
+ """See `pad_for_conv1d`.
+ """
+ length = x.shape[-1]
+ n_frames = (length - kernel_size + padding_total) / stride + 1
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
+ return ideal_length - length
+
+
+def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
+ """Pad for a convolution to make sure that the last window is full.
+ Extra padding is added at the end. This is required to ensure that we can rebuild
+ an output of the same length, as otherwise, even with padding, some time steps
+ might get removed.
+ For instance, with total padding = 4, kernel size = 4, stride = 2:
+ 0 0 1 2 3 4 5 0 0 # (0s are padding)
+ 1 2 3 # (output frames of a convolution, last 0 is never used)
+ 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
+ 1 2 3 4 # once you removed padding, we are missing one time step !
+ """
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
+ return F.pad(x, (0, extra_padding))
+
+
+def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
+ If this is the case, we insert extra 0 padding to the right before the reflection happen.
+ """
+ length = x.shape[-1]
+ padding_left, padding_right = paddings
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+ if mode == 'reflect':
+ max_pad = max(padding_left, padding_right)
+ extra_pad = 0
+ if length <= max_pad:
+ extra_pad = max_pad - length + 1
+ x = F.pad(x, (0, extra_pad))
+ padded = F.pad(x, paddings, mode, value)
+ end = padded.shape[-1] - extra_pad
+ return padded[..., :end]
+ else:
+ return F.pad(x, paddings, mode, value)
+
+
+def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
+ """Remove padding from x, handling properly zero padding. Only for 1d!
+ """
+ padding_left, padding_right = paddings
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+ assert (padding_left + padding_right) <= x.shape[-1]
+ end = x.shape[-1] - padding_right
+ return x[..., padding_left: end]
+
+
+class NormConv1d(nn.Module):
+ """Wrapper around Conv1d and normalization applied to this conv
+ to provide a uniform interface across normalization approaches.
+ """
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
+ super().__init__()
+ self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
+ self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
+ self.norm_type = norm
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.norm(x)
+ return x
+
+
+class NormConv2d(nn.Module):
+ """Wrapper around Conv2d and normalization applied to this conv
+ to provide a uniform interface across normalization approaches.
+ """
+ def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
+ super().__init__()
+ self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
+ self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
+ self.norm_type = norm
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.norm(x)
+ return x
+
+
+class NormConvTranspose1d(nn.Module):
+ """Wrapper around ConvTranspose1d and normalization applied to this conv
+ to provide a uniform interface across normalization approaches.
+ """
+ def __init__(self, *args, causal: bool = False, norm: str = 'none',
+ norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
+ super().__init__()
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
+ self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
+ self.norm_type = norm
+
+ def forward(self, x):
+ x = self.convtr(x)
+ x = self.norm(x)
+ return x
+
+
+class NormConvTranspose2d(nn.Module):
+ """Wrapper around ConvTranspose2d and normalization applied to this conv
+ to provide a uniform interface across normalization approaches.
+ """
+ def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
+ super().__init__()
+ self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
+ self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
+
+ def forward(self, x):
+ x = self.convtr(x)
+ x = self.norm(x)
+ return x
+
+
+class StreamableConv1d(nn.Module):
+ """Conv1d with some builtin handling of asymmetric or causal padding
+ and normalization.
+ """
+ def __init__(self, in_channels: int, out_channels: int,
+ kernel_size: int, stride: int = 1, dilation: int = 1,
+ groups: int = 1, bias: bool = True, causal: bool = False,
+ norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
+ pad_mode: str = 'reflect'):
+ super().__init__()
+ # warn user on unusual setup between dilation and stride
+ if stride > 1 and dilation > 1:
+ warnings.warn('StreamableConv1d has been initialized with stride > 1 and dilation > 1'
+ f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
+ self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
+ dilation=dilation, groups=groups, bias=bias, causal=causal,
+ norm=norm, norm_kwargs=norm_kwargs)
+ self.causal = causal
+ self.pad_mode = pad_mode
+
+ def forward(self, x):
+ B, C, T = x.shape
+ kernel_size = self.conv.conv.kernel_size[0]
+ stride = self.conv.conv.stride[0]
+ dilation = self.conv.conv.dilation[0]
+ kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
+ padding_total = kernel_size - stride
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
+ if self.causal:
+ # Left padding for causal
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
+ else:
+ # Asymmetric padding required for odd strides
+ padding_right = padding_total // 2
+ padding_left = padding_total - padding_right
+ x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
+ return self.conv(x)
+
+
+class StreamableConvTranspose1d(nn.Module):
+ """ConvTranspose1d with some builtin handling of asymmetric or causal padding
+ and normalization.
+ """
+ def __init__(self, in_channels: int, out_channels: int,
+ kernel_size: int, stride: int = 1, causal: bool = False,
+ norm: str = 'none', trim_right_ratio: float = 1.,
+ norm_kwargs: tp.Dict[str, tp.Any] = {}):
+ super().__init__()
+ self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
+ causal=causal, norm=norm, norm_kwargs=norm_kwargs)
+ self.causal = causal
+ self.trim_right_ratio = trim_right_ratio
+ assert self.causal or self.trim_right_ratio == 1., \
+ "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
+ assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
+
+ def forward(self, x):
+ kernel_size = self.convtr.convtr.kernel_size[0]
+ stride = self.convtr.convtr.stride[0]
+ padding_total = kernel_size - stride
+
+ y = self.convtr(x)
+
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
+ # removed at the very end, when keeping only the right length for the output,
+ # as removing it here would require also passing the length at the matching layer
+ # in the encoder.
+ if self.causal:
+ # Trim the padding on the right according to the specified ratio
+ # if trim_right_ratio = 1.0, trim everything from right
+ padding_right = math.ceil(padding_total * self.trim_right_ratio)
+ padding_left = padding_total - padding_right
+ y = unpad1d(y, (padding_left, padding_right))
+ else:
+ # Asymmetric padding required for odd strides
+ padding_right = padding_total // 2
+ padding_left = padding_total - padding_right
+ y = unpad1d(y, (padding_left, padding_right))
+ return y
diff --git a/audiocraft/modules/lstm.py b/audiocraft/modules/lstm.py
new file mode 100644
index 0000000..c086617
--- /dev/null
+++ b/audiocraft/modules/lstm.py
@@ -0,0 +1,25 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from torch import nn
+
+
+class StreamableLSTM(nn.Module):
+ """LSTM without worrying about the hidden state, nor the layout of the data.
+ Expects input as convolutional layout.
+ """
+ def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
+ super().__init__()
+ self.skip = skip
+ self.lstm = nn.LSTM(dimension, dimension, num_layers)
+
+ def forward(self, x):
+ x = x.permute(2, 0, 1)
+ y, _ = self.lstm(x)
+ if self.skip:
+ y = y + x
+ y = y.permute(1, 2, 0)
+ return y
diff --git a/audiocraft/modules/rope.py b/audiocraft/modules/rope.py
new file mode 100644
index 0000000..4b8c70b
--- /dev/null
+++ b/audiocraft/modules/rope.py
@@ -0,0 +1,124 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import typing as tp
+
+from torch import nn
+import torch
+
+
+class XPos(nn.Module):
+ """Length-extrapolatable positional embedding (xPos) from [Sun et al 2022](https://arxiv.org/abs/2212.10554v1).
+ This applies an exponential decay to the RoPE rotation matrix.
+
+ Args:
+ dim (int): Embedding dimension.
+ smoothing (float): Smoothing factor applied to the decay rates.
+ base_scale (int): Base decay rate, given in terms of scaling time.
+ device (torch.device or None): Device on which to initialize the module.
+ dtype (torch.dtype): dtype to use to generate the embedding.
+ """
+ def __init__(self, dim: int, smoothing: float = 0.4, base_scale: int = 512,
+ device=None, dtype: torch.dtype = torch.float32):
+ super().__init__()
+ assert dim % 2 == 0
+ assert dtype in [torch.float64, torch.float32]
+ self.dtype = dtype
+ self.base_scale = base_scale
+
+ half_dim = dim // 2
+ adim = torch.arange(half_dim, device=device, dtype=dtype)
+ decay_rates = (adim / half_dim + smoothing) / (1.0 + smoothing)
+ self.register_buffer("decay_rates", decay_rates)
+ self.decay: tp.Optional[torch.Tensor] = None
+
+ def get_decay(self, start: int, end: int):
+ """Create complex decay tensor, cache values for fast computation.
+ """
+ if self.decay is None or end > self.decay.shape[0]:
+ assert isinstance(self.decay_rates, torch.Tensor) # Satisfy type checker.
+ idx = torch.arange(end, device=self.decay_rates.device, dtype=self.dtype)
+ power = idx / self.base_scale
+ scale = self.decay_rates ** power.unsqueeze(-1)
+ self.decay = torch.polar(scale, torch.zeros_like(scale))
+ return self.decay[start:end] # [T, C/2]
+
+
+class RotaryEmbedding(nn.Module):
+ """Rotary positional embedding (RoPE) from [Su et al 2022](https://arxiv.org/abs/2104.09864).
+
+ Args:
+ dim (int): Embedding dimension (twice the number of frequencies).
+ max_period (float): Maximum period of the rotation frequencies.
+ xpos (bool): Use xPos, applies an exponential decay to rotation matrix.
+ scale (float): Scale of positional embedding, set to 0 to deactivate.
+ device (torch.device or None): Device on which to initialize the module.
+ dtype (torch.dtype): dtype to use to generate the embedding.
+ """
+ def __init__(self, dim: int, max_period: float = 10000.0, xpos: bool = False,
+ scale: float = 1.0, device=None, dtype: torch.dtype = torch.float32):
+ super().__init__()
+ assert dim % 2 == 0
+ self.scale = scale
+ assert dtype in [torch.float64, torch.float32]
+ self.dtype = dtype
+
+ adim = torch.arange(0, dim, 2, device=device, dtype=dtype)[: (dim // 2)]
+ frequencies = 1.0 / (max_period ** (adim / dim))
+ self.register_buffer("frequencies", frequencies)
+ self.rotation: tp.Optional[torch.Tensor] = None
+
+ self.xpos = XPos(dim, device=device, dtype=dtype) if xpos else None
+
+ def get_rotation(self, start: int, end: int):
+ """Create complex rotation tensor, cache values for fast computation.
+ """
+ if self.rotation is None or end > self.rotation.shape[0]:
+ assert isinstance(self.frequencies, torch.Tensor) # Satisfy type checker.
+ idx = torch.arange(end, device=self.frequencies.device, dtype=self.dtype)
+ angles = torch.outer(idx, self.frequencies)
+ self.rotation = torch.polar(torch.ones_like(angles), angles)
+ return self.rotation[start:end]
+
+ def rotate(self, x: torch.Tensor, start: int = 0, invert_decay: bool = False):
+ """Apply rope rotation to query or key tensor.
+ """
+ T = x.shape[1]
+ rotation = self.get_rotation(start, start + T).unsqueeze(0).unsqueeze(2)
+
+ if self.xpos:
+ decay = self.xpos.get_decay(start, start + T).unsqueeze(0).unsqueeze(2)
+ else:
+ decay = 1.0
+
+ if invert_decay:
+ decay = decay ** -1
+
+ x_complex = torch.view_as_complex(x.to(self.dtype).reshape(*x.shape[:-1], -1, 2))
+ scaled_rotation = (rotation * decay) * self.scale + (1.0 - self.scale)
+ x_out = torch.view_as_real(x_complex * scaled_rotation).flatten(-2)
+
+ return x_out.type_as(x)
+
+ def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0):
+ """ Apply rope rotation to both query and key tensors.
+ Supports streaming mode, in which query and key are not expected to have the same shape.
+ In streaming mode, key will be of legnth [P + C] with P the cached past timesteps, but
+ query will be [C] (typically C == 1).
+
+ Args:
+ query (torch.Tensor): Query to rotate.
+ key (torch.Tensor): Key to rotate.
+ start (int): Start index of the sequence for time offset.
+ """
+ query_timesteps = query.shape[1]
+ key_timesteps = key.shape[1]
+ streaming_offset = key_timesteps - query_timesteps
+
+ query_out = self.rotate(query, start + streaming_offset)
+ key_out = self.rotate(key, start, invert_decay=True)
+
+ return query_out, key_out
diff --git a/audiocraft/modules/seanet.py b/audiocraft/modules/seanet.py
new file mode 100644
index 0000000..3e5998e
--- /dev/null
+++ b/audiocraft/modules/seanet.py
@@ -0,0 +1,258 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import typing as tp
+
+import numpy as np
+import torch.nn as nn
+
+from .conv import StreamableConv1d, StreamableConvTranspose1d
+from .lstm import StreamableLSTM
+
+
+class SEANetResnetBlock(nn.Module):
+ """Residual block from SEANet model.
+
+ Args:
+ dim (int): Dimension of the input/output.
+ kernel_sizes (list): List of kernel sizes for the convolutions.
+ dilations (list): List of dilations for the convolutions.
+ activation (str): Activation function.
+ activation_params (dict): Parameters to provide to the activation function.
+ norm (str): Normalization method.
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
+ causal (bool): Whether to use fully causal convolution.
+ pad_mode (str): Padding mode for the convolutions.
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
+ true_skip (bool): Whether to use true skip connection or a simple
+ (streamable) convolution as the skip connection.
+ """
+ def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1],
+ activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
+ norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False,
+ pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True):
+ super().__init__()
+ assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations'
+ act = getattr(nn, activation)
+ hidden = dim // compress
+ block = []
+ for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
+ in_chs = dim if i == 0 else hidden
+ out_chs = dim if i == len(kernel_sizes) - 1 else hidden
+ block += [
+ act(**activation_params),
+ StreamableConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation,
+ norm=norm, norm_kwargs=norm_params,
+ causal=causal, pad_mode=pad_mode),
+ ]
+ self.block = nn.Sequential(*block)
+ self.shortcut: nn.Module
+ if true_skip:
+ self.shortcut = nn.Identity()
+ else:
+ self.shortcut = StreamableConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params,
+ causal=causal, pad_mode=pad_mode)
+
+ def forward(self, x):
+ return self.shortcut(x) + self.block(x)
+
+
+class SEANetEncoder(nn.Module):
+ """SEANet encoder.
+
+ Args:
+ channels (int): Audio channels.
+ dimension (int): Intermediate representation dimension.
+ n_filters (int): Base width for the model.
+ n_residual_layers (int): nb of residual layers.
+ ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
+ upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
+ that must match the decoder order. We use the decoder order as some models may only employ the decoder.
+ activation (str): Activation function.
+ activation_params (dict): Parameters to provide to the activation function.
+ norm (str): Normalization method.
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
+ kernel_size (int): Kernel size for the initial convolution.
+ last_kernel_size (int): Kernel size for the initial convolution.
+ residual_kernel_size (int): Kernel size for the residual layers.
+ dilation_base (int): How much to increase the dilation with each layer.
+ causal (bool): Whether to use fully causal convolution.
+ pad_mode (str): Padding mode for the convolutions.
+ true_skip (bool): Whether to use true skip connection or a simple
+ (streamable) convolution as the skip connection in the residual network blocks.
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
+ lstm (int): Number of LSTM layers at the end of the encoder.
+ disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
+ For the encoder, it corresponds to the N first blocks.
+ """
+ def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
+ ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
+ norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
+ last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
+ pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
+ disable_norm_outer_blocks: int = 0):
+ super().__init__()
+ self.channels = channels
+ self.dimension = dimension
+ self.n_filters = n_filters
+ self.ratios = list(reversed(ratios))
+ del ratios
+ self.n_residual_layers = n_residual_layers
+ self.hop_length = np.prod(self.ratios)
+ self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
+ self.disable_norm_outer_blocks = disable_norm_outer_blocks
+ assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
+ "Number of blocks for which to disable norm is invalid." \
+ "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
+
+ act = getattr(nn, activation)
+ mult = 1
+ model: tp.List[nn.Module] = [
+ StreamableConv1d(channels, mult * n_filters, kernel_size,
+ norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
+ norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
+ ]
+ # Downsample to raw audio scale
+ for i, ratio in enumerate(self.ratios):
+ block_norm = 'none' if self.disable_norm_outer_blocks >= i + 2 else norm
+ # Add residual layers
+ for j in range(n_residual_layers):
+ model += [
+ SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1],
+ dilations=[dilation_base ** j, 1],
+ norm=block_norm, norm_params=norm_params,
+ activation=activation, activation_params=activation_params,
+ causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
+
+ # Add downsampling layers
+ model += [
+ act(**activation_params),
+ StreamableConv1d(mult * n_filters, mult * n_filters * 2,
+ kernel_size=ratio * 2, stride=ratio,
+ norm=block_norm, norm_kwargs=norm_params,
+ causal=causal, pad_mode=pad_mode),
+ ]
+ mult *= 2
+
+ if lstm:
+ model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
+
+ model += [
+ act(**activation_params),
+ StreamableConv1d(mult * n_filters, dimension, last_kernel_size,
+ norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
+ norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
+ ]
+
+ self.model = nn.Sequential(*model)
+
+ def forward(self, x):
+ return self.model(x)
+
+
+class SEANetDecoder(nn.Module):
+ """SEANet decoder.
+
+ Args:
+ channels (int): Audio channels.
+ dimension (int): Intermediate representation dimension.
+ n_filters (int): Base width for the model.
+ n_residual_layers (int): nb of residual layers.
+ ratios (Sequence[int]): kernel size and stride ratios.
+ activation (str): Activation function.
+ activation_params (dict): Parameters to provide to the activation function.
+ final_activation (str): Final activation function after all convolutions.
+ final_activation_params (dict): Parameters to provide to the activation function.
+ norm (str): Normalization method.
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
+ kernel_size (int): Kernel size for the initial convolution.
+ last_kernel_size (int): Kernel size for the initial convolution.
+ residual_kernel_size (int): Kernel size for the residual layers.
+ dilation_base (int): How much to increase the dilation with each layer.
+ causal (bool): Whether to use fully causal convolution.
+ pad_mode (str): Padding mode for the convolutions.
+ true_skip (bool): Whether to use true skip connection or a simple.
+ (streamable) convolution as the skip connection in the residual network blocks.
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
+ lstm (int): Number of LSTM layers at the end of the encoder.
+ disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
+ For the decoder, it corresponds to the N last blocks.
+ trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
+ If equal to 1.0, it means that all the trimming is done at the right.
+ """
+ def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
+ ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
+ final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None,
+ norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
+ last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
+ pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
+ disable_norm_outer_blocks: int = 0, trim_right_ratio: float = 1.0):
+ super().__init__()
+ self.dimension = dimension
+ self.channels = channels
+ self.n_filters = n_filters
+ self.ratios = ratios
+ del ratios
+ self.n_residual_layers = n_residual_layers
+ self.hop_length = np.prod(self.ratios)
+ self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
+ self.disable_norm_outer_blocks = disable_norm_outer_blocks
+ assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
+ "Number of blocks for which to disable norm is invalid." \
+ "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
+
+ act = getattr(nn, activation)
+ mult = int(2 ** len(self.ratios))
+ model: tp.List[nn.Module] = [
+ StreamableConv1d(dimension, mult * n_filters, kernel_size,
+ norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
+ norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
+ ]
+
+ if lstm:
+ model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
+
+ # Upsample to raw audio scale
+ for i, ratio in enumerate(self.ratios):
+ block_norm = 'none' if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1) else norm
+ # Add upsampling layers
+ model += [
+ act(**activation_params),
+ StreamableConvTranspose1d(mult * n_filters, mult * n_filters // 2,
+ kernel_size=ratio * 2, stride=ratio,
+ norm=block_norm, norm_kwargs=norm_params,
+ causal=causal, trim_right_ratio=trim_right_ratio),
+ ]
+ # Add residual layers
+ for j in range(n_residual_layers):
+ model += [
+ SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1],
+ dilations=[dilation_base ** j, 1],
+ activation=activation, activation_params=activation_params,
+ norm=block_norm, norm_params=norm_params, causal=causal,
+ pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
+
+ mult //= 2
+
+ # Add final layers
+ model += [
+ act(**activation_params),
+ StreamableConv1d(n_filters, channels, last_kernel_size,
+ norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
+ norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
+ ]
+ # Add optional final activation to decoder (eg. tanh)
+ if final_activation is not None:
+ final_act = getattr(nn, final_activation)
+ final_activation_params = final_activation_params or {}
+ model += [
+ final_act(**final_activation_params)
+ ]
+ self.model = nn.Sequential(*model)
+
+ def forward(self, z):
+ y = self.model(z)
+ return y
diff --git a/audiocraft/modules/streaming.py b/audiocraft/modules/streaming.py
new file mode 100644
index 0000000..fdbdf5e
--- /dev/null
+++ b/audiocraft/modules/streaming.py
@@ -0,0 +1,135 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Streaming module API that should be implemented by all Streaming components,
+"""
+
+from contextlib import contextmanager
+import typing as tp
+from torch import nn
+import torch
+
+
+State = tp.Dict[str, torch.Tensor]
+
+
+class StreamingModule(nn.Module):
+ """Common API for streaming components.
+
+ Each streaming component has a streaming state, which is just a dict[str, Tensor].
+ By convention, the first dim of each tensor must be the batch size.
+ Don't use dots in the key names, as this would clash with submodules
+ (like in state_dict).
+
+ If `self._is_streaming` is True, the component should use and remember
+ the proper state inside `self._streaming_state`.
+
+ To set a streaming component in streaming state, use
+
+ with module.streaming():
+ ...
+
+ This will automatically reset the streaming state when exiting the context manager.
+ This also automatically propagates to all streaming children module.
+
+ Some module might also implement the `StreamingModule.flush` method, although
+ this one is trickier, as all parents module must be StreamingModule and implement
+ it as well for it to work properly. See `StreamingSequential` after.
+ """
+ def __init__(self) -> None:
+ super().__init__()
+ self._streaming_state: State = {}
+ self._is_streaming = False
+
+ def _apply_named_streaming(self, fn: tp.Any):
+ for name, module in self.named_modules():
+ if isinstance(module, StreamingModule):
+ fn(name, module)
+
+ def _set_streaming(self, streaming: bool):
+ def _set_streaming(name, module):
+ module._is_streaming = streaming
+ self._apply_named_streaming(_set_streaming)
+
+ @contextmanager
+ def streaming(self):
+ """Context manager to enter streaming mode. Reset streaming state on exit.
+ """
+ self._set_streaming(True)
+ try:
+ yield
+ finally:
+ self._set_streaming(False)
+ self.reset_streaming()
+
+ def reset_streaming(self):
+ """Reset the streaming state.
+ """
+ def _reset(name: str, module: StreamingModule):
+ module._streaming_state.clear()
+
+ self._apply_named_streaming(_reset)
+
+ def get_streaming_state(self) -> State:
+ """Return the streaming state, including that of sub-modules.
+ """
+ state: State = {}
+
+ def _add(name: str, module: StreamingModule):
+ if name:
+ name += "."
+ for key, value in module._streaming_state.items():
+ state[name + key] = value
+
+ self._apply_named_streaming(_add)
+ return state
+
+ def set_streaming_state(self, state: State):
+ """Set the streaming state, including that of sub-modules.
+ """
+ state = dict(state)
+
+ def _set(name: str, module: StreamingModule):
+ if name:
+ name += "."
+ module._streaming_state.clear()
+ for key, value in list(state.items()):
+ # complexity is not ideal here, but probably fine.
+ if key.startswith(name):
+ local_key = key[len(name):]
+ if '.' not in local_key:
+ module._streaming_state[local_key] = value
+ del state[key]
+
+ self._apply_named_streaming(_set)
+ assert len(state) == 0, list(state.keys())
+
+ def flush(self, x: tp.Optional[torch.Tensor] = None):
+ """Flush any remaining outputs that were waiting for completion.
+ Typically, for convolutions, this will add the final padding
+ and process the last buffer.
+
+ This should take an optional argument `x`, which will be provided
+ if a module before this one in the streaming pipeline has already
+ spitted out a flushed out buffer.
+ """
+ if x is None:
+ return None
+ else:
+ return self(x)
+
+
+class StreamingSequential(StreamingModule, nn.Sequential):
+ """A streaming compatible alternative of `nn.Sequential`.
+ """
+ def flush(self, x: tp.Optional[torch.Tensor] = None):
+ for module in self:
+ if isinstance(module, StreamingModule):
+ x = module.flush(x)
+ elif x is not None:
+ x = module(x)
+ return x
diff --git a/audiocraft/modules/transformer.py b/audiocraft/modules/transformer.py
new file mode 100644
index 0000000..e69cca8
--- /dev/null
+++ b/audiocraft/modules/transformer.py
@@ -0,0 +1,747 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Transformer model, with streaming support, xformer attention support
+and easy causal attention with a potentially finite receptive field.
+
+See `StreamingTransformer` for more information.
+
+Unlike regular PyTorch Transformer, we make the hard choice that batches are first.
+"""
+
+import typing as tp
+
+from einops import rearrange
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+from torch.utils.checkpoint import checkpoint as torch_checkpoint
+from xformers import ops
+
+from .rope import RotaryEmbedding
+from .streaming import StreamingModule
+
+_efficient_attention_backend: str = 'torch'
+
+
+def set_efficient_attention_backend(backend: str = 'torch'):
+ # Using torch by default, it seems a bit faster on older P100 GPUs (~20% faster).
+ global _efficient_attention_backend
+ assert _efficient_attention_backend in ['xformers', 'torch']
+ _efficient_attention_backend = backend
+
+
+def _get_attention_time_dimension() -> int:
+ if _efficient_attention_backend == 'torch':
+ return 2
+ else:
+ return 1
+
+
+def _is_profiled() -> bool:
+ # Return true if we are currently running with a xformers profiler activated.
+ try:
+ from xformers.profiler import profiler
+ except ImportError:
+ return False
+ return profiler._Profiler._CURRENT_PROFILER is not None
+
+
+def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module:
+ """Create normalization module for transformer encoder layer.
+
+ Args:
+ norm_type (str): Normalization method.
+ dim (int): Dimension of the normalized layer.
+ **kwargs (dict): Additional parameters for normalization layer.
+ Returns:
+ nn.Module: Normalization module.
+ """
+ if norm_type == 'layer_norm':
+ return nn.LayerNorm(dim, eps=1e-5, **kwargs)
+ else:
+ raise ValueError(f"Unknown norm type: {norm_type}")
+
+
+def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000,
+ dtype: torch.dtype = torch.float32) -> torch.Tensor:
+ """Create sinusoidal positional embedding, with shape `[B, T, C]`.
+
+ Args:
+ positions (torch.Tensor): LongTensor of positions.
+ dim (int): Dimension of the embedding.
+ max_period (float): Maximum period of the cosine/sine functions.
+ dtype (torch.dtype or str): dtype to use to generate the embedding.
+ Returns:
+ torch.Tensor: Sinusoidal positional embedding.
+ """
+ # We aim for BTC format
+ assert dim % 2 == 0
+ half_dim = dim // 2
+ positions = positions.to(dtype)
+ adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1)
+ max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype) # avoid sync point
+ phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
+ return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
+
+
+def expand_repeated_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
+ """torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers"""
+ if n_rep == 1:
+ return x
+ if _efficient_attention_backend == 'torch':
+ bs, n_kv_heads, slen, head_dim = x.shape
+ return (
+ x[:, :, None, :, :]
+ .expand(bs, n_kv_heads, n_rep, slen, head_dim)
+ .reshape(bs, n_kv_heads * n_rep, slen, head_dim)
+ )
+ else:
+ bs, slen, n_kv_heads, head_dim = x.shape
+ return (
+ x[:, :, :, None, :]
+ .expand(bs, slen, n_kv_heads, n_rep, head_dim)
+ .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
+ )
+
+
+class LayerScale(nn.Module):
+ """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
+ This rescales diagonaly the residual outputs close to 0, with a learnt scale.
+
+ Args:
+ channels (int): Number of channels.
+ init (float): Initial scale.
+ channel_last (bool): If True, expect `[*, C]` shaped tensors, otherwise, `[*, C, T]`.
+ device (torch.device or None): Device on which to initialize the module.
+ dtype (torch.dtype or None): dtype to use to initialize the module.
+ """
+ def __init__(self, channels: int, init: float = 1e-4, channel_last: bool = True,
+ device=None, dtype=None):
+ super().__init__()
+ self.channel_last = channel_last
+ self.scale = nn.Parameter(
+ torch.full((channels,), init,
+ requires_grad=True, device=device, dtype=dtype))
+
+ def forward(self, x: torch.Tensor):
+ if self.channel_last:
+ return self.scale * x
+ else:
+ return self.scale[:, None] * x
+
+
+class StreamingMultiheadAttention(StreamingModule):
+ """Similar to `nn.MultiheadAttention` but with support for streaming, causal evaluation.
+
+ Args:
+ embed_dim (int): Dimension to project to.
+ num_heads (int): Number of heads.
+ dropout (float): Dropout level.
+ bias (bool): Use bias in projections.
+ causal (bool): Causal mask applied automatically.
+ past_context (int or None): Receptive field for the causal mask, infinite if None.
+ custom (bool): Use custom MHA implementation, for testing / benchmarking.
+ memory_efficient (bool): Use xformers based memory efficient attention.
+ attention_as_float32 (bool): Perform the attention as float32
+ (especially important with memory_efficient as autocast won't do this automatically).
+ rope (`RotaryEmbedding` or None): Rope embedding to use.
+ cross_attention: Should be true when used as a cross attention.
+ All keys and values must be available at once, streaming is only for the queries.
+ Cannot be used with `causal` or `rope` (as it wouldn't make sens to
+ intepret the time steps in the keys relative to those in the queries).
+ safe_streaming (bool): Bug fix, will go away with xformers update.
+ qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product.
+ kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
+ This will lead to faster decoding time on A100 or other GPUs with tensorcore.
+ device (torch.device or None): Sevice on which to initialize.
+ dtype (torch.dtype or None): dtype to use.
+ """
+ def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True,
+ causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False,
+ memory_efficient: bool = False, attention_as_float32: bool = False,
+ rope: tp.Optional[RotaryEmbedding] = None, cross_attention: bool = False,
+ safe_streaming: bool = True, qk_layer_norm: bool = False, kv_repeat: int = 1,
+ device=None, dtype=None):
+ super().__init__()
+ factory_kwargs = {'device': device, 'dtype': dtype}
+ if past_context is not None:
+ assert causal
+
+ self.embed_dim = embed_dim
+ self.causal = causal
+ self.past_context = past_context
+ self.memory_efficient = memory_efficient
+ self.attention_as_float32 = attention_as_float32
+ self.rope = rope
+ self.cross_attention = cross_attention
+ self.safe_streaming = safe_streaming
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.kv_repeat = kv_repeat
+ if cross_attention:
+ assert not causal, "Causal cannot work with cross attention."
+ assert rope is None, "Rope cannot work with cross attention."
+
+ if memory_efficient:
+ _verify_xformers_memory_efficient_compat()
+
+ self.custom = _is_custom(custom, memory_efficient)
+ if self.custom:
+ out_dim = embed_dim
+ assert num_heads % kv_repeat == 0
+ assert not cross_attention or kv_repeat == 1
+ num_kv = num_heads // kv_repeat
+ kv_dim = (embed_dim // num_heads) * num_kv
+ out_dim += 2 * kv_dim
+ in_proj = nn.Linear(embed_dim, out_dim, bias=bias, **factory_kwargs)
+ # We try to follow the default PyTorch MHA convention, to easily compare results.
+ self.in_proj_weight = in_proj.weight
+ self.in_proj_bias = in_proj.bias
+ if bias:
+ self.in_proj_bias.data.zero_() # Following Pytorch convention
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
+ if bias:
+ self.out_proj.bias.data.zero_()
+ else:
+ assert not qk_layer_norm
+ assert kv_repeat == 1
+ self.mha = nn.MultiheadAttention(
+ embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True,
+ **factory_kwargs)
+ self.qk_layer_norm = qk_layer_norm
+ if qk_layer_norm:
+ assert self.custom
+ assert kv_repeat == 1
+ ln_dim = embed_dim
+ self.q_layer_norm = nn.LayerNorm(ln_dim)
+ self.k_layer_norm = nn.LayerNorm(ln_dim)
+
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
+ if not self.custom:
+ # Support compat with regular MHA
+ keys = [n for n, _ in self.mha.named_parameters()]
+ for key in keys:
+ if prefix + key in state_dict:
+ state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key)
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
+
+ def _get_mask(self, current_steps: int, device: torch.device, dtype: torch.dtype):
+ # Return a causal mask, accounting for potentially stored past keys/values
+ # We actually return a bias for the attention score, as this has the same
+ # convention both in the builtin MHA in Pytorch, and Xformers functions.
+ time_dim = _get_attention_time_dimension()
+ if self.memory_efficient:
+ from xformers.ops import LowerTriangularMask
+ if current_steps == 1:
+ # If we only have one step, then we do not need a mask.
+ return None
+ elif 'past_keys' in self._streaming_state:
+ raise RuntimeError('Not supported at the moment')
+ else:
+ # Then we can safely use a lower triangular mask
+ return LowerTriangularMask()
+ if self._streaming_state:
+ past_keys = self._streaming_state['past_keys']
+ past_steps = past_keys.shape[time_dim]
+ else:
+ past_steps = 0
+
+ queries_pos = torch.arange(
+ past_steps, current_steps + past_steps, device=device).view(-1, 1)
+ keys_pos = torch.arange(past_steps + current_steps, device=device).view(1, -1)
+ delta = queries_pos - keys_pos
+ valid = delta >= 0
+ if self.past_context is not None:
+ valid &= (delta <= self.past_context)
+ return torch.where(
+ valid,
+ torch.zeros([], device=device, dtype=dtype),
+ torch.full([], float('-inf'), device=device, dtype=dtype))
+
+ def _complete_kv(self, k, v):
+ time_dim = _get_attention_time_dimension()
+ if self.cross_attention:
+ # With cross attention we assume all keys and values
+ # are already available, and streaming is with respect
+ # to the queries only.
+ return k, v
+ # Complete the key/value pair using the streaming state.
+ if self._streaming_state:
+ pk = self._streaming_state['past_keys']
+ nk = torch.cat([pk, k], dim=time_dim)
+ if v is k:
+ nv = nk
+ else:
+ pv = self._streaming_state['past_values']
+ nv = torch.cat([pv, v], dim=time_dim)
+ else:
+ nk = k
+ nv = v
+
+ assert nk.shape[time_dim] == nv.shape[time_dim]
+ offset = 0
+ if self.past_context is not None:
+ offset = max(0, nk.shape[time_dim] - self.past_context)
+ if self._is_streaming:
+ self._streaming_state['past_keys'] = nk[:, offset:]
+ if v is not k:
+ self._streaming_state['past_values'] = nv[:, offset:]
+ if 'offset' in self._streaming_state:
+ self._streaming_state['offset'] += offset
+ else:
+ self._streaming_state['offset'] = torch.tensor(0)
+ return nk, nv
+
+ def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
+ # TODO: fix and verify layout.
+ assert _efficient_attention_backend == 'xformers', 'Rope not supported with torch attn.'
+ # Apply rope embeddings to query and key tensors.
+ assert self.rope is not None
+ if 'past_keys' in self._streaming_state:
+ past_keys_offset = self._streaming_state['past_keys'].shape[1]
+ else:
+ past_keys_offset = 0
+ if 'offset' in self._streaming_state:
+ past_context_offset = int(self._streaming_state['offset'].item())
+ else:
+ past_context_offset = 0
+ streaming_offset = past_context_offset + past_keys_offset
+ return self.rope.rotate_qk(query, key, start=streaming_offset)
+
+ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
+ key_padding_mask=None, need_weights=False, attn_mask=None,
+ average_attn_weights=True, is_causal=False):
+ assert attn_mask is None
+ assert not is_causal, ("new param added in torch 2.0.1 not supported, "
+ "use the causal args in the constructor.")
+
+ time_dim = _get_attention_time_dimension()
+ if time_dim == 2:
+ layout = "b h t d"
+ else:
+ layout = "b t h d"
+ dtype = query.dtype
+ if self._is_streaming:
+ assert self.causal or self.cross_attention, \
+ "Streaming only available for causal or cross attention"
+
+ if self.causal:
+ # At the moment we specialize only for the self-attention case.
+ assert query.shape[1] == key.shape[1], "Causal only for same length query / key / value"
+ assert value.shape[1] == key.shape[1], "Causal only for same length query / key / value"
+ attn_mask = self._get_mask(query.shape[1], query.device, query.dtype)
+
+ if self.custom:
+ # custom implementation
+ assert need_weights is False
+ assert key_padding_mask is None
+ if self.cross_attention:
+ # Different queries, keys, values, we have to spit manually the weights
+ # before applying the linear.
+ dim = self.in_proj_weight.shape[0] // 3
+ if self.in_proj_bias is None:
+ bias_q, bias_k, bias_v = None, None, None
+ else:
+ bias_q = self.in_proj_bias[:dim]
+ bias_k = self.in_proj_bias[dim: 2 * dim]
+ bias_v = self.in_proj_bias[2 * dim:]
+ q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q)
+ # todo: when streaming, we could actually save k, v and check the shape actually match.
+ k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim], bias_k)
+ v = nn.functional.linear(value, self.in_proj_weight[2 * dim:], bias_v)
+ if self.qk_layer_norm is True:
+ q = self.q_layer_norm(q)
+ k = self.k_layer_norm(k)
+ q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
+ else:
+ if not _is_profiled():
+ # profiling breaks that propertysomehow.
+ assert query is key, "specialized implementation"
+ assert value is key, "specialized implementation"
+ projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
+ if self.kv_repeat == 1:
+ if time_dim == 2:
+ bound_layout = "b h p t d"
+ else:
+ bound_layout = "b t p h d"
+ packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
+ q, k, v = ops.unbind(packed, dim=2)
+ else:
+ embed_dim = self.embed_dim
+ per_head_dim = (embed_dim // self.num_heads)
+ kv_heads = self.num_heads // self.kv_repeat
+ q = projected[:, :, :embed_dim]
+ start = embed_dim
+ end = start + per_head_dim * kv_heads
+ k = projected[:, :, start: end]
+ v = projected[:, :, end:]
+ q = rearrange(q, f"b t (h d) -> {layout}", h=self.num_heads)
+ k = rearrange(k, f"b t (h d) -> {layout}", h=kv_heads)
+ v = rearrange(v, f"b t (h d) -> {layout}", h=kv_heads)
+
+ if self.qk_layer_norm is True:
+ assert self.kv_repeat == 1
+ q, k = [rearrange(x, f"{layout} -> b t (h d)") for x in [q, k]]
+ q = self.q_layer_norm(q)
+ k = self.k_layer_norm(k)
+ q, k = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k]]
+ if self.rope:
+ q, k = self._apply_rope(q, k)
+ k, v = self._complete_kv(k, v)
+ if self.kv_repeat > 1:
+ k = expand_repeated_kv(k, self.kv_repeat)
+ v = expand_repeated_kv(v, self.kv_repeat)
+ if self.attention_as_float32:
+ q, k, v = [x.float() for x in [q, k, v]]
+ if self.memory_efficient:
+ p = self.dropout if self.training else 0
+ if _efficient_attention_backend == 'torch':
+ x = torch.nn.functional.scaled_dot_product_attention(
+ q, k, v, is_causal=attn_mask is not None, dropout_p=p)
+ else:
+ x = ops.memory_efficient_attention(q, k, v, attn_mask, p=p)
+ else:
+ # We include the dot product as float32, for consistency
+ # with the other implementations that include that step
+ # as part of the attention. Note that when using `autocast`,
+ # the einsums would be done as bfloat16, but the softmax
+ # would be done as bfloat16, so `attention_as_float32` will
+ # extend a bit the range of operations done in float32,
+ # although this should make no difference.
+ q = q / q.shape[-1] ** 0.5
+ key_layout = layout.replace('t', 'k')
+ query_layout = layout
+ if self._is_streaming and self.safe_streaming and q.device.type == 'cuda':
+ with torch.autocast(device_type=q.device.type, dtype=torch.float32):
+ pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
+ else:
+ pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
+ if attn_mask is not None:
+ pre_w = pre_w + attn_mask
+ w = torch.softmax(pre_w, dim=-1)
+ w = F.dropout(w, self.dropout, training=self.training).to(v)
+ # Key and value have the same format.
+ x = torch.einsum(f"b h t k, {key_layout} -> {layout}", w, v)
+ x = x.to(dtype)
+ x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
+ x = self.out_proj(x)
+ else:
+ key, value = self._complete_kv(key, value)
+ if self.attention_as_float32:
+ query, key, value = [x.float() for x in [query, key, value]]
+ x, _ = self.mha(
+ query, key, value, key_padding_mask,
+ need_weights, attn_mask, average_attn_weights)
+ x = x.to(dtype)
+
+ return x, None
+
+
+class StreamingTransformerLayer(nn.TransformerEncoderLayer):
+ """TransformerLayer with Streaming / Causal support.
+ This also integrates cross_attention, when passing `cross_attention=True`,
+ rather than having two separate classes like in PyTorch.
+
+ Args:
+ d_model (int): Dimension of the data.
+ num_heads (int): Number of heads.
+ dim_feedforward (int): Intermediate dimension of FF module.
+ dropout (float): Dropout both for MHA and FF.
+ bias_ff (bool): Use bias for FF.
+ bias_attn (bool): Use bias for MHA.
+ causal (bool): Causal mask applied automatically.
+ past_context (int or None): Receptive field for the causal mask, infinite if None.
+ custom (bool): Use custom MHA implementation, for testing / benchmarking.
+ memory_efficient (bool): Use xformers based memory efficient attention.
+ attention_as_float32 (bool): Perform the attention as float32
+ (especially important with memory_efficient as autocast won't do this automatically).
+ qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product in attention.
+ qk_layer_norm_cross (bool): Same for the cross attention.
+ cross_attention (bool): If True, expect to get secondary input for cross-attention.
+ Cross attention will use the default MHA, as it typically won't require
+ special treatment.
+ layer_scale (float or None): If not None, LayerScale will be used with
+ the given value as initial scale.
+ rope (`RotaryEmbedding` or None): Rope embedding to use.
+ attention_dropout (float or None): If not None, separate the value of the dimension dropout
+ in FFN and of the attention dropout.
+ kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
+ This will lead to faster decoding time on A100 or other GPUs with tensorcore.
+ device (torch.device or None): Device on which to initialize.
+ dtype (torch.dtype or None): dtype to use.
+ **kwargs: See `nn.TransformerEncoderLayer`.
+ """
+ def __init__(self, d_model: int, num_heads: int, dim_feedforward: int = 2048, dropout: float = 0.1,
+ bias_ff: bool = True, bias_attn: bool = True, causal: bool = False,
+ past_context: tp.Optional[int] = None, custom: bool = False,
+ memory_efficient: bool = False, attention_as_float32: bool = False,
+ qk_layer_norm: bool = False, qk_layer_norm_cross: bool = False,
+ cross_attention: bool = False, layer_scale: tp.Optional[float] = None,
+ rope: tp.Optional[RotaryEmbedding] = None, attention_dropout: tp.Optional[float] = None,
+ kv_repeat: int = 1, norm: str = 'layer_norm', device=None, dtype=None, **kwargs):
+ super().__init__(d_model, num_heads, dim_feedforward, dropout,
+ device=device, dtype=dtype, batch_first=True, **kwargs)
+ factory_kwargs = {'device': device, 'dtype': dtype}
+ # Redefine self_attn to our streaming multi-head attention
+ attn_kwargs: tp.Dict[str, tp.Any] = {
+ 'embed_dim': d_model,
+ 'num_heads': num_heads,
+ 'dropout': dropout if attention_dropout is None else attention_dropout,
+ 'bias': bias_attn,
+ 'custom': custom,
+ 'memory_efficient': memory_efficient,
+ 'attention_as_float32': attention_as_float32,
+ }
+ self.self_attn: StreamingMultiheadAttention = StreamingMultiheadAttention(
+ causal=causal, past_context=past_context, rope=rope, qk_layer_norm=qk_layer_norm,
+ kv_repeat=kv_repeat, **attn_kwargs, **factory_kwargs) # type: ignore
+ # Redefine feedforward layers to expose bias parameter
+ self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias_ff, **factory_kwargs)
+ self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias_ff, **factory_kwargs)
+
+ self.layer_scale_1: nn.Module
+ self.layer_scale_2: nn.Module
+ if layer_scale is None:
+ self.layer_scale_1 = nn.Identity()
+ self.layer_scale_2 = nn.Identity()
+ else:
+ self.layer_scale_1 = LayerScale(d_model, layer_scale, **factory_kwargs)
+ self.layer_scale_2 = LayerScale(d_model, layer_scale, **factory_kwargs)
+
+ self.cross_attention: tp.Optional[nn.Module] = None
+ if cross_attention:
+ self.cross_attention = StreamingMultiheadAttention(
+ cross_attention=True, qk_layer_norm=qk_layer_norm_cross,
+ **attn_kwargs, **factory_kwargs)
+ # Norm and dropout
+ self.dropout_cross = nn.Dropout(dropout)
+ # eps value matching that used in PyTorch reference implementation.
+ self.norm_cross = nn.LayerNorm(d_model, eps=1e-5, **factory_kwargs)
+ self.layer_scale_cross: nn.Module
+ if layer_scale is None:
+ self.layer_scale_cross = nn.Identity()
+ else:
+ self.layer_scale_cross = LayerScale(d_model, layer_scale, **factory_kwargs)
+ self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
+ self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
+
+ def _cross_attention_block(self, src: torch.Tensor,
+ cross_attention_src: torch.Tensor) -> torch.Tensor:
+ assert self.cross_attention is not None
+ # queries are from src, keys and values from cross_attention_src.
+ x = self.cross_attention(
+ src, cross_attention_src, cross_attention_src, need_weights=False)[0]
+ return self.dropout_cross(x) # type: ignore
+
+ def forward(self, src: torch.Tensor, src_mask: tp.Optional[torch.Tensor] = None, # type: ignore
+ src_key_padding_mask: tp.Optional[torch.Tensor] = None,
+ cross_attention_src: tp.Optional[torch.Tensor] = None):
+ if self.cross_attention is None:
+ assert cross_attention_src is None
+ else:
+ assert cross_attention_src is not None
+ x = src
+ if self.norm_first:
+ x = x + self.layer_scale_1(
+ self._sa_block(self.norm1(x), src_mask, src_key_padding_mask))
+ if cross_attention_src is not None:
+ x = x + self.layer_scale_cross(
+ self._cross_attention_block(
+ self.norm_cross(x), cross_attention_src))
+ x = x + self.layer_scale_2(self._ff_block(self.norm2(x)))
+ else:
+ x = self.norm1(x + self.layer_scale_1(
+ self._sa_block(x, src_mask, src_key_padding_mask)))
+ if cross_attention_src is not None:
+ x = self.norm_cross(
+ x + self.layer_scale_cross(
+ self._cross_attention_block(src, cross_attention_src)))
+ x = self.norm2(x + self.layer_scale_2(self._ff_block(x)))
+ return x
+
+
+class StreamingTransformer(StreamingModule):
+ """Transformer with Streaming / Causal support.
+
+ Args:
+ d_model (int): Dimension of the data.
+ num_heads (int): Number of heads.
+ dim_feedforward (int): Intermediate dimension of FF module.
+ dropout (float): Dropout both for MHA and FF.
+ bias_ff (bool): Use bias for FF.
+ bias_attn (bool): Use bias for MHA.
+ causal (bool): Causal mask applied automatically.
+ past_context (int or None): Receptive field for the causal mask, infinite if None.
+ custom (bool): Use custom MHA implementation, for testing / benchmarking.
+ memory_efficient (bool): Use xformers based memory efficient attention.
+ attention_as_float32 (bool): Perform the attention as float32
+ (especially important with memory_efficient as autocast won't do this automatically).
+ cross_attention (bool): If True, expect to get secondary input for cross-attention.
+ layer_scale (float or None): If not None, LayerScale will be used
+ with the given value as initial scale.
+ positional_embedding (str): Positional embedding strategy (sin, rope, or sin_rope).
+ max_period (float): Maximum period of the time embedding.
+ positional_scale (float): Scale of positional embedding, set to 0 to deactivate.
+ xpos (bool): Apply xpos exponential decay to positional embedding (rope only).
+ lr (float or None): learning rate override through the `make_optim_group` API.
+ weight_decay (float or None): Weight_decay override through the `make_optim_group` API.
+ layer_class: (subclass of `StreamingTransformerLayer): class to use
+ to initialize the layers, allowing further customization outside of Audiocraft.
+ checkpointing (str): Checkpointing strategy to reduce memory usage.
+ No checkpointing if set to 'none'. Per layer checkpointing using PyTorch
+ if set to 'torch' (entire layer checkpointed, i.e. linears are evaluated twice,
+ minimal memory usage, but maximal runtime). Finally, `xformers_default` provide
+ a policy for opting-out some operations of the checkpointing like
+ linear layers and attention, providing a middle ground between speed and memory.
+ device (torch.device or None): Device on which to initialize.
+ dtype (torch.dtype or None): dtype to use.
+ **kwargs: See `nn.TransformerEncoderLayer`.
+ """
+ def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_feedforward: int = 2048,
+ dropout: float = 0.1, bias_ff: bool = True, bias_attn: bool = True,
+ causal: bool = False, past_context: tp.Optional[int] = None,
+ custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False,
+ cross_attention: bool = False, layer_scale: tp.Optional[float] = None,
+ positional_embedding: str = 'sin', max_period: float = 10_000, positional_scale: float = 1.,
+ xpos: bool = False, lr: tp.Optional[float] = None, weight_decay: tp.Optional[float] = None,
+ layer_class: tp.Type[StreamingTransformerLayer] = StreamingTransformerLayer,
+ checkpointing: str = 'none', device=None, dtype=None, **kwargs):
+ super().__init__()
+ assert d_model % num_heads == 0
+
+ self.positional_embedding = positional_embedding
+ self.max_period = max_period
+ self.positional_scale = positional_scale
+ self.weight_decay = weight_decay
+ self.lr = lr
+
+ assert positional_embedding in ['sin', 'rope', 'sin_rope']
+ self.rope: tp.Optional[RotaryEmbedding] = None
+ if self.positional_embedding in ['rope', 'sin_rope']:
+ assert _is_custom(custom, memory_efficient)
+ self.rope = RotaryEmbedding(d_model // num_heads, max_period=max_period,
+ xpos=xpos, scale=positional_scale, device=device)
+
+ self.checkpointing = checkpointing
+
+ assert checkpointing in ['none', 'torch', 'xformers_default', 'xformers_mm']
+ if self.checkpointing.startswith('xformers'):
+ _verify_xformers_internal_compat()
+
+ self.layers = nn.ModuleList()
+ for idx in range(num_layers):
+ self.layers.append(
+ layer_class(
+ d_model=d_model, num_heads=num_heads, dim_feedforward=dim_feedforward,
+ dropout=dropout, bias_ff=bias_ff, bias_attn=bias_attn,
+ causal=causal, past_context=past_context, custom=custom,
+ memory_efficient=memory_efficient, attention_as_float32=attention_as_float32,
+ cross_attention=cross_attention, layer_scale=layer_scale, rope=self.rope,
+ device=device, dtype=dtype, **kwargs))
+
+ if self.checkpointing != 'none':
+ for layer in self.layers:
+ # see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the
+ # backward hook inside of FSDP...
+ layer._magma_checkpointed = True # type: ignore
+ assert layer.layer_drop == 0., "Need further checking" # type: ignore
+
+ def _apply_layer(self, layer, *args, **kwargs):
+ method = self.checkpointing
+ if method == 'none':
+ return layer(*args, **kwargs)
+ elif method == 'torch':
+ return torch_checkpoint(layer, *args, use_reentrant=False, **kwargs)
+ elif method.startswith('xformers'):
+ from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy
+ if method == 'xformers_default':
+ # those operations will be saved, and not recomputed.
+ # According to Francisco we can get smarter policies but this is a good start.
+ allow_list = [
+ "xformers.efficient_attention_forward_cutlass.default",
+ "xformers_flash.flash_fwd.default",
+ "aten.addmm.default",
+ "aten.mm.default",
+ ]
+ elif method == 'xformers_mm':
+ # those operations will be saved, and not recomputed.
+ # According to Francisco we can get smarter policies but this is a good start.
+ allow_list = [
+ "aten.addmm.default",
+ "aten.mm.default",
+ ]
+ else:
+ raise ValueError(f"xformers checkpointing xformers policy {method} is not known.")
+ policy_fn = _get_default_policy(allow_list)
+ return checkpoint(layer, *args, policy_fn=policy_fn, **kwargs)
+ else:
+ raise ValueError(f"Checkpointing method {method} is unknown.")
+
+ def forward(self, x: torch.Tensor, *args, **kwargs):
+ B, T, C = x.shape
+
+ if 'offsets' in self._streaming_state:
+ offsets = self._streaming_state['offsets']
+ else:
+ offsets = torch.zeros(B, dtype=torch.long, device=x.device)
+
+ if self.positional_embedding in ['sin', 'sin_rope']:
+ positions = torch.arange(T, device=x.device).view(1, -1, 1)
+ positions = positions + offsets.view(-1, 1, 1)
+ pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
+ x = x + self.positional_scale * pos_emb
+
+ for layer in self.layers:
+ x = self._apply_layer(layer, x, *args, **kwargs)
+
+ if self._is_streaming:
+ self._streaming_state['offsets'] = offsets + T
+
+ return x
+
+ def make_optim_group(self):
+ group = {"params": list(self.parameters())}
+ if self.lr is not None:
+ group["lr"] = self.lr
+ if self.weight_decay is not None:
+ group["weight_decay"] = self.weight_decay
+ return group
+
+
+# special attention attention related function
+
+def _verify_xformers_memory_efficient_compat():
+ try:
+ from xformers.ops import memory_efficient_attention, LowerTriangularMask # noqa
+ except ImportError:
+ raise ImportError(
+ "xformers is not installed. Please install it and try again.\n"
+ "To install on AWS and Azure, run \n"
+ "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n"
+ "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n"
+ "To install on FAIR Cluster, run \n"
+ "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n"
+ "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n")
+
+
+def _verify_xformers_internal_compat():
+ try:
+ from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy # noqa
+ except ImportError:
+ raise ImportError(
+ "Francisco's fairinternal xformers is not installed. Please install it and try again.\n"
+ "To install on AWS and Azure, run \n"
+ "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n"
+ "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n"
+ "To install on FAIR Cluster, run \n"
+ "FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n"
+ "pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n")
+
+
+def _is_custom(custom: bool, memory_efficient: bool):
+ return custom or memory_efficient
diff --git a/audiocraft/py.typed b/audiocraft/py.typed
new file mode 100644
index 0000000..e69de29
diff --git a/audiocraft/quantization/__init__.py b/audiocraft/quantization/__init__.py
new file mode 100644
index 0000000..836d6eb
--- /dev/null
+++ b/audiocraft/quantization/__init__.py
@@ -0,0 +1,9 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# flake8: noqa
+from .vq import ResidualVectorQuantizer
+from .base import BaseQuantizer, DummyQuantizer, QuantizedResult
diff --git a/audiocraft/quantization/base.py b/audiocraft/quantization/base.py
new file mode 100644
index 0000000..1b16c13
--- /dev/null
+++ b/audiocraft/quantization/base.py
@@ -0,0 +1,107 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Base class for all quantizers.
+"""
+
+from dataclasses import dataclass, field
+import typing as tp
+
+import torch
+from torch import nn
+
+
+@dataclass
+class QuantizedResult:
+ x: torch.Tensor
+ codes: torch.Tensor
+ bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
+ penalty: tp.Optional[torch.Tensor] = None
+ metrics: dict = field(default_factory=dict)
+
+
+class BaseQuantizer(nn.Module):
+ """Base class for quantizers.
+ """
+
+ def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult:
+ """
+ Given input tensor x, returns first the quantized (or approximately quantized)
+ representation along with quantized codes, bandwidth, and any penalty term for the loss.
+ Finally, this returns a dict of metrics to update logging etc.
+ Frame rate must be passed so that the bandwidth is properly computed.
+ """
+ raise NotImplementedError()
+
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
+ """Encode a given input tensor with the specified sample rate at the given bandwidth.
+ """
+ raise NotImplementedError()
+
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
+ """Decode the given codes to the quantized representation.
+ """
+ raise NotImplementedError()
+
+ @property
+ def total_codebooks(self):
+ """Total number of codebooks.
+ """
+ raise NotImplementedError()
+
+ @property
+ def num_codebooks(self):
+ """Number of active codebooks.
+ """
+ raise NotImplementedError()
+
+ def set_num_codebooks(self, n: int):
+ """Set the number of active codebooks.
+ """
+ raise NotImplementedError()
+
+
+class DummyQuantizer(BaseQuantizer):
+ """Fake quantizer that actually does not perform any quantization.
+ """
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x: torch.Tensor, frame_rate: int):
+ q = x.unsqueeze(1)
+ return QuantizedResult(x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x))
+
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
+ """Encode a given input tensor with the specified sample rate at the given bandwidth.
+ In the case of the DummyQuantizer, the codes are actually identical
+ to the input and resulting quantized representation as no quantization is done.
+ """
+ return x.unsqueeze(1)
+
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
+ """Decode the given codes to the quantized representation.
+ In the case of the DummyQuantizer, the codes are actually identical
+ to the input and resulting quantized representation as no quantization is done.
+ """
+ return codes.squeeze(1)
+
+ @property
+ def total_codebooks(self):
+ """Total number of codebooks.
+ """
+ return 1
+
+ @property
+ def num_codebooks(self):
+ """Total number of codebooks.
+ """
+ return self.total_codebooks
+
+ def set_num_codebooks(self, n: int):
+ """Set the number of active codebooks.
+ """
+ raise AttributeError("Cannot override the number of codebooks for the dummy quantizer")
diff --git a/audiocraft/quantization/core_vq.py b/audiocraft/quantization/core_vq.py
new file mode 100644
index 0000000..e1896bb
--- /dev/null
+++ b/audiocraft/quantization/core_vq.py
@@ -0,0 +1,400 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import typing as tp
+
+from einops import rearrange, repeat
+import flashy
+import torch
+from torch import nn, einsum
+import torch.nn.functional as F
+
+
+def exists(val: tp.Optional[tp.Any]) -> bool:
+ return val is not None
+
+
+def default(val: tp.Any, d: tp.Any) -> tp.Any:
+ return val if exists(val) else d
+
+
+def l2norm(t):
+ return F.normalize(t, p=2, dim=-1)
+
+
+def ema_inplace(moving_avg, new, decay: float):
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
+
+
+def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
+ return (x + epsilon) / (x.sum() + n_categories * epsilon)
+
+
+def uniform_init(*shape: int):
+ t = torch.empty(shape)
+ nn.init.kaiming_uniform_(t)
+ return t
+
+
+def sample_vectors(samples, num: int):
+ num_samples, device = samples.shape[0], samples.device
+
+ if num_samples >= num:
+ indices = torch.randperm(num_samples, device=device)[:num]
+ else:
+ indices = torch.randint(0, num_samples, (num,), device=device)
+
+ return samples[indices]
+
+
+def kmeans(samples, num_clusters: int, num_iters: int = 10):
+ dim, dtype = samples.shape[-1], samples.dtype
+
+ means = sample_vectors(samples, num_clusters)
+
+ for _ in range(num_iters):
+ diffs = rearrange(samples, "n d -> n () d") - rearrange(
+ means, "c d -> () c d"
+ )
+ dists = -(diffs ** 2).sum(dim=-1)
+
+ buckets = dists.max(dim=-1).indices
+ bins = torch.bincount(buckets, minlength=num_clusters)
+ zero_mask = bins == 0
+ bins_min_clamped = bins.masked_fill(zero_mask, 1)
+
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
+ new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
+ new_means = new_means / bins_min_clamped[..., None]
+
+ means = torch.where(zero_mask[..., None], means, new_means)
+
+ return means, bins
+
+
+def orthgonal_loss_fn(t):
+ # eq (2) from https://arxiv.org/abs/2112.00384
+ n = t.shape[0]
+ normed_codes = l2norm(t)
+ identity = torch.eye(n, device=t.device)
+ cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes)
+ return ((cosine_sim - identity) ** 2).sum() / (n ** 2)
+
+
+class EuclideanCodebook(nn.Module):
+ """Codebook with Euclidean distance.
+
+ Args:
+ dim (int): Dimension.
+ codebook_size (int): Codebook size.
+ kmeans_init (bool): Whether to use k-means to initialize the codebooks.
+ If set to true, run the k-means algorithm on the first training batch and use
+ the learned centroids as initialization.
+ kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
+ decay (float): Decay for exponential moving average over the codebooks.
+ epsilon (float): Epsilon value for numerical stability.
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
+ that have an exponential moving average cluster size less than the specified threshold with
+ randomly selected vector from the current batch.
+ """
+ def __init__(
+ self,
+ dim: int,
+ codebook_size: int,
+ kmeans_init: int = False,
+ kmeans_iters: int = 10,
+ decay: float = 0.8,
+ epsilon: float = 1e-5,
+ threshold_ema_dead_code: int = 2,
+ ):
+ super().__init__()
+ self.decay = decay
+ init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
+ embed = init_fn(codebook_size, dim)
+
+ self.codebook_size = codebook_size
+
+ self.kmeans_iters = kmeans_iters
+ self.epsilon = epsilon
+ self.threshold_ema_dead_code = threshold_ema_dead_code
+
+ self.register_buffer("inited", torch.Tensor([not kmeans_init]))
+ self.register_buffer("cluster_size", torch.zeros(codebook_size))
+ self.register_buffer("embed", embed)
+ self.register_buffer("embed_avg", embed.clone())
+
+ @torch.jit.ignore
+ def init_embed_(self, data):
+ if self.inited:
+ return
+
+ embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
+ self.embed.data.copy_(embed)
+ self.embed_avg.data.copy_(embed.clone())
+ self.cluster_size.data.copy_(cluster_size)
+ self.inited.data.copy_(torch.Tensor([True]))
+ # Make sure all buffers across workers are in sync after initialization
+ flashy.distrib.broadcast_tensors(self.buffers())
+
+ def replace_(self, samples, mask):
+ modified_codebook = torch.where(
+ mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
+ )
+ self.embed.data.copy_(modified_codebook)
+
+ def expire_codes_(self, batch_samples):
+ if self.threshold_ema_dead_code == 0:
+ return
+
+ expired_codes = self.cluster_size < self.threshold_ema_dead_code
+ if not torch.any(expired_codes):
+ return
+
+ batch_samples = rearrange(batch_samples, "... d -> (...) d")
+ self.replace_(batch_samples, mask=expired_codes)
+ flashy.distrib.broadcast_tensors(self.buffers())
+
+ def preprocess(self, x):
+ x = rearrange(x, "... d -> (...) d")
+ return x
+
+ def quantize(self, x):
+ embed = self.embed.t()
+ dist = -(
+ x.pow(2).sum(1, keepdim=True)
+ - 2 * x @ embed
+ + embed.pow(2).sum(0, keepdim=True)
+ )
+ embed_ind = dist.max(dim=-1).indices
+ return embed_ind
+
+ def postprocess_emb(self, embed_ind, shape):
+ return embed_ind.view(*shape[:-1])
+
+ def dequantize(self, embed_ind):
+ quantize = F.embedding(embed_ind, self.embed)
+ return quantize
+
+ def encode(self, x):
+ shape = x.shape
+ # pre-process
+ x = self.preprocess(x)
+ # quantize
+ embed_ind = self.quantize(x)
+ # post-process
+ embed_ind = self.postprocess_emb(embed_ind, shape)
+ return embed_ind
+
+ def decode(self, embed_ind):
+ quantize = self.dequantize(embed_ind)
+ return quantize
+
+ def forward(self, x):
+ shape, dtype = x.shape, x.dtype
+ x = self.preprocess(x)
+ self.init_embed_(x)
+
+ embed_ind = self.quantize(x)
+ embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
+ embed_ind = self.postprocess_emb(embed_ind, shape)
+ quantize = self.dequantize(embed_ind)
+
+ if self.training:
+ # We do the expiry of code at that point as buffers are in sync
+ # and all the workers will take the same decision.
+ self.expire_codes_(x)
+ ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
+ embed_sum = x.t() @ embed_onehot
+ ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
+ cluster_size = (
+ laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
+ * self.cluster_size.sum()
+ )
+ embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
+ self.embed.data.copy_(embed_normalized)
+
+ return quantize, embed_ind
+
+
+class VectorQuantization(nn.Module):
+ """Vector quantization implementation.
+ Currently supports only euclidean distance.
+
+ Args:
+ dim (int): Dimension
+ codebook_size (int): Codebook size
+ codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
+ decay (float): Decay for exponential moving average over the codebooks.
+ epsilon (float): Epsilon value for numerical stability.
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
+ threshold_ema_dead_code (int):
+ channels_last (bool): Channels are the last dimension in the input tensors.
+ commitment_weight (float): Weight for commitment loss.
+ orthogonal_reg_weight (float): Orthogonal regularization weights.
+ orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
+ orthogonal_reg_max_codes (optional int): Maximum number of codes to consider
+ for orthogonal regulariation.
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
+ that have an exponential moving average cluster size less than the specified threshold with
+ randomly selected vector from the current batch.
+ """
+ def __init__(
+ self,
+ dim: int,
+ codebook_size: int,
+ codebook_dim: tp.Optional[int] = None,
+ decay: float = 0.8,
+ epsilon: float = 1e-5,
+ kmeans_init: bool = False,
+ kmeans_iters: int = 10,
+ threshold_ema_dead_code: int = 2,
+ channels_last: bool = False,
+ commitment_weight: float = 1.,
+ orthogonal_reg_weight: float = 0.0,
+ orthogonal_reg_active_codes_only: bool = False,
+ orthogonal_reg_max_codes: tp.Optional[int] = None,
+ ):
+ super().__init__()
+ _codebook_dim: int = default(codebook_dim, dim)
+
+ requires_projection = _codebook_dim != dim
+ self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
+ self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
+
+ self.epsilon = epsilon
+ self.commitment_weight = commitment_weight
+
+ self.orthogonal_reg_weight = orthogonal_reg_weight
+ self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
+ self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
+
+ self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
+ kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
+ decay=decay, epsilon=epsilon,
+ threshold_ema_dead_code=threshold_ema_dead_code)
+ self.codebook_size = codebook_size
+
+ self.channels_last = channels_last
+
+ @property
+ def codebook(self):
+ return self._codebook.embed
+
+ @property
+ def inited(self):
+ return self._codebook.inited
+
+ def _preprocess(self, x):
+ if not self.channels_last:
+ x = rearrange(x, "b d n -> b n d")
+ return x
+
+ def _postprocess(self, quantize):
+ if not self.channels_last:
+ quantize = rearrange(quantize, "b n d -> b d n")
+ return quantize
+
+ def encode(self, x):
+ x = self._preprocess(x)
+ x = self.project_in(x)
+ embed_in = self._codebook.encode(x)
+ return embed_in
+
+ def decode(self, embed_ind):
+ quantize = self._codebook.decode(embed_ind)
+ quantize = self.project_out(quantize)
+ quantize = self._postprocess(quantize)
+ return quantize
+
+ def forward(self, x):
+ device = x.device
+ x = self._preprocess(x)
+
+ x = self.project_in(x)
+ quantize, embed_ind = self._codebook(x)
+
+ if self.training:
+ quantize = x + (quantize - x).detach()
+
+ loss = torch.tensor([0.0], device=device, requires_grad=self.training)
+
+ if self.training:
+ if self.commitment_weight > 0:
+ commit_loss = F.mse_loss(quantize.detach(), x)
+ loss = loss + commit_loss * self.commitment_weight
+
+ if self.orthogonal_reg_weight > 0:
+ codebook = self.codebook
+
+ if self.orthogonal_reg_active_codes_only:
+ # only calculate orthogonal loss for the activated codes for this batch
+ unique_code_ids = torch.unique(embed_ind)
+ codebook = codebook[unique_code_ids]
+
+ num_codes = codebook.shape[0]
+ if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes:
+ rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes]
+ codebook = codebook[rand_ids]
+
+ orthogonal_reg_loss = orthgonal_loss_fn(codebook)
+ loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
+
+ quantize = self.project_out(quantize)
+ quantize = self._postprocess(quantize)
+
+ return quantize, embed_ind, loss
+
+
+class ResidualVectorQuantization(nn.Module):
+ """Residual vector quantization implementation.
+
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
+ """
+ def __init__(self, *, num_quantizers, **kwargs):
+ super().__init__()
+ self.layers = nn.ModuleList(
+ [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
+ )
+
+ def forward(self, x, n_q: tp.Optional[int] = None):
+ quantized_out = 0.0
+ residual = x
+
+ all_losses = []
+ all_indices = []
+
+ n_q = n_q or len(self.layers)
+
+ for i, layer in enumerate(self.layers[:n_q]):
+ quantized, indices, loss = layer(residual)
+ residual = residual - quantized
+ quantized_out = quantized_out + quantized
+ all_indices.append(indices)
+ all_losses.append(loss)
+
+ out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
+ return quantized_out, out_indices, out_losses
+
+ def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
+ residual = x
+ all_indices = []
+ n_q = n_q or len(self.layers)
+ for layer in self.layers[:n_q]:
+ indices = layer.encode(residual)
+ quantized = layer.decode(indices)
+ residual = residual - quantized
+ all_indices.append(indices)
+ out_indices = torch.stack(all_indices)
+ return out_indices
+
+ def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
+ quantized_out = torch.tensor(0.0, device=q_indices.device)
+ for i, indices in enumerate(q_indices):
+ layer = self.layers[i]
+ quantized = layer.decode(indices)
+ quantized_out = quantized_out + quantized
+ return quantized_out
diff --git a/audiocraft/quantization/vq.py b/audiocraft/quantization/vq.py
new file mode 100644
index 0000000..f67c3a0
--- /dev/null
+++ b/audiocraft/quantization/vq.py
@@ -0,0 +1,116 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import typing as tp
+
+import torch
+
+from .base import BaseQuantizer, QuantizedResult
+from .core_vq import ResidualVectorQuantization
+
+
+class ResidualVectorQuantizer(BaseQuantizer):
+ """Residual Vector Quantizer.
+
+ Args:
+ dimension (int): Dimension of the codebooks.
+ n_q (int): Number of residual vector quantizers used.
+ q_dropout (bool): Random quantizer drop out at train time.
+ bins (int): Codebook size.
+ decay (float): Decay for exponential moving average over the codebooks.
+ kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
+ kmeans_iters (int): Number of iterations used for kmeans initialization.
+ threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
+ that have an exponential moving average cluster size less than the specified threshold with
+ randomly selected vector from the current batch.
+ orthogonal_reg_weight (float): Orthogonal regularization weights.
+ orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
+ orthogonal_reg_max_codes (optional int): Maximum number of codes to consider.
+ for orthogonal regulariation.
+ """
+ def __init__(
+ self,
+ dimension: int = 256,
+ n_q: int = 8,
+ q_dropout: bool = False,
+ bins: int = 1024,
+ decay: float = 0.99,
+ kmeans_init: bool = True,
+ kmeans_iters: int = 10,
+ threshold_ema_dead_code: int = 2,
+ orthogonal_reg_weight: float = 0.0,
+ orthogonal_reg_active_codes_only: bool = False,
+ orthogonal_reg_max_codes: tp.Optional[int] = None,
+ ):
+ super().__init__()
+ self.max_n_q = n_q
+ self.n_q = n_q
+ self.q_dropout = q_dropout
+ self.dimension = dimension
+ self.bins = bins
+ self.decay = decay
+ self.kmeans_init = kmeans_init
+ self.kmeans_iters = kmeans_iters
+ self.threshold_ema_dead_code = threshold_ema_dead_code
+ self.orthogonal_reg_weight = orthogonal_reg_weight
+ self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
+ self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
+ self.vq = ResidualVectorQuantization(
+ dim=self.dimension,
+ codebook_size=self.bins,
+ num_quantizers=self.n_q,
+ decay=self.decay,
+ kmeans_init=self.kmeans_init,
+ kmeans_iters=self.kmeans_iters,
+ threshold_ema_dead_code=self.threshold_ema_dead_code,
+ orthogonal_reg_weight=self.orthogonal_reg_weight,
+ orthogonal_reg_active_codes_only=self.orthogonal_reg_active_codes_only,
+ orthogonal_reg_max_codes=self.orthogonal_reg_max_codes,
+ channels_last=False
+ )
+
+ def forward(self, x: torch.Tensor, frame_rate: int):
+ n_q = self.n_q
+ if self.training and self.q_dropout:
+ n_q = int(torch.randint(1, self.n_q + 1, (1,)).item())
+ bw_per_q = math.log2(self.bins) * frame_rate / 1000
+ quantized, codes, commit_loss = self.vq(x, n_q=n_q)
+ codes = codes.transpose(0, 1)
+ # codes is [B, K, T], with T frames, K nb of codebooks.
+ bw = torch.tensor(n_q * bw_per_q).to(x)
+ return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
+
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
+ """Encode a given input tensor with the specified frame rate at the given bandwidth.
+ The RVQ encode method sets the appropriate number of quantizer to use
+ and returns indices for each quantizer.
+ """
+ n_q = self.n_q
+ codes = self.vq.encode(x, n_q=n_q)
+ codes = codes.transpose(0, 1)
+ # codes is [B, K, T], with T frames, K nb of codebooks.
+ return codes
+
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
+ """Decode the given codes to the quantized representation.
+ """
+ # codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T].
+ codes = codes.transpose(0, 1)
+ quantized = self.vq.decode(codes)
+ return quantized
+
+ @property
+ def total_codebooks(self):
+ return self.max_n_q
+
+ @property
+ def num_codebooks(self):
+ return self.n_q
+
+ def set_num_codebooks(self, n: int):
+ assert n > 0 and n <= self.max_n_q
+ self.n_q = n
diff --git a/audiocraft/utils/__init__.py b/audiocraft/utils/__init__.py
new file mode 100644
index 0000000..0952fcc
--- /dev/null
+++ b/audiocraft/utils/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/audiocraft/utils/autocast.py b/audiocraft/utils/autocast.py
new file mode 100644
index 0000000..ed64484
--- /dev/null
+++ b/audiocraft/utils/autocast.py
@@ -0,0 +1,40 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+
+class TorchAutocast:
+ """TorchAutocast utility class.
+ Allows you to enable and disable autocast. This is specially useful
+ when dealing with different architectures and clusters with different
+ levels of support.
+
+ Args:
+ enabled (bool): Whether to enable torch.autocast or not.
+ args: Additional args for torch.autocast.
+ kwargs: Additional kwargs for torch.autocast
+ """
+ def __init__(self, enabled: bool, *args, **kwargs):
+ self.autocast = torch.autocast(*args, **kwargs) if enabled else None
+
+ def __enter__(self):
+ if self.autocast is None:
+ return
+ try:
+ self.autocast.__enter__()
+ except RuntimeError:
+ device = self.autocast.device
+ dtype = self.autocast.fast_dtype
+ raise RuntimeError(
+ f"There was an error autocasting with dtype={dtype} device={device}\n"
+ "If you are on the FAIR Cluster, you might need to use autocast_dtype=float16"
+ )
+
+ def __exit__(self, *args, **kwargs):
+ if self.autocast is None:
+ return
+ self.autocast.__exit__(*args, **kwargs)
diff --git a/audiocraft/utils/export.py b/audiocraft/utils/export.py
new file mode 100644
index 0000000..b513b52
--- /dev/null
+++ b/audiocraft/utils/export.py
@@ -0,0 +1,56 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+"""
+Utility to export a training checkpoint to a lightweight release checkpoint.
+"""
+
+from pathlib import Path
+import typing as tp
+
+from omegaconf import OmegaConf, DictConfig
+import torch
+
+
+def _clean_lm_cfg(cfg: DictConfig):
+ OmegaConf.set_struct(cfg, False)
+ # This used to be set automatically in the LM solver, need a more robust solution
+ # for the future.
+ cfg['transformer_lm']['card'] = 2048
+ cfg['transformer_lm']['n_q'] = 4
+ # Experimental params no longer supported.
+ bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters',
+ 'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop']
+ for name in bad_params:
+ del cfg['transformer_lm'][name]
+ OmegaConf.set_struct(cfg, True)
+ return cfg
+
+
+def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
+ sig = Path(checkpoint_path).parent.name
+ assert len(sig) == 8, "Not a valid Dora signature"
+ pkg = torch.load(checkpoint_path, 'cpu')
+ new_pkg = {
+ 'best_state': pkg['ema']['state']['model'],
+ 'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
+ }
+ out_file = Path(out_folder) / f'{sig}.th'
+ torch.save(new_pkg, out_file)
+ return out_file
+
+
+def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
+ sig = Path(checkpoint_path).parent.name
+ assert len(sig) == 8, "Not a valid Dora signature"
+ pkg = torch.load(checkpoint_path, 'cpu')
+ new_pkg = {
+ 'best_state': pkg['fsdp_best_state']['model'],
+ 'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg']))
+ }
+ out_file = Path(out_folder) / f'{sig}.th'
+ torch.save(new_pkg, out_file)
+ return out_file
diff --git a/audiocraft/utils/notebook.py b/audiocraft/utils/notebook.py
new file mode 100644
index 0000000..019b9d1
--- /dev/null
+++ b/audiocraft/utils/notebook.py
@@ -0,0 +1,32 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+try:
+ import IPython.display as ipd # type: ignore
+except ImportError:
+ # Note in a notebook...
+ pass
+
+
+import torch
+
+
+def display_audio(samples: torch.Tensor, sample_rate: int):
+ """Renders an audio player for the given audio samples.
+
+ Args:
+ samples (torch.Tensor): a Tensor of decoded audio samples
+ with shapes [B, C, T] or [C, T]
+ sample_rate (int): sample rate audio should be displayed with.
+ """
+ assert samples.dim() == 2 or samples.dim() == 3
+
+ samples = samples.detach().cpu()
+ if samples.dim() == 2:
+ samples = samples[None, ...]
+
+ for audio in samples:
+ ipd.display(ipd.Audio(audio, rate=sample_rate))
diff --git a/audiocraft/utils/utils.py b/audiocraft/utils/utils.py
new file mode 100644
index 0000000..86e1448
--- /dev/null
+++ b/audiocraft/utils/utils.py
@@ -0,0 +1,234 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from concurrent.futures import ProcessPoolExecutor
+from functools import wraps
+import hashlib
+import logging
+import typing as tp
+
+import flashy
+import flashy.distrib
+import omegaconf
+import torch
+from torch.nn.utils.rnn import pad_sequence
+
+
+logger = logging.getLogger(__name__)
+
+
+def dict_from_config(cfg: omegaconf.DictConfig) -> dict:
+ """Convenience function to map an omegaconf configuration to a dictionary.
+
+ Args:
+ cfg (omegaconf.DictConfig): Original configuration to map to dict.
+ Returns:
+ dict: Config as dictionary object.
+ """
+ dct = omegaconf.OmegaConf.to_container(cfg, resolve=True)
+ assert isinstance(dct, dict)
+ return dct
+
+
+def random_subset(dataset, max_samples: int, seed: int = 42) -> torch.utils.data.Subset:
+ if max_samples >= len(dataset):
+ return dataset
+
+ generator = torch.Generator().manual_seed(seed)
+ perm = torch.randperm(len(dataset), generator=generator)
+ return torch.utils.data.Subset(dataset, perm[:max_samples].tolist())
+
+
+def get_loader(dataset, num_samples: tp.Optional[int], batch_size: int,
+ num_workers: int, seed: int, **kwargs) -> torch.utils.data.DataLoader:
+ """Convenience function to load dataset into a dataloader with optional subset sampling.
+
+ Args:
+ dataset: Dataset to load.
+ num_samples (Optional[int]): Number of samples to limit subset size.
+ batch_size (int): Batch size.
+ num_workers (int): Number of workers for data loading.
+ seed (int): Random seed.
+ """
+ if num_samples is not None:
+ dataset = random_subset(dataset, num_samples, seed)
+
+ dataloader = flashy.distrib.loader(
+ dataset,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ **kwargs
+ )
+ return dataloader
+
+
+def get_dataset_from_loader(dataloader):
+ dataset = dataloader.dataset
+ if isinstance(dataset, torch.utils.data.Subset):
+ return dataset.dataset
+ else:
+ return dataset
+
+
+def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
+ """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
+
+ Args:
+ input (torch.Tensor): The input tensor containing probabilities.
+ num_samples (int): Number of samples to draw.
+ replacement (bool): Whether to draw with replacement or not.
+ Keywords args:
+ generator (torch.Generator): A pseudorandom number generator for sampling.
+ Returns:
+ torch.Tensor: Last dimension contains num_samples indices
+ sampled from the multinomial probability distribution
+ located in the last dimension of tensor input.
+ """
+ input_ = input.reshape(-1, input.shape[-1])
+ output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
+ output = output_.reshape(*list(input.shape[:-1]), -1)
+ return output
+
+
+def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
+ """Sample next token from top K values along the last dimension of the input probs tensor.
+
+ Args:
+ probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
+ k (int): The k in βtop-kβ.
+ Returns:
+ torch.Tensor: Sampled tokens.
+ """
+ top_k_value, _ = torch.topk(probs, k, dim=-1)
+ min_value_top_k = top_k_value[..., [-1]]
+ probs *= (probs >= min_value_top_k).float()
+ probs.div_(probs.sum(dim=-1, keepdim=True))
+ next_token = multinomial(probs, num_samples=1)
+ return next_token
+
+
+def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
+ """Sample next token from top P probabilities along the last dimension of the input probs tensor.
+
+ Args:
+ probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
+ p (int): The p in βtop-pβ.
+ Returns:
+ torch.Tensor: Sampled tokens.
+ """
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
+ mask = probs_sum - probs_sort > p
+ probs_sort *= (~mask).float()
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
+ next_token = multinomial(probs_sort, num_samples=1)
+ next_token = torch.gather(probs_idx, -1, next_token)
+ return next_token
+
+
+class DummyPoolExecutor:
+ """Dummy pool executor to use when we actually have only 1 worker.
+ (e.g. instead of ProcessPoolExecutor).
+ """
+ class DummyResult:
+ def __init__(self, func, *args, **kwargs):
+ self.func = func
+ self.args = args
+ self.kwargs = kwargs
+
+ def result(self):
+ return self.func(*self.args, **self.kwargs)
+
+ def __init__(self, workers, mp_context=None):
+ pass
+
+ def submit(self, func, *args, **kwargs):
+ return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, exc_tb):
+ return
+
+
+def get_pool_executor(num_workers: int, mp_context=None):
+ return ProcessPoolExecutor(num_workers, mp_context) if num_workers > 1 else DummyPoolExecutor(1)
+
+
+def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> torch.Tensor:
+ """Utility function to convert a tensor of sequence lengths to a mask (useful when working on padded sequences).
+ For example: [3, 5] => [[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]]
+
+ Args:
+ lengths (torch.Tensor): tensor with lengths
+ max_len (int): can set the max length manually. Defaults to None.
+ Returns:
+ torch.Tensor: mask with 0s where there is pad tokens else 1s
+ """
+ assert len(lengths.shape) == 1, "Length shape should be 1 dimensional."
+ final_length = lengths.max().item() if not max_len else max_len
+ final_length = max(final_length, 1) # if all seqs are of len zero we don't want a zero-size tensor
+ return torch.arange(final_length)[None, :].to(lengths.device) < lengths[:, None]
+
+
+def hash_trick(word: str, vocab_size: int) -> int:
+ """Hash trick to pair each word with an index
+
+ Args:
+ word (str): word we wish to convert to an index
+ vocab_size (int): size of the vocabulary
+ Returns:
+ int: index of the word in the embedding LUT
+ """
+ hash = int(hashlib.sha256(word.encode("utf-8")).hexdigest(), 16)
+ return hash % vocab_size
+
+
+def with_rank_rng(base_seed: int = 1234):
+ """Decorator for a function so that the function will use a Random Number Generator
+ whose state depend on the GPU rank. The original RNG state is restored upon returning.
+
+ Args:
+ base_seed (int): Random seed.
+ """
+ def _decorator(fun: tp.Callable):
+ @wraps(fun)
+ def _decorated(*args, **kwargs):
+ state = torch.get_rng_state()
+ seed = base_seed ^ flashy.distrib.rank()
+ torch.manual_seed(seed)
+ logger.debug('Rank dependent seed set to %d', seed)
+ try:
+ return fun(*args, **kwargs)
+ finally:
+ torch.set_rng_state(state)
+ logger.debug('RNG state restored.')
+ return _decorated
+ return _decorator
+
+
+def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+ """Get a list of tensors and collate them to a single tensor. according to the following logic:
+ - `dim` specifies the time dimension which will be stacked and padded.
+ - The output will contain 1 new dimension (dimension index 0) which will be the size of
+ of the original list.
+
+ Args:
+ tensors (tp.List[torch.Tensor]): List of tensors to collate.
+ dim (int): Dimension which will be stacked and padded.
+ Returns:
+ tp.Tuple[torch.Tensor, torch.Tensor]:
+ torch.Tensor: Stacked and padded tensor. The output will contain 1 new dimension
+ (dimension index 0) which will be the size of the original list.
+ torch.Tensor: Tensor containing length of original tensor sizes (without padding).
+ """
+ tensors = [x.transpose(0, dim) for x in tensors]
+ lens = torch.LongTensor([len(x) for x in tensors])
+ padded_tensors = pad_sequence(tensors)
+ padded_tensors = padded_tensors.transpose(0, 1)
+ padded_tensors = padded_tensors.transpose(1, dim + 1)
+ return padded_tensors, lens
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..aa3fa0d
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,20 @@
+# please make sure you have already a pytorch install that is cuda enabled!
+av
+einops
+flashy>=0.0.1
+hydra-core>=1.1
+hydra_colorlog
+julius
+num2words
+numpy
+sentencepiece
+spacy==3.5.2
+torch>=2.0.0
+torchaudio>=2.0.0
+huggingface_hub
+tqdm
+transformers
+xformers
+demucs
+librosa
+gradio