music-gen
Build-Deploy-Actions Details

This commit is contained in:
songw 2023-06-16 15:24:12 +08:00
parent e619938f16
commit 7b6c296b49
38 changed files with 7075 additions and 0 deletions

View File

@ -0,0 +1,47 @@
name: Build
run-name: ${{ github.actor }} is upgrade release 🚀
on: [push]
env:
REPOSITORY: ${{ github.repository }}
COMMIT_ID: ${{ github.sha }}
jobs:
Build-Deploy-Actions:
runs-on: ubuntu-latest
steps:
- run: echo "🎉 The job was automatically triggered by a ${{ github.event_name }} event."
- run: echo "🐧 This job is now running on a ${{ runner.os }} server hosted by Gitea!"
- run: echo "🔎 The name of your branch is ${{ github.ref }} and your repository is ${{ github.repository }}."
- name: Check out repository code
uses: actions/checkout@v3
-
name: Setup Git LFS
run: |
git lfs install
git lfs fetch
git lfs checkout
- name: List files in the repository
run: |
ls ${{ github.workspace }}
-
name: Docker Image Info
id: image-info
run: |
echo "::set-output name=image_name::$(echo $REPOSITORY | tr '[:upper:]' '[:lower:]')"
echo "::set-output name=image_tag::${COMMIT_ID:0:10}"
-
name: Login to Docker Hub
uses: docker/login-action@v2
with:
registry: artifacts.iflytek.com
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
-
name: Build and push
run: |
docker version
docker buildx build -t artifacts.iflytek.com/docker-private/atp/${{ steps.image-info.outputs.image_name }}:${{ steps.image-info.outputs.image_tag }} . --file ${{ github.workspace }}/Dockerfile --load
docker push artifacts.iflytek.com/docker-private/atp/${{ steps.image-info.outputs.image_name }}:${{ steps.image-info.outputs.image_tag }}
docker rmi artifacts.iflytek.com/docker-private/atp/${{ steps.image-info.outputs.image_name }}:${{ steps.image-info.outputs.image_tag }}
- run: echo "🍏 This job's status is ${{ job.status }}."

12
Dockerfile Normal file
View File

@ -0,0 +1,12 @@
FROM python:3.8.13
WORKDIR /app
COPY . /app
RUN pip config set global.index-url https://pypi.mirrors.ustc.edu.cn/simple
RUN apt -y update && apt -y upgrade
RUN apt -y install ffmpeg
RUN pip install -r requirements.txt
CMD ["python", "app.py"]

304
app.py Normal file
View File

@ -0,0 +1,304 @@
import argparse
from concurrent.futures import ProcessPoolExecutor
import os
import subprocess as sp
from tempfile import NamedTemporaryFile
import time
import warnings
import torch
import gradio as gr
from audiocraft.data.audio_utils import convert_audio
from audiocraft.data.audio import audio_write
from audiocraft.models import MusicGen
from gradio.themes.utils import sizes
theme = gr.themes.Default(radius_size=sizes.radius_none).set(
block_label_text_color = '#4D63FF',
block_title_text_color = '#4D63FF',
button_primary_text_color = '#4D63FF',
button_primary_background_fill='#FFFFFF',
button_primary_border_color='#4D63FF',
button_primary_background_fill_hover='#EDEFFF',
)
MODEL = None # Last used model
IS_BATCHED = "facebook/MusicGen" in os.environ.get('SPACE_ID', '')
MAX_BATCH_SIZE = 12
BATCHED_DURATION = 15
INTERRUPTING = False
# We have to wrap subprocess call to clean a bit the log when using gr.make_waveform
_old_call = sp.call
def _call_nostderr(*args, **kwargs):
# Avoid ffmpeg vomitting on the logs.
kwargs['stderr'] = sp.DEVNULL
kwargs['stdout'] = sp.DEVNULL
_old_call(*args, **kwargs)
sp.call = _call_nostderr
# Preallocating the pool of processes.
pool = ProcessPoolExecutor(4)
pool.__enter__()
def interrupt():
global INTERRUPTING
INTERRUPTING = True
def make_waveform(*args, **kwargs):
# Further remove some warnings.
be = time.time()
with warnings.catch_warnings():
warnings.simplefilter('ignore')
out = gr.make_waveform(*args, **kwargs)
print("Make a video took", time.time() - be)
return out
def load_model(version='melody'):
global MODEL
print("Loading model", version)
if MODEL is None or MODEL.name != version:
MODEL = MusicGen.get_pretrained(version)
def _do_predictions(texts, melodies, duration, progress=False, **gen_kwargs):
MODEL.set_generation_params(duration=duration, **gen_kwargs)
print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
be = time.time()
processed_melodies = []
target_sr = 32000
target_ac = 1
for melody in melodies:
if melody is None:
processed_melodies.append(None)
else:
sr, melody = melody[0], torch.from_numpy(melody[1]).to(MODEL.device).float().t()
if melody.dim() == 1:
melody = melody[None]
melody = melody[..., :int(sr * duration)]
melody = convert_audio(melody, sr, target_sr, target_ac)
processed_melodies.append(melody)
if any(m is not None for m in processed_melodies):
outputs = MODEL.generate_with_chroma(
descriptions=texts,
melody_wavs=processed_melodies,
melody_sample_rate=target_sr,
progress=progress,
)
else:
outputs = MODEL.generate(texts, progress=progress)
outputs = outputs.detach().cpu().float()
out_files = []
for output in outputs:
with NamedTemporaryFile("wb", suffix=".wav", delete=False) as file:
audio_write(
file.name, output, MODEL.sample_rate, strategy="loudness",
loudness_headroom_db=16, loudness_compressor=True, add_suffix=False)
out_files.append(pool.submit(make_waveform, file.name))
res = [out_file.result() for out_file in out_files]
print("batch finished", len(texts), time.time() - be)
return res
def predict_batched(texts, melodies):
max_text_length = 512
texts = [text[:max_text_length] for text in texts]
load_model('melody')
res = _do_predictions(texts, melodies, BATCHED_DURATION)
return [res]
def predict_full(model, text, melody, duration, topk, topp, temperature, cfg_coef, progress=gr.Progress()):
global INTERRUPTING
INTERRUPTING = False
if temperature < 0:
raise gr.Error("Temperature must be >= 0.")
if topk < 0:
raise gr.Error("Topk must be non-negative.")
if topp < 0:
raise gr.Error("Topp must be non-negative.")
topk = int(topk)
load_model(model)
def _progress(generated, to_generate):
progress((generated, to_generate))
if INTERRUPTING:
raise gr.Error("Interrupted.")
MODEL.set_custom_progress_callback(_progress)
outs = _do_predictions(
[text], [melody], duration, progress=True,
top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef)
return outs[0]
def ui_full(launch_kwargs):
with gr.Blocks(theme=theme, css="footer {visibility: hidden}") as interface:
gr.Markdown(
"""
<div align='center' ><font size='60'>音乐生成</font></div>
"""
)
with gr.Row():
with gr.Column():
with gr.Row():
text = gr.Text(label="输入文本", interactive=True)
melody = gr.Audio(source="upload", type="numpy", label="旋律(可选)", interactive=True)
with gr.Row():
submit = gr.Button("Submit")
# Adapted from https://github.com/rkfg/audiocraft/blob/long/app.py, MIT license.
_ = gr.Button("中断").click(fn=interrupt, queue=False)
with gr.Row():
model = gr.Radio(["melody", "medium", "small", "large"], label="模型", value="melody", interactive=True)
with gr.Row():
duration = gr.Slider(minimum=1, maximum=120, value=10, label="Duration", interactive=True)
with gr.Row():
topk = gr.Number(label="Top-k", value=250, interactive=True)
topp = gr.Number(label="Top-p", value=0, interactive=True)
temperature = gr.Number(label="Temperature", value=1.0, interactive=True)
cfg_coef = gr.Number(label="Classifier Free Guidance", value=3.0, interactive=True)
with gr.Column():
output = gr.Video(label="生成的音乐")
submit.click(predict_full, inputs=[model, text, melody, duration, topk, topp, temperature, cfg_coef], outputs=[output])
gr.Examples(
fn=predict_full,
examples=[
[
"An 80s driving pop song with heavy drums and synth pads in the background",
"./assets/bach.mp3",
"melody"
],
[
"A cheerful country song with acoustic guitars",
"./assets/bolero_ravel.mp3",
"melody"
],
[
"90s rock song with electric guitar and heavy drums",
None,
"medium"
],
[
"a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions",
"./assets/bach.mp3",
"melody"
],
[
"lofi slow bpm electro chill with organic samples",
None,
"medium",
],
],
inputs=[text, melody, model],
outputs=[output],
label="例子"
)
interface.queue().launch(**launch_kwargs)
def ui_batched(launch_kwargs):
with gr.Blocks(theme=theme, css="footer {visibility: hidden}") as demo:
gr.Markdown(
"""
<div align='center' ><font size='60'>音乐生成</font></div>
"""
)
with gr.Row():
with gr.Column():
with gr.Row():
text = gr.Text(label="Describe your music", lines=2, interactive=True)
melody = gr.Audio(source="upload", type="numpy", label="Condition on a melody (optional)", interactive=True)
with gr.Row():
submit = gr.Button("Generate")
with gr.Column():
output = gr.Video(label="Generated Music")
submit.click(predict_batched, inputs=[text, melody], outputs=[output], batch=True, max_batch_size=MAX_BATCH_SIZE)
gr.Examples(
fn=predict_batched,
examples=[
[
"An 80s driving pop song with heavy drums and synth pads in the background",
"./assets/bach.mp3",
],
[
"A cheerful country song with acoustic guitars",
"./assets/bolero_ravel.mp3",
],
[
"90s rock song with electric guitar and heavy drums",
None,
],
[
"a light and cheerly EDM track, with syncopated drums, aery pads, and strong emotions bpm: 130",
"./assets/bach.mp3",
],
[
"lofi slow bpm electro chill with organic samples",
None,
],
],
inputs=[text, melody],
outputs=[output],
label="例子"
)
demo.queue(max_size=8 * 4).launch(**launch_kwargs)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--listen',
type=str,
default='0.0.0.0' if 'SPACE_ID' in os.environ else '0.0.0.0',
help='IP to listen on for connections to Gradio',
)
parser.add_argument(
'--username', type=str, default='', help='Username for authentication'
)
parser.add_argument(
'--password', type=str, default='', help='Password for authentication'
)
parser.add_argument(
'--server_port',
type=int,
default=0,
help='Port to run the server listener on',
)
parser.add_argument(
'--inbrowser', action='store_true', help='Open in browser'
)
parser.add_argument(
'--share', action='store_true', help='Share the gradio UI'
)
args = parser.parse_args()
launch_kwargs = {}
launch_kwargs['server_name'] = args.listen
if args.username and args.password:
launch_kwargs['auth'] = (args.username, args.password)
if args.server_port:
launch_kwargs['server_port'] = args.server_port
if args.inbrowser:
launch_kwargs['inbrowser'] = args.inbrowser
if args.share:
launch_kwargs['share'] = args.share
# Show the interface
if IS_BATCHED:
ui_batched(launch_kwargs)
else:
ui_full(launch_kwargs)

BIN
assets/bach.mp3 Normal file

Binary file not shown.

BIN
assets/bolero_ravel.mp3 Normal file

Binary file not shown.

10
audiocraft/__init__.py Normal file
View File

@ -0,0 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# flake8: noqa
from . import data, modules, models
__version__ = '0.0.2a2'

View File

@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# flake8: noqa
from . import audio, audio_dataset

215
audiocraft/data/audio.py Normal file
View File

@ -0,0 +1,215 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Audio IO methods are defined in this module (info, read, write),
We rely on av library for faster read when possible, otherwise on torchaudio.
"""
from dataclasses import dataclass
from pathlib import Path
import logging
import typing as tp
import numpy as np
import soundfile
import torch
from torch.nn import functional as F
import torchaudio as ta
import av
from .audio_utils import f32_pcm, i16_pcm, normalize_audio
_av_initialized = False
def _init_av():
global _av_initialized
if _av_initialized:
return
logger = logging.getLogger('libav.mp3')
logger.setLevel(logging.ERROR)
_av_initialized = True
@dataclass(frozen=True)
class AudioFileInfo:
sample_rate: int
duration: float
channels: int
def _av_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
_init_av()
with av.open(str(filepath)) as af:
stream = af.streams.audio[0]
sample_rate = stream.codec_context.sample_rate
duration = float(stream.duration * stream.time_base)
channels = stream.channels
return AudioFileInfo(sample_rate, duration, channels)
def _soundfile_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
info = soundfile.info(filepath)
return AudioFileInfo(info.samplerate, info.duration, info.channels)
def audio_info(filepath: tp.Union[str, Path]) -> AudioFileInfo:
# torchaudio no longer returns useful duration informations for some formats like mp3s.
filepath = Path(filepath)
if filepath.suffix in ['.flac', '.ogg']: # TODO: Validate .ogg can be safely read with av_info
# ffmpeg has some weird issue with flac.
return _soundfile_info(filepath)
else:
return _av_info(filepath)
def _av_read(filepath: tp.Union[str, Path], seek_time: float = 0, duration: float = -1.) -> tp.Tuple[torch.Tensor, int]:
"""FFMPEG-based audio file reading using PyAV bindings.
Soundfile cannot read mp3 and av_read is more efficient than torchaudio.
Args:
filepath (str or Path): Path to audio file to read.
seek_time (float): Time at which to start reading in the file.
duration (float): Duration to read from the file. If set to -1, the whole file is read.
Returns:
Tuple[torch.Tensor, int]: Tuple containing audio data and sample rate
"""
_init_av()
with av.open(str(filepath)) as af:
stream = af.streams.audio[0]
sr = stream.codec_context.sample_rate
num_frames = int(sr * duration) if duration >= 0 else -1
frame_offset = int(sr * seek_time)
# we need a small negative offset otherwise we get some edge artifact
# from the mp3 decoder.
af.seek(int(max(0, (seek_time - 0.1)) / stream.time_base), stream=stream)
frames = []
length = 0
for frame in af.decode(streams=stream.index):
current_offset = int(frame.rate * frame.pts * frame.time_base)
strip = max(0, frame_offset - current_offset)
buf = torch.from_numpy(frame.to_ndarray())
if buf.shape[0] != stream.channels:
buf = buf.view(-1, stream.channels).t()
buf = buf[:, strip:]
frames.append(buf)
length += buf.shape[1]
if num_frames > 0 and length >= num_frames:
break
assert frames
# If the above assert fails, it is likely because we seeked past the end of file point,
# in which case ffmpeg returns a single frame with only zeros, and a weird timestamp.
# This will need proper debugging, in due time.
wav = torch.cat(frames, dim=1)
assert wav.shape[0] == stream.channels
if num_frames > 0:
wav = wav[:, :num_frames]
return f32_pcm(wav), sr
def audio_read(filepath: tp.Union[str, Path], seek_time: float = 0.,
duration: float = -1., pad: bool = False) -> tp.Tuple[torch.Tensor, int]:
"""Read audio by picking the most appropriate backend tool based on the audio format.
Args:
filepath (str or Path): Path to audio file to read.
seek_time (float): Time at which to start reading in the file.
duration (float): Duration to read from the file. If set to -1, the whole file is read.
pad (bool): Pad output audio if not reaching expected duration.
Returns:
Tuple[torch.Tensor, int]: Tuple containing audio data and sample rate.
"""
fp = Path(filepath)
if fp.suffix in ['.flac', '.ogg']: # TODO: check if we can safely use av_read for .ogg
# There is some bug with ffmpeg and reading flac
info = _soundfile_info(filepath)
frames = -1 if duration <= 0 else int(duration * info.sample_rate)
frame_offset = int(seek_time * info.sample_rate)
wav, sr = soundfile.read(filepath, start=frame_offset, frames=frames, dtype=np.float32)
assert info.sample_rate == sr, f"Mismatch of sample rates {info.sample_rate} {sr}"
wav = torch.from_numpy(wav).t().contiguous()
if len(wav.shape) == 1:
wav = torch.unsqueeze(wav, 0)
elif (
fp.suffix in ['.wav', '.mp3'] and fp.suffix[1:] in ta.utils.sox_utils.list_read_formats()
and duration <= 0 and seek_time == 0
):
# Torchaudio is faster if we load an entire file at once.
wav, sr = ta.load(fp)
else:
wav, sr = _av_read(filepath, seek_time, duration)
if pad and duration > 0:
expected_frames = int(duration * sr)
wav = F.pad(wav, (0, expected_frames - wav.shape[-1]))
return wav, sr
def audio_write(stem_name: tp.Union[str, Path],
wav: torch.Tensor, sample_rate: int,
format: str = 'wav', mp3_rate: int = 320, normalize: bool = True,
strategy: str = 'peak', peak_clip_headroom_db: float = 1,
rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
loudness_compressor: bool = False,
log_clipping: bool = True, make_parent_dir: bool = True,
add_suffix: bool = True) -> Path:
"""Convenience function for saving audio to disk. Returns the filename the audio was written to.
Args:
stem_name (str or Path): Filename without extension which will be added automatically.
format (str): Either "wav" or "mp3".
mp3_rate (int): kbps when using mp3s.
normalize (bool): if `True` (default), normalizes according to the prescribed
strategy (see after). If `False`, the strategy is only used in case clipping
would happen.
strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
with extra headroom to avoid clipping. 'clip' just clips.
peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
than the `peak_clip` one to avoid further clipping.
loudness_headroom_db (float): Target loudness for loudness normalization.
loudness_compressor (bool): Uses tanh for soft clipping when strategy is 'loudness'.
when strategy is 'loudness'log_clipping (bool): If True, basic logging on stderr when clipping still
occurs despite strategy (only for 'rms').
make_parent_dir (bool): Make parent directory if it doesn't exist.
Returns:
Path: Path of the saved audio.
"""
assert wav.dtype.is_floating_point, "wav is not floating point"
if wav.dim() == 1:
wav = wav[None]
elif wav.dim() > 2:
raise ValueError("Input wav should be at most 2 dimension.")
assert wav.isfinite().all()
wav = normalize_audio(wav, normalize, strategy, peak_clip_headroom_db,
rms_headroom_db, loudness_headroom_db, log_clipping=log_clipping,
sample_rate=sample_rate, stem_name=str(stem_name))
kwargs: dict = {}
if format == 'mp3':
suffix = '.mp3'
kwargs.update({"compression": mp3_rate})
elif format == 'wav':
wav = i16_pcm(wav)
suffix = '.wav'
kwargs.update({"encoding": "PCM_S", "bits_per_sample": 16})
else:
raise RuntimeError(f"Invalid format {format}. Only wav or mp3 are supported.")
if not add_suffix:
suffix = ''
path = Path(str(stem_name) + suffix)
if make_parent_dir:
path.parent.mkdir(exist_ok=True, parents=True)
try:
ta.save(path, wav, sample_rate, **kwargs)
except Exception:
if path.exists():
# we do not want to leave half written files around.
path.unlink()
raise
return path

View File

@ -0,0 +1,525 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import copy
from concurrent.futures import ThreadPoolExecutor, Future
from dataclasses import dataclass, fields
from contextlib import ExitStack
import gzip
import json
import logging
import os
from pathlib import Path
import random
import sys
import typing as tp
import torch
import torch.nn.functional as F
from .audio import audio_read, audio_info
from .audio_utils import convert_audio
from .zip import PathInZip
try:
import dora
except ImportError:
dora = None # type: ignore
@dataclass(order=True)
class BaseInfo:
@classmethod
def _dict2fields(cls, dictionary: dict):
return {
field.name: dictionary[field.name]
for field in fields(cls) if field.name in dictionary
}
@classmethod
def from_dict(cls, dictionary: dict):
_dictionary = cls._dict2fields(dictionary)
return cls(**_dictionary)
def to_dict(self):
return {
field.name: self.__getattribute__(field.name)
for field in fields(self)
}
@dataclass(order=True)
class AudioMeta(BaseInfo):
path: str
duration: float
sample_rate: int
amplitude: tp.Optional[float] = None
weight: tp.Optional[float] = None
# info_path is used to load additional information about the audio file that is stored in zip files.
info_path: tp.Optional[PathInZip] = None
@classmethod
def from_dict(cls, dictionary: dict):
base = cls._dict2fields(dictionary)
if 'info_path' in base and base['info_path'] is not None:
base['info_path'] = PathInZip(base['info_path'])
return cls(**base)
def to_dict(self):
d = super().to_dict()
if d['info_path'] is not None:
d['info_path'] = str(d['info_path'])
return d
@dataclass(order=True)
class SegmentInfo(BaseInfo):
meta: AudioMeta
seek_time: float
n_frames: int # actual number of frames without padding
total_frames: int # total number of frames, padding included
sample_rate: int # actual sample rate
DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
logger = logging.getLogger(__name__)
def _get_audio_meta(file_path: str, minimal: bool = True) -> AudioMeta:
"""AudioMeta from a path to an audio file.
Args:
file_path (str): Resolved path of valid audio file.
minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
Returns:
AudioMeta: Audio file path and its metadata.
"""
info = audio_info(file_path)
amplitude: tp.Optional[float] = None
if not minimal:
wav, sr = audio_read(file_path)
amplitude = wav.abs().max().item()
return AudioMeta(file_path, info.duration, info.sample_rate, amplitude)
def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta:
"""If Dora is available as a dependency, try to resolve potential relative paths
in list of AudioMeta. This method is expected to be used when loading meta from file.
Args:
m (AudioMeta): Audio meta to resolve.
fast (bool): If True, uses a really fast check for determining if a file is already absolute or not.
Only valid on Linux/Mac.
Returns:
AudioMeta: Audio meta with resolved path.
"""
def is_abs(m):
if fast:
return str(m)[0] == '/'
else:
os.path.isabs(str(m))
if not dora:
return m
if not is_abs(m.path):
m.path = dora.git_save.to_absolute_path(m.path)
if m.info_path is not None and not is_abs(m.info_path.zip_path):
m.info_path.zip_path = dora.git_save.to_absolute_path(m.path)
return m
def find_audio_files(path: tp.Union[Path, str],
exts: tp.List[str] = DEFAULT_EXTS,
resolve: bool = True,
minimal: bool = True,
progress: bool = False,
workers: int = 0) -> tp.List[AudioMeta]:
"""Build a list of AudioMeta from a given path,
collecting relevant audio files and fetching meta info.
Args:
path (str or Path): Path to folder containing audio files.
exts (list of str): List of file extensions to consider for audio files.
minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
progress (bool): Whether to log progress on audio files collection.
workers (int): number of parallel workers, if 0, use only the current thread.
Returns:
List[AudioMeta]: List of audio file path and its metadata.
"""
audio_files = []
futures: tp.List[Future] = []
pool: tp.Optional[ThreadPoolExecutor] = None
with ExitStack() as stack:
if workers > 0:
pool = ThreadPoolExecutor(workers)
stack.enter_context(pool)
if progress:
print("Finding audio files...")
for root, folders, files in os.walk(path, followlinks=True):
for file in files:
full_path = Path(root) / file
if full_path.suffix.lower() in exts:
audio_files.append(full_path)
if pool is not None:
futures.append(pool.submit(_get_audio_meta, str(audio_files[-1]), minimal))
if progress:
print(format(len(audio_files), " 8d"), end='\r', file=sys.stderr)
if progress:
print("Getting audio metadata...")
meta: tp.List[AudioMeta] = []
for idx, file_path in enumerate(audio_files):
try:
if pool is None:
m = _get_audio_meta(str(file_path), minimal)
else:
m = futures[idx].result()
if resolve:
m = _resolve_audio_meta(m)
except Exception as err:
print("Error with", str(file_path), err, file=sys.stderr)
continue
meta.append(m)
if progress:
print(format((1 + idx) / len(audio_files), " 3.1%"), end='\r', file=sys.stderr)
meta.sort()
return meta
def load_audio_meta(path: tp.Union[str, Path],
resolve: bool = True, fast: bool = True) -> tp.List[AudioMeta]:
"""Load list of AudioMeta from an optionally compressed json file.
Args:
path (str or Path): Path to JSON file.
resolve (bool): Whether to resolve the path from AudioMeta (default=True).
fast (bool): activates some tricks to make things faster.
Returns:
List[AudioMeta]: List of audio file path and its total duration.
"""
open_fn = gzip.open if str(path).lower().endswith('.gz') else open
with open_fn(path, 'rb') as fp: # type: ignore
lines = fp.readlines()
meta = []
for line in lines:
d = json.loads(line)
m = AudioMeta.from_dict(d)
if resolve:
m = _resolve_audio_meta(m, fast=fast)
meta.append(m)
return meta
def save_audio_meta(path: tp.Union[str, Path], meta: tp.List[AudioMeta]):
"""Save the audio metadata to the file pointer as json.
Args:
path (str or Path): Path to JSON file.
metadata (list of BaseAudioMeta): List of audio meta to save.
"""
Path(path).parent.mkdir(exist_ok=True, parents=True)
open_fn = gzip.open if str(path).lower().endswith('.gz') else open
with open_fn(path, 'wb') as fp: # type: ignore
for m in meta:
json_str = json.dumps(m.to_dict()) + '\n'
json_bytes = json_str.encode('utf-8')
fp.write(json_bytes)
class AudioDataset:
"""Base audio dataset.
The dataset takes a list of AudioMeta and create a dataset composed of segments of audio
and potentially additional information, by creating random segments from the list of audio
files referenced in the metadata and applying minimal data pre-processing such as resampling,
mixing of channels, padding, etc.
If no segment_duration value is provided, the AudioDataset will return the full wav for each
audio file. Otherwise, it will randomly sample audio files and create a segment of the specified
duration, applying padding if required.
By default, only the torch Tensor corresponding to the waveform is returned. Setting return_info=True
allows to return a tuple containing the torch Tensor and additional metadata on the segment and the
original audio meta.
Args:
meta (tp.List[AudioMeta]): List of audio files metadata.
segment_duration (float): Optional segment duration of audio to load.
If not specified, the dataset will load the full audio segment from the file.
shuffle (bool): Set to `True` to have the data reshuffled at every epoch.
sample_rate (int): Target sample rate of the loaded audio samples.
channels (int): Target number of channels of the loaded audio samples.
sample_on_duration (bool): Set to `True` to sample segments with probability
dependent on audio file duration. This is only used if `segment_duration` is provided.
sample_on_weight (bool): Set to `True` to sample segments using the `weight` entry of
`AudioMeta`. If `sample_on_duration` is also True, the actual weight will be the product
of the file duration and file weight. This is only used if `segment_duration` is provided.
min_segment_ratio (float): Minimum segment ratio to use when the audio file
is shorter than the desired segment.
max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset.
return_info (bool): Whether to return the wav only or return wav along with segment info and metadata.
min_audio_duration (tp.Optional[float], optional): Minimum audio file duration, in seconds, if provided
audio shorter than this will be filtered out.
max_audio_duration (tp.Optional[float], optional): Maximal audio file duration in seconds, if provided
audio longer than this will be filtered out.
"""
def __init__(self,
meta: tp.List[AudioMeta],
segment_duration: tp.Optional[float] = None,
shuffle: bool = True,
num_samples: int = 10_000,
sample_rate: int = 48_000,
channels: int = 2,
pad: bool = True,
sample_on_duration: bool = True,
sample_on_weight: bool = True,
min_segment_ratio: float = 0.5,
max_read_retry: int = 10,
return_info: bool = False,
min_audio_duration: tp.Optional[float] = None,
max_audio_duration: tp.Optional[float] = None
):
assert len(meta) > 0, 'No audio meta provided to AudioDataset. Please check loading of audio meta.'
assert segment_duration is None or segment_duration > 0
assert segment_duration is None or min_segment_ratio >= 0
logging.debug(f'sample_on_duration: {sample_on_duration}')
logging.debug(f'sample_on_weight: {sample_on_weight}')
logging.debug(f'pad: {pad}')
logging.debug(f'min_segment_ratio: {min_segment_ratio}')
self.segment_duration = segment_duration
self.min_segment_ratio = min_segment_ratio
self.max_audio_duration = max_audio_duration
self.min_audio_duration = min_audio_duration
if self.min_audio_duration is not None and self.max_audio_duration is not None:
assert self.min_audio_duration <= self.max_audio_duration
self.meta: tp.List[AudioMeta] = self._filter_duration(meta)
assert len(self.meta) # Fail fast if all data has been filtered.
self.total_duration = sum(d.duration for d in self.meta)
if segment_duration is None:
num_samples = len(self.meta)
self.num_samples = num_samples
self.shuffle = shuffle
self.sample_rate = sample_rate
self.channels = channels
self.pad = pad
self.sample_on_weight = sample_on_weight
self.sample_on_duration = sample_on_duration
self.sampling_probabilities = self._get_sampling_probabilities()
self.max_read_retry = max_read_retry
self.return_info = return_info
def __len__(self):
return self.num_samples
def _get_sampling_probabilities(self, normalized: bool = True):
"""Return the sampling probabilities for each file inside `self.meta`.
"""
scores: tp.List[float] = []
for file_meta in self.meta:
score = 1.
if self.sample_on_weight and file_meta.weight is not None:
score *= file_meta.weight
if self.sample_on_duration:
score *= file_meta.duration
scores.append(score)
probabilities = torch.tensor(scores)
if normalized:
probabilities /= probabilities.sum()
return probabilities
def sample_file(self, rng: torch.Generator) -> AudioMeta:
"""Sample a given file from `self.meta`. Can be overriden in subclasses.
This is only called if `segment_duration` is not None.
You must use the provided random number generator `rng` for reproducibility.
"""
if not self.sample_on_weight and not self.sample_on_duration:
file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
else:
file_index = int(torch.multinomial(self.sampling_probabilities, 1, generator=rng).item())
return self.meta[file_index]
def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]:
if self.segment_duration is None:
file_meta = self.meta[index]
out, sr = audio_read(file_meta.path)
out = convert_audio(out, sr, self.sample_rate, self.channels)
n_frames = out.shape[-1]
segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames,
sample_rate=self.sample_rate)
else:
rng = torch.Generator()
if self.shuffle:
# We use index, plus extra randomness
rng.manual_seed(index + self.num_samples * random.randint(0, 2**24))
else:
# We only use index
rng.manual_seed(index)
for retry in range(self.max_read_retry):
file_meta = self.sample_file(rng)
# We add some variance in the file position even if audio file is smaller than segment
# without ending up with empty segments
max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio)
seek_time = torch.rand(1, generator=rng).item() * max_seek
try:
out, sr = audio_read(file_meta.path, seek_time, self.segment_duration, pad=False)
out = convert_audio(out, sr, self.sample_rate, self.channels)
n_frames = out.shape[-1]
target_frames = int(self.segment_duration * self.sample_rate)
if self.pad:
out = F.pad(out, (0, target_frames - n_frames))
segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames,
sample_rate=self.sample_rate)
except Exception as exc:
logger.warning("Error opening file %s: %r", file_meta.path, exc)
if retry == self.max_read_retry - 1:
raise
else:
break
if self.return_info:
# Returns the wav and additional information on the wave segment
return out, segment_info
else:
return out
def collater(self, samples):
"""The collater function has to be provided to the dataloader
if AudioDataset has return_info=True in order to properly collate
the samples of a batch.
"""
if self.segment_duration is None and len(samples) > 1:
assert self.pad, "Must allow padding when batching examples of different durations."
# In this case the audio reaching the collater is of variable length as segment_duration=None.
to_pad = self.segment_duration is None and self.pad
if to_pad:
max_len = max([wav.shape[-1] for wav, _ in samples])
def _pad_wav(wav):
return F.pad(wav, (0, max_len - wav.shape[-1]))
if self.return_info:
if len(samples) > 0:
assert len(samples[0]) == 2
assert isinstance(samples[0][0], torch.Tensor)
assert isinstance(samples[0][1], SegmentInfo)
wavs = [wav for wav, _ in samples]
segment_infos = [copy.deepcopy(info) for _, info in samples]
if to_pad:
# Each wav could be of a different duration as they are not segmented.
for i in range(len(samples)):
# Determines the total legth of the signal with padding, so we update here as we pad.
segment_infos[i].total_frames = max_len
wavs[i] = _pad_wav(wavs[i])
wav = torch.stack(wavs)
return wav, segment_infos
else:
assert isinstance(samples[0], torch.Tensor)
if to_pad:
samples = [_pad_wav(s) for s in samples]
return torch.stack(samples)
def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
"""Filters out audio files with short durations.
Removes from meta files that have durations that will not allow to samples examples from them.
"""
orig_len = len(meta)
# Filter data that is too short.
if self.min_audio_duration is not None:
meta = [m for m in meta if m.duration >= self.min_audio_duration]
# Filter data that is too long.
if self.max_audio_duration is not None:
meta = [m for m in meta if m.duration <= self.max_audio_duration]
filtered_len = len(meta)
removed_percentage = 100*(1-float(filtered_len)/orig_len)
msg = 'Removed %.2f percent of the data because it was too short or too long.' % removed_percentage
if removed_percentage < 10:
logging.debug(msg)
else:
logging.warning(msg)
return meta
@classmethod
def from_meta(cls, root: tp.Union[str, Path], **kwargs):
"""Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file.
Args:
root (str or Path): Path to root folder containing audio files.
kwargs: Additional keyword arguments for the AudioDataset.
"""
root = Path(root)
if root.is_dir():
if (root / 'data.jsonl').exists():
root = root / 'data.jsonl'
elif (root / 'data.jsonl.gz').exists():
root = root / 'data.jsonl.gz'
else:
raise ValueError("Don't know where to read metadata from in the dir. "
"Expecting either a data.jsonl or data.jsonl.gz file but none found.")
meta = load_audio_meta(root)
return cls(meta, **kwargs)
@classmethod
def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True,
exts: tp.List[str] = DEFAULT_EXTS, **kwargs):
"""Instantiate AudioDataset from a path containing (possibly nested) audio files.
Args:
root (str or Path): Path to root folder containing audio files.
minimal_meta (bool): Whether to only load minimal metadata or not.
exts (list of str): Extensions for audio files.
kwargs: Additional keyword arguments for the AudioDataset.
"""
root = Path(root)
if root.is_file():
meta = load_audio_meta(root, resolve=True)
else:
meta = find_audio_files(root, exts, minimal=minimal_meta, resolve=True)
return cls(meta, **kwargs)
def main():
logging.basicConfig(stream=sys.stderr, level=logging.INFO)
parser = argparse.ArgumentParser(
prog='audio_dataset',
description='Generate .jsonl files by scanning a folder.')
parser.add_argument('root', help='Root folder with all the audio files')
parser.add_argument('output_meta_file',
help='Output file to store the metadata, ')
parser.add_argument('--complete',
action='store_false', dest='minimal', default=True,
help='Retrieve all metadata, even the one that are expansive '
'to compute (e.g. normalization).')
parser.add_argument('--resolve',
action='store_true', default=False,
help='Resolve the paths to be absolute and with no symlinks.')
parser.add_argument('--workers',
default=10, type=int,
help='Number of workers.')
args = parser.parse_args()
meta = find_audio_files(args.root, DEFAULT_EXTS, progress=True,
resolve=args.resolve, minimal=args.minimal, workers=args.workers)
save_audio_meta(args.output_meta_file, meta)
if __name__ == '__main__':
main()

View File

@ -0,0 +1,174 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import sys
import typing as tp
import julius
import torch
import torchaudio
def convert_audio_channels(wav: torch.Tensor, channels: int = 2) -> torch.Tensor:
"""Convert audio to the given number of channels.
Args:
wav (torch.Tensor): Audio wave of shape [B, C, T].
channels (int): Expected number of channels as output.
Returns:
torch.Tensor: Downmixed or unchanged audio wave [B, C, T].
"""
*shape, src_channels, length = wav.shape
if src_channels == channels:
pass
elif channels == 1:
# Case 1:
# The caller asked 1-channel audio, and the stream has multiple
# channels, downmix all channels.
wav = wav.mean(dim=-2, keepdim=True)
elif src_channels == 1:
# Case 2:
# The caller asked for multiple channels, but the input file has
# a single channel, replicate the audio over all channels.
wav = wav.expand(*shape, channels, length)
elif src_channels >= channels:
# Case 3:
# The caller asked for multiple channels, and the input file has
# more channels than requested. In that case return the first channels.
wav = wav[..., :channels, :]
else:
# Case 4: What is a reasonable choice here?
raise ValueError('The audio file has less channels than requested but is not mono.')
return wav
def convert_audio(wav: torch.Tensor, from_rate: float,
to_rate: float, to_channels: int) -> torch.Tensor:
"""Convert audio to new sample rate and number of audio channels.
"""
wav = julius.resample_frac(wav, int(from_rate), int(to_rate))
wav = convert_audio_channels(wav, to_channels)
return wav
def normalize_loudness(wav: torch.Tensor, sample_rate: int, loudness_headroom_db: float = 14,
loudness_compressor: bool = False, energy_floor: float = 2e-3):
"""Normalize an input signal to a user loudness in dB LKFS.
Audio loudness is defined according to the ITU-R BS.1770-4 recommendation.
Args:
wav (torch.Tensor): Input multichannel audio data.
sample_rate (int): Sample rate.
loudness_headroom_db (float): Target loudness of the output in dB LUFS.
loudness_compressor (bool): Uses tanh for soft clipping.
energy_floor (float): anything below that RMS level will not be rescaled.
Returns:
output (torch.Tensor): Loudness normalized output data.
"""
energy = wav.pow(2).mean().sqrt().item()
if energy < energy_floor:
return wav
transform = torchaudio.transforms.Loudness(sample_rate)
input_loudness_db = transform(wav).item()
# calculate the gain needed to scale to the desired loudness level
delta_loudness = -loudness_headroom_db - input_loudness_db
gain = 10.0 ** (delta_loudness / 20.0)
output = gain * wav
if loudness_compressor:
output = torch.tanh(output)
assert output.isfinite().all(), (input_loudness_db, wav.pow(2).mean().sqrt())
return output
def _clip_wav(wav: torch.Tensor, log_clipping: bool = False, stem_name: tp.Optional[str] = None) -> None:
"""Utility function to clip the audio with logging if specified."""
max_scale = wav.abs().max()
if log_clipping and max_scale > 1:
clamp_prob = (wav.abs() > 1).float().mean().item()
print(f"CLIPPING {stem_name or ''} happening with proba (a bit of clipping is okay):",
clamp_prob, "maximum scale: ", max_scale.item(), file=sys.stderr)
wav.clamp_(-1, 1)
def normalize_audio(wav: torch.Tensor, normalize: bool = True,
strategy: str = 'peak', peak_clip_headroom_db: float = 1,
rms_headroom_db: float = 18, loudness_headroom_db: float = 14,
loudness_compressor: bool = False, log_clipping: bool = False,
sample_rate: tp.Optional[int] = None,
stem_name: tp.Optional[str] = None) -> torch.Tensor:
"""Normalize the audio according to the prescribed strategy (see after).
Args:
wav (torch.Tensor): Audio data.
normalize (bool): if `True` (default), normalizes according to the prescribed
strategy (see after). If `False`, the strategy is only used in case clipping
would happen.
strategy (str): Can be either 'clip', 'peak', or 'rms'. Default is 'peak',
i.e. audio is normalized by its largest value. RMS normalizes by root-mean-square
with extra headroom to avoid clipping. 'clip' just clips.
peak_clip_headroom_db (float): Headroom in dB when doing 'peak' or 'clip' strategy.
rms_headroom_db (float): Headroom in dB when doing 'rms' strategy. This must be much larger
than the `peak_clip` one to avoid further clipping.
loudness_headroom_db (float): Target loudness for loudness normalization.
loudness_compressor (bool): If True, uses tanh based soft clipping.
log_clipping (bool): If True, basic logging on stderr when clipping still
occurs despite strategy (only for 'rms').
sample_rate (int): Sample rate for the audio data (required for loudness).
stem_name (Optional[str]): Stem name for clipping logging.
Returns:
torch.Tensor: Normalized audio.
"""
scale_peak = 10 ** (-peak_clip_headroom_db / 20)
scale_rms = 10 ** (-rms_headroom_db / 20)
if strategy == 'peak':
rescaling = (scale_peak / wav.abs().max())
if normalize or rescaling < 1:
wav = wav * rescaling
elif strategy == 'clip':
wav = wav.clamp(-scale_peak, scale_peak)
elif strategy == 'rms':
mono = wav.mean(dim=0)
rescaling = scale_rms / mono.pow(2).mean().sqrt()
if normalize or rescaling < 1:
wav = wav * rescaling
_clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
elif strategy == 'loudness':
assert sample_rate is not None, "Loudness normalization requires sample rate."
wav = normalize_loudness(wav, sample_rate, loudness_headroom_db, loudness_compressor)
_clip_wav(wav, log_clipping=log_clipping, stem_name=stem_name)
else:
assert wav.abs().max() < 1
assert strategy == '' or strategy == 'none', f"Unexpected strategy: '{strategy}'"
return wav
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
"""Convert audio to float 32 bits PCM format.
"""
if wav.dtype.is_floating_point:
return wav
else:
assert wav.dtype == torch.int16
return wav.float() / 2**15
def i16_pcm(wav: torch.Tensor) -> torch.Tensor:
"""Convert audio to int 16 bits PCM format.
..Warning:: There exist many formula for doing this convertion. None are perfect
due to the asymetry of the int16 range. One either have possible clipping, DC offset,
or inconsistancies with f32_pcm. If the given wav doesn't have enough headroom,
it is possible that `i16_pcm(f32_pcm)) != Identity`.
"""
if wav.dtype.is_floating_point:
assert wav.abs().max() <= 1
candidate = (wav * 2 ** 15).round()
if candidate.max() >= 2 ** 15: # clipping would occur
candidate = (wav * (2 ** 15 - 1)).round()
return candidate.short()
else:
assert wav.dtype == torch.int16
return wav

74
audiocraft/data/zip.py Normal file
View File

@ -0,0 +1,74 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing
import zipfile
from dataclasses import dataclass
from functools import lru_cache
from typing_extensions import Literal
DEFAULT_SIZE = 32
MODE = Literal['r', 'w', 'x', 'a']
@dataclass(order=True)
class PathInZip:
"""Class for holding a path of file within a zip file.
Args:
path: The convention is <path_to_zip>:<relative_path_inside_zip>
Let's assume there is a zip file /some/location/foo.zip
and inside of it is a json file located at /data/file1.json,
Then we expect path = "/some/location/foo.zip:/data/file1.json"
"""
INFO_PATH_SEP = ':'
zip_path: str
file_path: str
def __init__(self, path: str) -> None:
split_path = path.split(self.INFO_PATH_SEP)
assert len(split_path) == 2
self.zip_path, self.file_path = split_path
@classmethod
def from_paths(cls, zip_path: str, file_path: str):
return cls(zip_path + cls.INFO_PATH_SEP + file_path)
def __str__(self) -> str:
return self.zip_path + self.INFO_PATH_SEP + self.file_path
def _open_zip(path: str, mode: MODE = 'r'):
return zipfile.ZipFile(path, mode)
_cached_open_zip = lru_cache(DEFAULT_SIZE)(_open_zip)
def set_zip_cache_size(max_size: int):
"""Sets the maximal LRU caching for zip file opening.
Args:
max_size: the maximal LRU cache.
"""
global _cached_open_zip
_cached_open_zip = lru_cache(max_size)(_open_zip)
def open_file_in_zip(path_in_zip: PathInZip, mode: str = 'r') -> typing.IO:
"""Opens a file stored inside a zip and returns a file-like object.
Args:
path_in_zip: A PathInZip object representing the file to return a file-like object of.
mode: The mode in which to open the file with.
Returns:
A file-like object for PathInZip.
"""
zf = _cached_open_zip(path_in_zip.zip_path)
return zf.open(path_in_zip.file_path)

View File

@ -0,0 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# flake8: noqa
from .musicgen import MusicGen
from .lm import LMModel
from .encodec import CompressionModel, EncodecModel

View File

@ -0,0 +1,218 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
All the functions to build the relevant models and modules
from the Hydra config.
"""
import typing as tp
import warnings
import audiocraft
import omegaconf
import torch
from .encodec import CompressionModel, EncodecModel, FlattenedCompressionModel # noqa
from .lm import LMModel
from ..modules.codebooks_patterns import (
CodebooksPatternProvider,
DelayedPatternProvider,
ParallelPatternProvider,
UnrolledPatternProvider,
VALLEPattern,
MusicLMPattern,
)
from ..modules.conditioners import (
BaseConditioner,
ConditioningProvider,
LUTConditioner,
T5Conditioner,
ConditionFuser,
ChromaStemConditioner,
)
from .. import quantization as qt
from ..utils.utils import dict_from_config
def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> qt.BaseQuantizer:
klass = {
'no_quant': qt.DummyQuantizer,
'rvq': qt.ResidualVectorQuantizer
}[quantizer]
kwargs = dict_from_config(getattr(cfg, quantizer))
if quantizer != 'no_quant':
kwargs['dimension'] = dimension
return klass(**kwargs)
def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
if encoder_name == 'seanet':
kwargs = dict_from_config(getattr(cfg, 'seanet'))
encoder_override_kwargs = kwargs.pop('encoder')
decoder_override_kwargs = kwargs.pop('decoder')
encoder_kwargs = {**kwargs, **encoder_override_kwargs}
decoder_kwargs = {**kwargs, **decoder_override_kwargs}
encoder = audiocraft.modules.SEANetEncoder(**encoder_kwargs)
decoder = audiocraft.modules.SEANetDecoder(**decoder_kwargs)
return encoder, decoder
else:
raise KeyError(f'Unexpected compression model {cfg.compression_model}')
def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel:
"""Instantiate a compression model.
"""
if cfg.compression_model == 'encodec':
kwargs = dict_from_config(getattr(cfg, 'encodec'))
encoder_name = kwargs.pop('autoencoder')
quantizer_name = kwargs.pop('quantizer')
encoder, decoder = get_encodec_autoencoder(encoder_name, cfg)
quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
frame_rate = kwargs['sample_rate'] // encoder.hop_length
renormalize = kwargs.pop('renormalize', None)
renorm = kwargs.pop('renorm')
if renormalize is None:
renormalize = renorm is not None
warnings.warn("You are using a deprecated EnCodec model. Please migrate to new renormalization.")
return EncodecModel(encoder, decoder, quantizer,
frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
else:
raise KeyError(f'Unexpected compression model {cfg.compression_model}')
def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
"""Instantiate a transformer LM.
"""
if cfg.lm_model == 'transformer_lm':
kwargs = dict_from_config(getattr(cfg, 'transformer_lm'))
n_q = kwargs['n_q']
q_modeling = kwargs.pop('q_modeling', None)
codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
cfg_prob, cfg_coef = cls_free_guidance["training_dropout"], cls_free_guidance["inference_coef"]
fuser = get_condition_fuser(cfg)
condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device)
if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programatically
kwargs['cross_attention'] = True
if codebooks_pattern_cfg.modeling is None:
assert q_modeling is not None, \
'LM model should either have a codebook pattern defined or transformer_lm.q_modeling'
codebooks_pattern_cfg = omegaconf.OmegaConf.create(
{'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}}
)
pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
return LMModel(
pattern_provider=pattern_provider,
condition_provider=condition_provider,
fuser=fuser,
cfg_dropout=cfg_prob,
cfg_coef=cfg_coef,
attribute_dropout=attribute_dropout,
dtype=getattr(torch, cfg.dtype),
device=cfg.device,
**kwargs
).to(cfg.device)
else:
raise KeyError(f'Unexpected LM model {cfg.lm_model}')
def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider:
"""Instantiate a conditioning model.
"""
device = cfg.device
duration = cfg.dataset.segment_duration
cfg = getattr(cfg, "conditioners")
cfg = omegaconf.OmegaConf.create({}) if cfg is None else cfg
conditioners: tp.Dict[str, BaseConditioner] = {}
with omegaconf.open_dict(cfg):
condition_provider_args = cfg.pop('args', {})
for cond, cond_cfg in cfg.items():
model_type = cond_cfg["model"]
model_args = cond_cfg[model_type]
if model_type == "t5":
conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args)
elif model_type == "lut":
conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args)
elif model_type == "chroma_stem":
model_args.pop('cache_path', None)
conditioners[str(cond)] = ChromaStemConditioner(
output_dim=output_dim,
duration=duration,
device=device,
**model_args
)
else:
raise ValueError(f"unrecognized conditioning model: {model_type}")
conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args)
return conditioner
def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
"""Instantiate a condition fuser object.
"""
fuser_cfg = getattr(cfg, "fuser")
fuser_methods = ["sum", "cross", "prepend", "input_interpolate"]
fuse2cond = {k: fuser_cfg[k] for k in fuser_methods}
kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
return fuser
def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider:
"""Instantiate a codebooks pattern provider object.
"""
pattern_providers = {
'parallel': ParallelPatternProvider,
'delay': DelayedPatternProvider,
'unroll': UnrolledPatternProvider,
'valle': VALLEPattern,
'musiclm': MusicLMPattern,
}
name = cfg.modeling
kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
klass = pattern_providers[name]
return klass(n_q, **kwargs)
def get_debug_compression_model(device='cpu'):
"""Instantiate a debug compression model to be used for unit tests.
"""
seanet_kwargs = {
'n_filters': 4,
'n_residual_layers': 1,
'dimension': 32,
'ratios': [10, 8, 16] # 25 Hz at 32kHz
}
encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs)
decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs)
quantizer = qt.ResidualVectorQuantizer(dimension=32, bins=400, n_q=4)
init_x = torch.randn(8, 32, 128)
quantizer(init_x, 1) # initialize kmeans etc.
compression_model = EncodecModel(
encoder, decoder, quantizer,
frame_rate=25, sample_rate=32000, channels=1).to(device)
return compression_model.eval()
def get_debug_lm_model(device='cpu'):
"""Instantiate a debug LM to be used for unit tests.
"""
pattern = DelayedPatternProvider(n_q=4)
dim = 16
providers = {
'description': LUTConditioner(n_bins=128, dim=dim, output_dim=dim, tokenizer="whitespace"),
}
condition_provider = ConditioningProvider(providers)
fuser = ConditionFuser(
{'cross': ['description'], 'prepend': [],
'sum': [], 'input_interpolate': []})
lm = LMModel(
pattern, condition_provider, fuser,
n_q=4, card=400, dim=dim, num_heads=4, custom=True, num_layers=2,
cross_attention=True, causal=True)
return lm.to(device).eval()

View File

@ -0,0 +1,302 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from abc import ABC, abstractmethod
import typing as tp
from einops import rearrange
import torch
from torch import nn
from .. import quantization as qt
class CompressionModel(ABC, nn.Module):
@abstractmethod
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
...
@abstractmethod
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
"""See `EncodecModel.encode`"""
...
@abstractmethod
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
"""See `EncodecModel.decode`"""
...
@property
@abstractmethod
def channels(self) -> int:
...
@property
@abstractmethod
def frame_rate(self) -> int:
...
@property
@abstractmethod
def sample_rate(self) -> int:
...
@property
@abstractmethod
def cardinality(self) -> int:
...
@property
@abstractmethod
def num_codebooks(self) -> int:
...
@property
@abstractmethod
def total_codebooks(self) -> int:
...
@abstractmethod
def set_num_codebooks(self, n: int):
"""Set the active number of codebooks used by the quantizer.
"""
...
class EncodecModel(CompressionModel):
"""Encodec model operating on the raw waveform.
Args:
encoder (nn.Module): Encoder network.
decoder (nn.Module): Decoder network.
quantizer (qt.BaseQuantizer): Quantizer network.
frame_rate (int): Frame rate for the latent representation.
sample_rate (int): Audio sample rate.
channels (int): Number of audio channels.
causal (bool): Whether to use a causal version of the model.
renormalize (bool): Whether to renormalize the audio before running the model.
"""
# we need assignement to override the property in the abstract class,
# I couldn't find a better way...
frame_rate: int = 0
sample_rate: int = 0
channels: int = 0
def __init__(self,
encoder: nn.Module,
decoder: nn.Module,
quantizer: qt.BaseQuantizer,
frame_rate: int,
sample_rate: int,
channels: int,
causal: bool = False,
renormalize: bool = False):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.quantizer = quantizer
self.frame_rate = frame_rate
self.sample_rate = sample_rate
self.channels = channels
self.renormalize = renormalize
self.causal = causal
if self.causal:
# we force disabling here to avoid handling linear overlap of segments
# as supported in original EnCodec codebase.
assert not self.renormalize, 'Causal model does not support renormalize'
@property
def total_codebooks(self):
"""Total number of quantizer codebooks available.
"""
return self.quantizer.total_codebooks
@property
def num_codebooks(self):
"""Active number of codebooks used by the quantizer.
"""
return self.quantizer.num_codebooks
def set_num_codebooks(self, n: int):
"""Set the active number of codebooks used by the quantizer.
"""
self.quantizer.set_num_codebooks(n)
@property
def cardinality(self):
"""Cardinality of each codebook.
"""
return self.quantizer.bins
def preprocess(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
scale: tp.Optional[torch.Tensor]
if self.renormalize:
mono = x.mean(dim=1, keepdim=True)
volume = mono.pow(2).mean(dim=2, keepdim=True).sqrt()
scale = 1e-8 + volume
x = x / scale
scale = scale.view(-1, 1)
else:
scale = None
return x, scale
def postprocess(self,
x: torch.Tensor,
scale: tp.Optional[torch.Tensor] = None) -> torch.Tensor:
if scale is not None:
assert self.renormalize
x = x * scale.view(-1, 1, 1)
return x
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
assert x.dim() == 3
length = x.shape[-1]
x, scale = self.preprocess(x)
emb = self.encoder(x)
q_res = self.quantizer(emb, self.frame_rate)
out = self.decoder(q_res.x)
# remove extra padding added by the encoder and decoder
assert out.shape[-1] >= length, (out.shape[-1], length)
out = out[..., :length]
q_res.x = self.postprocess(out, scale)
return q_res
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
"""Encode the given input tensor to quantized representation along with scale parameter.
Args:
x (torch.Tensor): Float tensor of shape [B, C, T]
Returns:
codes, scale (tp.Tuple[torch.Tensor, torch.Tensor]): Tuple composed of:
codes a float tensor of shape [B, K, T] with K the number of codebooks used and T the timestep.
scale a float tensor containing the scale for audio renormalizealization.
"""
assert x.dim() == 3
x, scale = self.preprocess(x)
emb = self.encoder(x)
codes = self.quantizer.encode(emb)
return codes, scale
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
"""Decode the given codes to a reconstructed representation, using the scale to perform
audio denormalization if needed.
Args:
codes (torch.Tensor): Int tensor of shape [B, K, T]
scale (tp.Optional[torch.Tensor]): Float tensor containing the scale value.
Returns:
out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
"""
emb = self.quantizer.decode(codes)
out = self.decoder(emb)
out = self.postprocess(out, scale)
# out contains extra padding added by the encoder and decoder
return out
class FlattenedCompressionModel(CompressionModel):
"""Wraps a CompressionModel and flatten its codebooks, e.g.
instead of returning [B, K, T], return [B, S, T * (K // S)] with
S the number of codebooks per step, and `K // S` the number of 'virtual steps'
for each real time step.
Args:
model (CompressionModel): compression model to wrap.
codebooks_per_step (int): number of codebooks to keep per step,
this must divide the number of codebooks provided by the wrapped model.
extend_cardinality (bool): if True, and for instance if codebooks_per_step = 1,
if each codebook has a cardinality N, then the first codebook will
use the range [0, N - 1], and the second [N, 2 N - 1] etc.
On decoding, this can lead to potentially invalid sequences.
Any invalid entry will be silently remapped to the proper range
with a modulo.
"""
def __init__(self, model: CompressionModel, codebooks_per_step: int = 1,
extend_cardinality: bool = True):
super().__init__()
self.model = model
self.codebooks_per_step = codebooks_per_step
self.extend_cardinality = extend_cardinality
@property
def total_codebooks(self):
return self.model.total_codebooks
@property
def num_codebooks(self):
"""Active number of codebooks used by the quantizer.
..Warning:: this reports the number of codebooks after the flattening
of the codebooks!
"""
assert self.model.num_codebooks % self.codebooks_per_step == 0
return self.codebooks_per_step
def set_num_codebooks(self, n: int):
"""Set the active number of codebooks used by the quantizer.
..Warning:: this sets the number of codebooks **before** the flattening
of the codebooks.
"""
assert n % self.codebooks_per_step == 0
self.model.set_num_codebooks(n)
@property
def num_virtual_steps(self) -> int:
"""Return the number of virtual steps, e.g. one real step
will be split into that many steps.
"""
return self.model.num_codebooks // self.codebooks_per_step
@property
def frame_rate(self) -> int:
return self.model.frame_rate * self.num_virtual_steps
@property
def sample_rate(self) -> int:
return self.model.sample_rate
@property
def channels(self) -> int:
return self.model.channels
@property
def cardinality(self):
"""Cardinality of each codebook.
"""
if self.extend_cardinality:
return self.model.cardinality * self.num_virtual_steps
else:
return self.model.cardinality
def forward(self, x: torch.Tensor) -> qt.QuantizedResult:
raise NotImplementedError("Not supported, use encode and decode.")
def encode(self, x: torch.Tensor) -> tp.Tuple[torch.Tensor, tp.Optional[torch.Tensor]]:
indices, scales = self.model.encode(x)
B, K, T = indices.shape
indices = rearrange(indices, 'b (k v) t -> b k t v', k=self.codebooks_per_step)
if self.extend_cardinality:
for virtual_step in range(1, self.num_virtual_steps):
indices[..., virtual_step] += self.model.cardinality * virtual_step
indices = rearrange(indices, 'b k t v -> b k (t v)')
return (indices, scales)
def decode(self, codes: torch.Tensor, scale: tp.Optional[torch.Tensor] = None):
B, K, T = codes.shape
assert T % self.num_virtual_steps == 0
codes = rearrange(codes, 'b k (t v) -> b (k v) t', v=self.num_virtual_steps)
# We silently ignore potential errors from the LM when
# using extend_cardinality.
codes = codes % self.model.cardinality
return self.model.decode(codes, scale)

527
audiocraft/models/lm.py Normal file
View File

@ -0,0 +1,527 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from functools import partial
import logging
import math
import typing as tp
import torch
from torch import nn
from ..utils import utils
from ..modules.streaming import StreamingModule, State
from ..modules.transformer import StreamingTransformer, create_norm_fn
from ..modules.conditioners import (
ConditionFuser,
ClassifierFreeGuidanceDropout,
AttributeDropout,
ConditioningProvider,
ConditioningAttributes,
ConditionType,
)
from ..modules.codebooks_patterns import CodebooksPatternProvider
from ..modules.activations import get_activation_fn
logger = logging.getLogger(__name__)
ConditionTensors = tp.Dict[str, ConditionType]
CFGConditions = tp.Union[ConditionTensors, tp.Tuple[ConditionTensors, ConditionTensors]]
def get_init_fn(method: str, input_dim: int, init_depth: tp.Optional[int] = None):
"""LM layer initialization.
Inspired from xlformers: https://github.com/fairinternal/xlformers
Args:
method (str): Method name for init function. Valid options are:
'gaussian', 'uniform'.
input_dim (int): Input dimension of the initialized module.
init_depth (Optional[int]): Optional init depth value used to rescale
the standard deviation if defined.
"""
# Compute std
std = 1 / math.sqrt(input_dim)
# Rescale with depth
if init_depth is not None:
std = std / math.sqrt(2 * init_depth)
if method == 'gaussian':
return partial(
torch.nn.init.trunc_normal_, mean=0.0, std=std, a=-3 * std, b=3 * std
)
elif method == 'uniform':
bound = math.sqrt(3) * std # ensure the standard deviation is `std`
return partial(torch.nn.init.uniform_, a=-bound, b=bound)
else:
raise ValueError("Unsupported layer initialization method")
def init_layer(m: nn.Module,
method: str,
init_depth: tp.Optional[int] = None,
zero_bias_init: bool = False):
"""Wrapper around ``get_init_fn`` for proper initialization of LM modules.
Args:
m (nn.Module): Module to initialize.
method (str): Method name for the init function.
init_depth (Optional[int]): Optional init depth value used to rescale
the standard deviation if defined.
zero_bias_init (bool): Whether to initialize the bias to 0 or not.
"""
if isinstance(m, nn.Linear):
init_fn = get_init_fn(method, m.in_features, init_depth=init_depth)
if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
weight = m.weight.float()
init_fn(weight)
m.weight.data[:] = weight.half()
else:
init_fn(m.weight)
if zero_bias_init and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Embedding):
init_fn = get_init_fn(method, m.embedding_dim, init_depth=None)
if m.weight.device.type == 'cpu' and m.weight.dtype == torch.float16:
weight = m.weight.float()
init_fn(weight)
m.weight.data[:] = weight.half()
else:
init_fn(m.weight)
class ScaledEmbedding(nn.Embedding):
"""Boost learning rate for embeddings (with `scale`).
"""
def __init__(self, *args, lr=None, **kwargs):
super().__init__(*args, **kwargs)
self.lr = lr
def make_optim_group(self):
group = {"params": list(self.parameters())}
if self.lr is not None:
group["lr"] = self.lr
return group
@dataclass
class LMOutput:
# The logits are already re-aligned with the input codes
# hence no extra shift is required, e.g. when computing CE
logits: torch.Tensor # [B, K, T, card]
mask: torch.Tensor # [B, K, T]
class LMModel(StreamingModule):
"""Transformer-based language model on multiple streams of codes.
Args:
pattern_provider (CodebooksPatternProvider): Pattern provider for codebook interleaving.
condition_provider (MusicConditioningProvider): Conditioning provider from metadata.
fuser (ConditionFuser): Fuser handling the fusing of conditions with language model input.
n_q (int): Number of parallel streams to model.
card (int): Cardinality, vocabulary size.
dim (int): Dimension of the transformer encoder.
num_heads (int): Number of heads for the transformer encoder.
hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder.
norm (str): Normalization method.
norm_first (bool): Use pre-norm instead of post-norm.
emb_lr (Optional[float]): Embedding-specific learning rate.
bias_proj (bool): Use bias for output projections.
weight_init (Optional[str]): Method for weight initialization.
depthwise_init (Optional[str]): Method for depthwise weight initialization.
zero_bias_init (bool): If true and bias in Linears, initialize bias to zeros.
cfg_dropout (float): Classifier-free guidance dropout.
cfg_coef (float): Classifier-free guidance coefficient.
attribute_dropout (dict): Attribute dropout probabilities.
two_step_cfg (bool): Whether to run classifier free-guidance with 2 distinct steps.
**kwargs: Additional parameters for the transformer encoder.
"""
def __init__(self, pattern_provider: CodebooksPatternProvider, condition_provider: ConditioningProvider,
fuser: ConditionFuser, n_q: int = 8, card: int = 1024, dim: int = 128, num_heads: int = 8,
hidden_scale: int = 4, norm: str = 'layer_norm', norm_first: bool = False,
emb_lr: tp.Optional[float] = None, bias_proj: bool = True,
weight_init: tp.Optional[str] = None, depthwise_init: tp.Optional[str] = None,
zero_bias_init: bool = False, cfg_dropout: float = 0, cfg_coef: float = 1.0,
attribute_dropout: tp.Dict[str, tp.Dict[str, float]] = {}, two_step_cfg: bool = False,
**kwargs):
super().__init__()
self.cfg_coef = cfg_coef
self.cfg_dropout = ClassifierFreeGuidanceDropout(p=cfg_dropout)
self.att_dropout = AttributeDropout(p=attribute_dropout)
self.condition_provider = condition_provider
self.fuser = fuser
self.card = card
embed_dim = self.card + 1
self.n_q = n_q
self.dim = dim
self.pattern_provider = pattern_provider
self.two_step_cfg = two_step_cfg
self.emb = nn.ModuleList([ScaledEmbedding(embed_dim, dim, lr=emb_lr) for _ in range(n_q)])
if 'activation' in kwargs:
kwargs['activation'] = get_activation_fn(kwargs['activation'])
self.transformer = StreamingTransformer(
d_model=dim, num_heads=num_heads, dim_feedforward=int(hidden_scale * dim),
norm=norm, norm_first=norm_first, **kwargs)
self.out_norm: tp.Optional[nn.Module] = None
if norm_first:
self.out_norm = create_norm_fn(norm, dim)
self.linears = nn.ModuleList([nn.Linear(dim, self.card, bias=bias_proj) for _ in range(n_q)])
self._init_weights(weight_init, depthwise_init, zero_bias_init)
self._fsdp: tp.Optional[nn.Module]
self.__dict__['_fsdp'] = None
def _init_weights(self, weight_init: tp.Optional[str], depthwise_init: tp.Optional[str], zero_bias_init: bool):
"""Initialization of the transformer module weights.
Args:
weight_init (Optional[str]): Weight initialization strategy. See ``get_init_fn`` for valid options.
depthwise_init (Optional[str]): Depwthwise initialization strategy. The following options are valid:
'current' where the depth corresponds to the current layer index or 'global' where the total number
of layer is used as depth. If not set, no depthwise initialization strategy is used.
zero_bias_init (bool): Whether to initalize bias to zero or not.
"""
assert depthwise_init is None or depthwise_init in ['current', 'global']
assert depthwise_init is None or weight_init is not None, \
"If 'depthwise_init' is defined, a 'weight_init' method should be provided."
assert not zero_bias_init or weight_init is not None, \
"If 'zero_bias_init', a 'weight_init' method should be provided"
if weight_init is None:
return
for emb_layer in self.emb:
init_layer(emb_layer, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
for layer_idx, tr_layer in enumerate(self.transformer.layers):
depth = None
if depthwise_init == 'current':
depth = layer_idx + 1
elif depthwise_init == 'global':
depth = len(self.transformer.layers)
init_fn = partial(init_layer, method=weight_init, init_depth=depth, zero_bias_init=zero_bias_init)
tr_layer.apply(init_fn)
for linear in self.linears:
init_layer(linear, method=weight_init, init_depth=None, zero_bias_init=zero_bias_init)
@property
def special_token_id(self) -> int:
return self.card
@property
def num_codebooks(self) -> int:
return self.n_q
def forward(self, sequence: torch.Tensor,
conditions: tp.List[ConditioningAttributes],
condition_tensors: tp.Optional[ConditionTensors] = None) -> torch.Tensor:
"""Apply language model on sequence and conditions.
Given a tensor of sequence of shape [B, K, S] with K the number of codebooks and
S the sequence steps, return the logits with shape [B, card, K, S].
Args:
indices (torch.Tensor): indices of the codes to model.
conditions (list[ConditioningAttributes]): conditionings to use when modeling
the given codes. Note that when evaluating multiple time with the same conditioning
you should pre-compute those and pass them as `condition_tensors`.
condition_tensors (dict[str, ConditionType] or None): pre-computed conditioning
tensors, see `conditions`.
Returns:
torch.Tensor: Logits.
"""
B, K, S = sequence.shape
assert K == self.num_codebooks, 'Sequence shape must match the specified number of codebooks'
input_ = sum([self.emb[k](sequence[:, k]) for k in range(K)])
if condition_tensors is None:
assert not self._is_streaming, "Conditions tensors should be precomputed when streaming."
# apply dropout modules
conditions = self.cfg_dropout(conditions)
conditions = self.att_dropout(conditions)
tokenized = self.condition_provider.tokenize(conditions)
# encode conditions and fuse, both have a streaming cache to not recompute when generating.
condition_tensors = self.condition_provider(tokenized)
else:
assert not conditions, "Shouldn't pass both conditions and condition_tensors."
input_, cross_attention_input = self.fuser(input_, condition_tensors)
out = self.transformer(input_, cross_attention_src=cross_attention_input)
if self.out_norm:
out = self.out_norm(out)
logits = torch.stack([self.linears[k](out) for k in range(K)], dim=1) # [B, K, S, card]
# remove the prefix from the model outputs
if len(self.fuser.fuse2cond['prepend']) > 0:
logits = logits[:, :, -S:]
return logits # [B, K, S, card]
def compute_predictions(
self, codes: torch.Tensor,
conditions: tp.List[ConditioningAttributes],
condition_tensors: tp.Optional[ConditionTensors] = None) -> LMOutput:
"""Given an input tensor of codes [B, K, T] and list of conditions, runs the model
forward using the specified codes interleaving pattern.
Args:
codes (torch.Tensor): Input codes of shape [B, K, T] with B the batch size,
K the number of codebooks and T the number of timesteps.
conditions (list[ConditioningAttributes]): conditionings to use when modeling
the given codes. Note that when evaluating multiple time with the same conditioning
you should pre-compute those and pass them as `condition_tensors`.
condition_tensors (dict[str, ConditionType] or None): pre-computed conditioning
tensors, see `conditions`.
Returns:
LMOutput: Language model outputs
logits (torch.Tensor) of shape [B, K, T, card] corresponding to the provided codes,
i.e. the first item corresponds to logits to predict the first code, meaning that
no additional shifting of codes and logits is required.
mask (torch.Tensor) of shape [B, K, T], mask over valid and invalid positions.
Given the specified interleaving strategies, parts of the logits and codes should
not be considered as valid predictions because of invalid context.
"""
B, K, T = codes.shape
codes = codes.contiguous()
# map codes [B, K, T] into pattern sequence [B, K, S] using special_token_id for masked tokens
pattern = self.pattern_provider.get_pattern(T)
sequence_codes, sequence_indexes, sequence_mask = pattern.build_pattern_sequence(
codes, self.special_token_id, keep_only_valid_steps=True
)
# apply model on pattern sequence
model = self if self._fsdp is None else self._fsdp
logits = model(sequence_codes, conditions, condition_tensors) # [B, K, S, card]
# map back the logits on pattern sequence to logits on original codes: [B, K, S, card] -> [B, K, T, card]
# and provide the corresponding mask over invalid positions of tokens
logits = logits.permute(0, 3, 1, 2) # [B, card, K, S]
# note: we use nans as special token to make it obvious if we feed unexpected logits
logits, logits_indexes, logits_mask = pattern.revert_pattern_logits(
logits, float('nan'), keep_only_valid_steps=True
)
logits = logits.permute(0, 2, 3, 1) # [B, K, T, card]
logits_mask = logits_mask[None, :, :].expand(B, -1, -1) # [K, T] -> [B, K, T]
return LMOutput(logits, logits_mask)
def _sample_next_token(self,
sequence: torch.Tensor,
cfg_conditions: CFGConditions,
unconditional_state: State,
use_sampling: bool = False,
temp: float = 1.0,
top_k: int = 0,
top_p: float = 0.0,
cfg_coef: tp.Optional[float] = None) -> torch.Tensor:
"""Sample next token from the model given a sequence and a set of conditions. The model supports
multiple sampling strategies (greedy sampling, softmax, top-k, top-p...).
Args:
sequence (torch.Tensor): Current sequence of shape [B, K, S]
with K corresponding to the number of codebooks and S the number of sequence steps.
S = 1 in streaming mode, except for the first step that contains a bigger prompt.
condition_tensors (Dict[str, ConditionType): Set of conditions. If CFG is used,
should be twice the batch size, being the concatenation of the conditions + null conditions.
use_sampling (bool): Whether to use a sampling strategy or not.
temp (float): Sampling temperature.
top_k (int): K for "top-k" sampling.
top_p (float): P for "top-p" sampling.
cfg_coef (float): classifier free guidance coefficient
Returns:
next_token (torch.Tensor): Next token tensor of shape [B, K, 1].
"""
B = sequence.shape[0]
cfg_coef = self.cfg_coef if cfg_coef is None else cfg_coef
model = self if self._fsdp is None else self._fsdp
if self.two_step_cfg and cfg_conditions != {}:
assert isinstance(cfg_conditions, tuple)
condition_tensors, null_condition_tensors = cfg_conditions
cond_logits = model(sequence, conditions=[], condition_tensors=condition_tensors)
state = self.get_streaming_state()
self.set_streaming_state(unconditional_state)
uncond_logits = model(sequence, conditions=[], condition_tensors=null_condition_tensors)
unconditional_state.update(self.get_streaming_state())
self.set_streaming_state(state)
logits = uncond_logits + (cond_logits - uncond_logits) * self.cfg_coef
else:
assert isinstance(cfg_conditions, dict)
condition_tensors = cfg_conditions
if condition_tensors:
# Preparing for CFG, predicting both conditional and unconditional logits.
sequence = torch.cat([sequence, sequence], dim=0)
all_logits = model(
sequence,
conditions=[], condition_tensors=condition_tensors)
if condition_tensors:
cond_logits, uncond_logits = all_logits.split(B, dim=0) # [B, K, T, card]
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_coef
else:
logits = all_logits
logits = logits.permute(0, 1, 3, 2) # [B, K, card, T]
logits = logits[..., -1] # [B x K x card]
# Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
if use_sampling and temp > 0.0:
probs = torch.softmax(logits / temp, dim=-1)
if top_p > 0.0:
next_token = utils.sample_top_p(probs, p=top_p)
elif top_k > 0:
next_token = utils.sample_top_k(probs, k=top_k)
else:
next_token = utils.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(logits, dim=-1, keepdim=True)
return next_token
@torch.no_grad()
def generate(self,
prompt: tp.Optional[torch.Tensor] = None,
conditions: tp.List[ConditioningAttributes] = [],
num_samples: tp.Optional[int] = None,
max_gen_len: int = 256,
use_sampling: bool = True,
temp: float = 1.0,
top_k: int = 250,
top_p: float = 0.0,
cfg_coef: tp.Optional[float] = None,
two_step_cfg: bool = False,
remove_prompts: bool = False,
check: bool = False,
callback: tp.Optional[tp.Callable[[int, int], None]] = None) -> torch.Tensor:
"""Generate tokens sampling from the model given a prompt or unconditionally. Generation can
be perform in a greedy fashion or using sampling with top K and top P strategies.
Args:
prompt (Optional[torch.Tensor]): Prompt tokens of shape [B, K, T].
conditions_tensors (Dict[str, torch.Tensor]): Set of conditions or None.
num_samples (int or None): Number of samples to generate when no prompt and no conditions are given.
max_gen_len (int): Maximum generation length.
use_sampling (bool): Whether to use a sampling strategy or not.
temp (float): Sampling temperature.
top_k (int): K for "top-k" sampling.
top_p (float): P for "top-p" sampling.
remove_prompts (bool): Whether to remove prompts from generation or not.
Returns:
torch.Tensor: Generated tokens.
"""
assert not self.training, "generation shouldn't be used in training mode."
first_param = next(iter(self.parameters()))
device = first_param.device
# Checking all input shapes are consistents.
possible_num_samples = []
if num_samples is not None:
possible_num_samples.append(num_samples)
elif prompt is not None:
possible_num_samples.append(prompt.shape[0])
elif conditions:
possible_num_samples.append(len(conditions))
else:
possible_num_samples.append(1)
assert [x == possible_num_samples[0] for x in possible_num_samples], "Inconsitent inputs shapes"
num_samples = possible_num_samples[0]
# below we create set of conditions: one conditional and one unconditional
# to do that we merge the regular condition together with the null condition
# we then do 1 forward pass instead of 2.
# the reason for that is two-fold:
# 1. it is about x2 faster than doing 2 forward passes
# 2. avoid the streaming API treating the 2 passes as part of different time steps
# We also support doing two different passes, in particular to ensure that
# the padding structure is exactly the same between train anf test.
# With a batch size of 1, this can be slower though.
cfg_conditions: CFGConditions
two_step_cfg = self.two_step_cfg if two_step_cfg is None else two_step_cfg
if conditions:
null_conditions = ClassifierFreeGuidanceDropout(p=1.0)(conditions)
if two_step_cfg:
cfg_conditions = (
self.condition_provider(self.condition_provider.tokenize(conditions)),
self.condition_provider(self.condition_provider.tokenize(null_conditions)),
)
else:
conditions = conditions + null_conditions
tokenized = self.condition_provider.tokenize(conditions)
cfg_conditions = self.condition_provider(tokenized)
else:
cfg_conditions = {}
if prompt is None:
assert num_samples > 0
prompt = torch.zeros((num_samples, self.num_codebooks, 0), dtype=torch.long, device=device)
B, K, T = prompt.shape
start_offset = T
assert start_offset < max_gen_len
pattern = self.pattern_provider.get_pattern(max_gen_len)
# this token is used as default value for codes that are not generated yet
unknown_token = -1
# we generate codes up to the max_gen_len that will be mapped to the pattern sequence
gen_codes = torch.full((B, K, max_gen_len), unknown_token, dtype=torch.long, device=device)
# filling the gen_codes with the prompt if needed
gen_codes[..., :start_offset] = prompt
# create the gen_sequence with proper interleaving from the pattern: [B, K, S]
gen_sequence, indexes, mask = pattern.build_pattern_sequence(gen_codes, self.special_token_id)
# retrieve the start_offset in the sequence:
# it is the first sequence step that contains the `start_offset` timestep
start_offset_sequence = pattern.get_first_step_with_timesteps(start_offset)
assert start_offset_sequence is not None
with self.streaming():
unconditional_state = self.get_streaming_state()
prev_offset = 0
gen_sequence_len = gen_sequence.shape[-1] # gen_sequence shape is [B, K, S]
for offset in range(start_offset_sequence, gen_sequence_len):
# get current sequence (note that the streaming API is providing the caching over previous offsets)
curr_sequence = gen_sequence[..., prev_offset:offset]
curr_mask = mask[None, ..., prev_offset:offset].expand(B, -1, -1)
if check:
# check coherence between mask and sequence
assert (curr_sequence == torch.where(curr_mask, curr_sequence, self.special_token_id)).all()
# should never happen as gen_sequence is filled progressively
assert not (curr_sequence == unknown_token).any()
# sample next token from the model, next token shape is [B, K, 1]
next_token = self._sample_next_token(
curr_sequence, cfg_conditions, unconditional_state, use_sampling, temp, top_k, top_p,
cfg_coef=cfg_coef)
# ensure the tokens that should be masked are properly set to special_token_id
# as the model never output special_token_id
valid_mask = mask[..., offset:offset+1].expand(B, -1, -1)
next_token[~valid_mask] = self.special_token_id
# ensure we don't overwrite prompt tokens, we only write over unknown tokens
# (then mask tokens should be left as is as well, which is correct)
gen_sequence[..., offset:offset+1] = torch.where(
gen_sequence[..., offset:offset+1] == unknown_token,
next_token, gen_sequence[..., offset:offset+1]
)
prev_offset = offset
if callback is not None:
callback(1 + offset - start_offset_sequence, gen_sequence_len - start_offset_sequence)
unconditional_state.clear()
# ensure sequence has been entirely filled
assert not (gen_sequence == unknown_token).any()
# ensure gen_sequence pattern and mask are matching
# which means the gen_sequence is valid according to the pattern
assert (
gen_sequence == torch.where(mask[None, ...].expand(B, -1, -1), gen_sequence, self.special_token_id)
).all()
# get back the codes, trimming the prompt if needed and cutting potentially incomplete timesteps
out_codes, out_indexes, out_mask = pattern.revert_pattern_sequence(gen_sequence, special_token=unknown_token)
# sanity checks over the returned codes and corresponding masks
assert (out_codes[..., :max_gen_len] != unknown_token).all()
assert (out_mask[..., :max_gen_len] == 1).all()
out_start_offset = start_offset if remove_prompts else 0
out_codes = out_codes[..., out_start_offset:max_gen_len]
# ensure the returned codes are all valid
assert (out_codes >= 0).all() and (out_codes <= self.card).all()
return out_codes

View File

@ -0,0 +1,90 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Utility functions to load from the checkpoints.
Each checkpoint is a torch.saved dict with the following keys:
- 'xp.cfg': the hydra config as dumped during training. This should be used
to rebuild the object using the audiocraft.models.builders functions,
- 'model_best_state': a readily loadable best state for the model, including
the conditioner. The model obtained from `xp.cfg` should be compatible
with this state dict. In the case of a LM, the encodec model would not be
bundled along but instead provided separately.
Those functions also support loading from a remote location with the Torch Hub API.
They also support overriding some parameters, in particular the device and dtype
of the returned model.
"""
from pathlib import Path
from huggingface_hub import hf_hub_download
import typing as tp
import os
from omegaconf import OmegaConf
import torch
from . import builders
HF_MODEL_CHECKPOINTS_MAP = {
"small": "facebook/musicgen-small",
"medium": "facebook/musicgen-medium",
"large": "facebook/musicgen-large",
"melody": "facebook/musicgen-melody",
}
def _get_state_dict(
file_or_url_or_id: tp.Union[Path, str],
filename: tp.Optional[str] = None,
device='cpu',
cache_dir: tp.Optional[str] = None,
):
# Return the state dict either from a file or url
file_or_url_or_id = str(file_or_url_or_id)
assert isinstance(file_or_url_or_id, str)
if os.path.isfile(file_or_url_or_id):
return torch.load(file_or_url_or_id, map_location=device)
elif file_or_url_or_id.startswith('https://'):
return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True)
elif file_or_url_or_id in HF_MODEL_CHECKPOINTS_MAP:
assert filename is not None, "filename needs to be defined if using HF checkpoints"
repo_id = HF_MODEL_CHECKPOINTS_MAP[file_or_url_or_id]
file = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir)
return torch.load(file, map_location=device)
else:
raise ValueError(f"{file_or_url_or_id} is not a valid name, path or link that can be loaded.")
def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
pkg = _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir)
cfg = OmegaConf.create(pkg['xp.cfg'])
cfg.device = str(device)
model = builders.get_compression_model(cfg)
model.load_state_dict(pkg['best_state'])
model.eval()
return model
def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
pkg = _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir)
cfg = OmegaConf.create(pkg['xp.cfg'])
cfg.device = str(device)
if cfg.device == 'cpu':
cfg.dtype = 'float32'
else:
cfg.dtype = 'float16'
model = builders.get_lm_model(cfg)
model.load_state_dict(pkg['best_state'])
model.eval()
model.cfg = cfg
return model

View File

@ -0,0 +1,361 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Main model for using MusicGen. This will combine all the required components
and provide easy access to the generation API.
"""
import os
import typing as tp
import torch
from .encodec import CompressionModel
from .lm import LMModel
from .builders import get_debug_compression_model, get_debug_lm_model
from .loaders import load_compression_model, load_lm_model, HF_MODEL_CHECKPOINTS_MAP
from ..data.audio_utils import convert_audio
from ..modules.conditioners import ConditioningAttributes, WavCondition
from ..utils.autocast import TorchAutocast
MelodyList = tp.List[tp.Optional[torch.Tensor]]
MelodyType = tp.Union[torch.Tensor, MelodyList]
class MusicGen:
"""MusicGen main model with convenient generation API.
Args:
name (str): name of the model.
compression_model (CompressionModel): Compression model
used to map audio to invertible discrete representations.
lm (LMModel): Language model over discrete representations.
"""
def __init__(self, name: str, compression_model: CompressionModel, lm: LMModel,
max_duration: float = 30):
self.name = name
self.compression_model = compression_model
self.lm = lm
self.max_duration = max_duration
self.device = next(iter(lm.parameters())).device
self.generation_params: dict = {}
self.set_generation_params(duration=15) # 15 seconds by default
self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
if self.device.type == 'cpu':
self.autocast = TorchAutocast(enabled=False)
else:
self.autocast = TorchAutocast(
enabled=True, device_type=self.device.type, dtype=torch.float16)
@property
def frame_rate(self) -> int:
"""Roughly the number of AR steps per seconds."""
return self.compression_model.frame_rate
@property
def sample_rate(self) -> int:
"""Sample rate of the generated audio."""
return self.compression_model.sample_rate
@property
def audio_channels(self) -> int:
"""Audio channels of the generated audio."""
return self.compression_model.channels
@staticmethod
def get_pretrained(name: str = 'melody', device=None):
"""Return pretrained model, we provide four models:
- small (300M), text to music, # see: https://huggingface.co/facebook/musicgen-small
- medium (1.5B), text to music, # see: https://huggingface.co/facebook/musicgen-medium
- melody (1.5B) text to music and text+melody to music, # see: https://huggingface.co/facebook/musicgen-melody
- large (3.3B), text to music, # see: https://huggingface.co/facebook/musicgen-large
"""
if device is None:
if torch.cuda.device_count():
device = 'cuda'
else:
device = 'cpu'
if name == 'debug':
# used only for unit tests
compression_model = get_debug_compression_model(device)
lm = get_debug_lm_model(device)
return MusicGen(name, compression_model, lm)
if name not in HF_MODEL_CHECKPOINTS_MAP:
raise ValueError(
f"{name} is not a valid checkpoint name. "
f"Choose one of {', '.join(HF_MODEL_CHECKPOINTS_MAP.keys())}"
)
cache_dir = os.environ.get('MUSICGEN_ROOT', None)
compression_model = load_compression_model(name, device=device, cache_dir=cache_dir)
lm = load_lm_model(name, device=device, cache_dir=cache_dir)
if name == 'melody':
lm.condition_provider.conditioners['self_wav'].match_len_on_eval = True
return MusicGen(name, compression_model, lm)
def set_generation_params(self, use_sampling: bool = True, top_k: int = 250,
top_p: float = 0.0, temperature: float = 1.0,
duration: float = 30.0, cfg_coef: float = 3.0,
two_step_cfg: bool = False, extend_stride: float = 18):
"""Set the generation parameters for MusicGen.
Args:
use_sampling (bool, optional): Use sampling if True, else do argmax decoding. Defaults to True.
top_k (int, optional): top_k used for sampling. Defaults to 250.
top_p (float, optional): top_p used for sampling, when set to 0 top_k is used. Defaults to 0.0.
temperature (float, optional): Softmax temperature parameter. Defaults to 1.0.
duration (float, optional): Duration of the generated waveform. Defaults to 30.0.
cfg_coef (float, optional): Coefficient used for classifier free guidance. Defaults to 3.0.
two_step_cfg (bool, optional): If True, performs 2 forward for Classifier Free Guidance,
instead of batching together the two. This has some impact on how things
are padded but seems to have little impact in practice.
extend_stride: when doing extended generation (i.e. more than 30 seconds), by how much
should we extend the audio each time. Larger values will mean less context is
preserved, and shorter value will require extra computations.
"""
assert extend_stride < self.max_duration, "Cannot stride by more than max generation duration."
self.extend_stride = extend_stride
self.duration = duration
self.generation_params = {
'use_sampling': use_sampling,
'temp': temperature,
'top_k': top_k,
'top_p': top_p,
'cfg_coef': cfg_coef,
'two_step_cfg': two_step_cfg,
}
def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
"""Override the default progress callback."""
self._progress_callback = progress_callback
def generate_unconditional(self, num_samples: int, progress: bool = False) -> torch.Tensor:
"""Generate samples in an unconditional manner.
Args:
num_samples (int): Number of samples to be generated.
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
"""
descriptions: tp.List[tp.Optional[str]] = [None] * num_samples
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
return self._generate_tokens(attributes, prompt_tokens, progress)
def generate(self, descriptions: tp.List[str], progress: bool = False) -> torch.Tensor:
"""Generate samples conditioned on text.
Args:
descriptions (tp.List[str]): A list of strings used as text conditioning.
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
"""
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, None)
assert prompt_tokens is None
return self._generate_tokens(attributes, prompt_tokens, progress)
def generate_with_chroma(self, descriptions: tp.List[str], melody_wavs: MelodyType,
melody_sample_rate: int, progress: bool = False) -> torch.Tensor:
"""Generate samples conditioned on text and melody.
Args:
descriptions (tp.List[str]): A list of strings used as text conditioning.
melody_wavs: (torch.Tensor or list of Tensor): A batch of waveforms used as
melody conditioning. Should have shape [B, C, T] with B matching the description length,
C=1 or 2. It can be [C, T] if there is a single description. It can also be
a list of [C, T] tensors.
melody_sample_rate: (int): Sample rate of the melody waveforms.
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
"""
if isinstance(melody_wavs, torch.Tensor):
if melody_wavs.dim() == 2:
melody_wavs = melody_wavs[None]
if melody_wavs.dim() != 3:
raise ValueError("Melody wavs should have a shape [B, C, T].")
melody_wavs = list(melody_wavs)
else:
for melody in melody_wavs:
if melody is not None:
assert melody.dim() == 2, "One melody in the list has the wrong number of dims."
melody_wavs = [
convert_audio(wav, melody_sample_rate, self.sample_rate, self.audio_channels)
if wav is not None else None
for wav in melody_wavs]
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions=descriptions, prompt=None,
melody_wavs=melody_wavs)
assert prompt_tokens is None
return self._generate_tokens(attributes, prompt_tokens, progress)
def generate_continuation(self, prompt: torch.Tensor, prompt_sample_rate: int,
descriptions: tp.Optional[tp.List[tp.Optional[str]]] = None,
progress: bool = False) -> torch.Tensor:
"""Generate samples conditioned on audio prompts.
Args:
prompt (torch.Tensor): A batch of waveforms used for continuation.
Prompt should be [B, C, T], or [C, T] if only one sample is generated.
prompt_sample_rate (int): Sampling rate of the given audio waveforms.
descriptions (tp.List[str], optional): A list of strings used as text conditioning. Defaults to None.
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
"""
if prompt.dim() == 2:
prompt = prompt[None]
if prompt.dim() != 3:
raise ValueError("prompt should have 3 dimensions: [B, C, T] (C = 1).")
prompt = convert_audio(prompt, prompt_sample_rate, self.sample_rate, self.audio_channels)
if descriptions is None:
descriptions = [None] * len(prompt)
attributes, prompt_tokens = self._prepare_tokens_and_attributes(descriptions, prompt)
assert prompt_tokens is not None
return self._generate_tokens(attributes, prompt_tokens, progress)
@torch.no_grad()
def _prepare_tokens_and_attributes(
self,
descriptions: tp.Sequence[tp.Optional[str]],
prompt: tp.Optional[torch.Tensor],
melody_wavs: tp.Optional[MelodyList] = None,
) -> tp.Tuple[tp.List[ConditioningAttributes], tp.Optional[torch.Tensor]]:
"""Prepare model inputs.
Args:
descriptions (tp.List[str]): A list of strings used as text conditioning.
prompt (torch.Tensor): A batch of waveforms used for continuation.
melody_wavs (tp.Optional[torch.Tensor], optional): A batch of waveforms
used as melody conditioning. Defaults to None.
"""
attributes = [
ConditioningAttributes(text={'description': description})
for description in descriptions]
if melody_wavs is None:
for attr in attributes:
attr.wav['self_wav'] = WavCondition(
torch.zeros((1, 1), device=self.device),
torch.tensor([0], device=self.device),
path='null_wav') # type: ignore
else:
if self.name != "melody":
raise RuntimeError("This model doesn't support melody conditioning. "
"Use the `melody` model.")
assert len(melody_wavs) == len(descriptions), \
f"number of melody wavs must match number of descriptions! " \
f"got melody len={len(melody_wavs)}, and descriptions len={len(descriptions)}"
for attr, melody in zip(attributes, melody_wavs):
if melody is None:
attr.wav['self_wav'] = WavCondition(
torch.zeros((1, 1), device=self.device),
torch.tensor([0], device=self.device),
path='null_wav') # type: ignore
else:
attr.wav['self_wav'] = WavCondition(
melody.to(device=self.device),
torch.tensor([melody.shape[-1]], device=self.device))
if prompt is not None:
if descriptions is not None:
assert len(descriptions) == len(prompt), "Prompt and nb. descriptions doesn't match"
prompt = prompt.to(self.device)
prompt_tokens, scale = self.compression_model.encode(prompt)
assert scale is None
else:
prompt_tokens = None
return attributes, prompt_tokens
def _generate_tokens(self, attributes: tp.List[ConditioningAttributes],
prompt_tokens: tp.Optional[torch.Tensor], progress: bool = False) -> torch.Tensor:
"""Generate discrete audio tokens given audio prompt and/or conditions.
Args:
attributes (tp.List[ConditioningAttributes]): Conditions used for generation (text/melody).
prompt_tokens (tp.Optional[torch.Tensor]): Audio prompt used for continuation.
progress (bool, optional): Flag to display progress of the generation process. Defaults to False.
Returns:
torch.Tensor: Generated audio, of shape [B, C, T], T is defined by the generation params.
"""
total_gen_len = int(self.duration * self.frame_rate)
max_prompt_len = int(min(self.duration, self.max_duration) * self.frame_rate)
current_gen_offset: int = 0
def _progress_callback(generated_tokens: int, tokens_to_generate: int):
generated_tokens += current_gen_offset
if self._progress_callback is not None:
# Note that total_gen_len might be quite wrong depending on the
# codebook pattern used, but with delay it is almost accurate.
self._progress_callback(generated_tokens, total_gen_len)
else:
print(f'{generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
if prompt_tokens is not None:
assert max_prompt_len >= prompt_tokens.shape[-1], \
"Prompt is longer than audio to generate"
callback = None
if progress:
callback = _progress_callback
if self.duration <= self.max_duration:
# generate by sampling from LM, simple case.
with self.autocast:
gen_tokens = self.lm.generate(
prompt_tokens, attributes,
callback=callback, max_gen_len=total_gen_len, **self.generation_params)
else:
# now this gets a bit messier, we need to handle prompts,
# melody conditioning etc.
ref_wavs = [attr.wav['self_wav'] for attr in attributes]
all_tokens = []
if prompt_tokens is None:
prompt_length = 0
else:
all_tokens.append(prompt_tokens)
prompt_length = prompt_tokens.shape[-1]
stride_tokens = int(self.frame_rate * self.extend_stride)
while current_gen_offset + prompt_length < total_gen_len:
time_offset = current_gen_offset / self.frame_rate
chunk_duration = min(self.duration - time_offset, self.max_duration)
max_gen_len = int(chunk_duration * self.frame_rate)
for attr, ref_wav in zip(attributes, ref_wavs):
wav_length = ref_wav.length.item()
if wav_length == 0:
continue
# We will extend the wav periodically if it not long enough.
# we have to do it here rather than in conditioners.py as otherwise
# we wouldn't have the full wav.
initial_position = int(time_offset * self.sample_rate)
wav_target_length = int(self.max_duration * self.sample_rate)
print(initial_position / self.sample_rate, wav_target_length / self.sample_rate)
positions = torch.arange(initial_position,
initial_position + wav_target_length, device=self.device)
attr.wav['self_wav'] = WavCondition(
ref_wav[0][:, positions % wav_length],
torch.full_like(ref_wav[1], wav_target_length))
with self.autocast:
gen_tokens = self.lm.generate(
prompt_tokens, attributes,
callback=callback, max_gen_len=max_gen_len, **self.generation_params)
if prompt_tokens is None:
all_tokens.append(gen_tokens)
else:
all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
prompt_tokens = gen_tokens[:, :, stride_tokens:]
prompt_length = prompt_tokens.shape[-1]
current_gen_offset += stride_tokens
gen_tokens = torch.cat(all_tokens, dim=-1)
# generate audio
assert gen_tokens.dim() == 3
with torch.no_grad():
gen_audio = self.compression_model.decode(gen_tokens, None)
return gen_audio

View File

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# flake8: noqa
from .conv import (
NormConv1d,
NormConv2d,
NormConvTranspose1d,
NormConvTranspose2d,
StreamableConv1d,
StreamableConvTranspose1d,
pad_for_conv1d,
pad1d,
unpad1d,
)
from .lstm import StreamableLSTM
from .seanet import SEANetEncoder, SEANetDecoder

View File

@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from torch import Tensor
from typing import Union, Callable
class CustomGLU(nn.Module):
"""Custom Gated Linear Unit activation.
Applies a modified gated linear unit :math:`a * f(b)` where :math:`a` is the first half
of the input matrices, :math:`b` is the second half, and :math:`f` is a provided activation
function (i.e. sigmoid, swish, etc.).
Args:
activation (nn.Module): The custom activation to apply in the Gated Linear Unit
dim (int): the dimension on which to split the input. Default: -1
Shape:
- Input: :math:`(\ast_1, N, \ast_2)` where `*` means, any number of additional
dimensions
- Output: :math:`(\ast_1, M, \ast_2)` where :math:`M=N/2`
Examples::
>>> m = CustomGLU(nn.Sigmoid())
>>> input = torch.randn(4, 2)
>>> output = m(input)
"""
def __init__(self, activation: nn.Module, dim: int = -1):
super(CustomGLU, self).__init__()
self.dim = dim
self.activation = activation
def forward(self, x: Tensor):
assert x.shape[self.dim] % 2 == 0 # M = N / 2
a, b = torch.chunk(x, 2, dim=self.dim)
return a * self.activation(b)
class SwiGLU(CustomGLU):
"""SiLU Gated Linear Unit activation.
Applies SiLU Gated Linear Unit :math:`a * SiLU(b)` where :math:`a` is
the first half of the input matrices, :math:`b` is the second half.
Args:
dim (int): the dimension on which to split the input. Default: -1
"""
def __init__(self, dim: int = -1):
super(SwiGLU, self).__init__(nn.SiLU(), dim)
class GeGLU(CustomGLU):
"""GeLU Gated Linear Unit activation.
Applies GeLU Gated Linear Unit :math:`a * GELU(b)` where :math:`a` is
the first half of the input matrices, :math:`b` is the second half.
Args:
dim (int): the dimension on which to split the input. Default: -1
"""
def __init__(self, dim: int = -1):
super(GeGLU, self).__init__(nn.GELU(), dim)
class ReGLU(CustomGLU):
"""ReLU Gated Linear Unit activation.
Applies ReLU Gated Linear Unit :math:`a * ReLU(b)` where :math:`a` is
the first half of the input matrices, :math:`b` is the second half.
Args:
dim (int): the dimension on which to split the input. Default: -1
"""
def __init__(self, dim: int = -1):
super(ReGLU, self).__init__(nn.ReLU(), dim)
def get_activation_fn(
activation: Union[str, Callable[[Tensor], Tensor]]
) -> Union[str, Callable[[Tensor], Tensor]]:
"""Helper function to map an activation string to the activation class.
If the supplied activation is not a string that is recognized, the activation is passed back.
Args:
activation (Union[str, Callable[[Tensor], Tensor]]): Activation to check
"""
if isinstance(activation, str):
if activation == "reglu":
return ReGLU()
elif activation == "geglu":
return GeGLU()
elif activation == "swiglu":
return SwiGLU()
return activation

View File

@ -0,0 +1,539 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from collections import namedtuple
from dataclasses import dataclass
from functools import lru_cache
import logging
import typing as tp
from abc import ABC, abstractmethod
import torch
LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index)
PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates
logger = logging.getLogger(__name__)
@dataclass
class Pattern:
"""Base implementation of a pattern over a sequence with multiple codebooks.
The codebook pattern consists in a layout, defining for each sequence step
the list of coordinates of each codebook timestep in the resulting interleaved sequence.
The first item of the pattern is always an empty list in order to properly insert a special token
to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern
and ``timesteps`` the number of timesteps corresponding to the original sequence.
The pattern provides convenient methods to build and revert interleaved sequences from it:
``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
to the interleaved sequence of shape [B, K, S] applying the pattern, with S being the batch size,
K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
for the output sequence. The unfilled positions are replaced with a special token and the built sequence
is returned along with a mask indicating valid tokens.
``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment
of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask
to fill and specify invalid positions if needed.
See the dedicated methods for more details.
"""
# Pattern layout, for each sequence step, we have a list of coordinates
# corresponding to the original codebook timestep and position.
# The first list is always an empty list in order to properly insert
# a special token to start with.
layout: PatternLayout
timesteps: int
n_q: int
def __post_init__(self):
assert len(self.layout) > 0
assert self.layout[0] == []
self._validate_layout()
self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
logger.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
def _validate_layout(self):
"""Runs checks on the layout to ensure a valid pattern is defined.
A pattern is considered invalid if:
- Multiple timesteps for a same codebook are defined in the same sequence step
- The timesteps for a given codebook are not in ascending order as we advance in the sequence
(this would mean that we have future timesteps before past timesteps).
"""
q_timesteps = {q: 0 for q in range(self.n_q)}
for s, seq_coords in enumerate(self.layout):
if len(seq_coords) > 0:
qs = set()
for coord in seq_coords:
qs.add(coord.q)
last_q_timestep = q_timesteps[coord.q]
assert coord.t >= last_q_timestep, \
f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}"
q_timesteps[coord.q] = coord.t
# each sequence step contains at max 1 coordinate per codebook
assert len(qs) == len(seq_coords), \
f"Multiple entries for a same codebook are found at step {s}"
@property
def num_sequence_steps(self):
return len(self.layout) - 1
@property
def max_delay(self):
max_t_in_seq_coords = 0
for seq_coords in self.layout[1:]:
for coords in seq_coords:
max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
return max_t_in_seq_coords - self.timesteps
@property
def valid_layout(self):
valid_step = len(self.layout) - self.max_delay
return self.layout[:valid_step]
def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
"""Get codebook coordinates in the layout that corresponds to the specified timestep t
and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
and the actual codebook coordinates.
"""
assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps"
if q is not None:
assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks"
coords = []
for s, seq_codes in enumerate(self.layout):
for code in seq_codes:
if code.t == t and (q is None or code.q == q):
coords.append((s, code))
return coords
def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]:
return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]:
steps_with_timesteps = self.get_steps_with_timestep(t, q)
return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool,
device: tp.Union[torch.device, str] = 'cpu'):
"""Build scatter indexes corresponding to the pattern, up to the provided sequence_steps.
Args:
timesteps (int): Maximum number of timesteps steps to consider.
keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
device (Union[torch.device, str]): Device for created tensors.
Returns:
indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
"""
assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern"
# use the proper layout based on whether we limit ourselves to valid steps only or not,
# note that using the valid_layout will result in a truncated sequence up to the valid steps
ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
# single item indexing being super slow with pytorch vs. numpy, so we use numpy here
indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy()
mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy()
# fill indexes with last sequence step value that will correspond to our special token
# the last value is n_q * timesteps as we have flattened z and append special token as the last token
# which will correspond to the index: n_q * timesteps
indexes[:] = n_q * timesteps
# iterate over the pattern and fill scattered indexes and mask
for s, sequence_coords in enumerate(ref_layout):
for coords in sequence_coords:
if coords.t < timesteps:
indexes[coords.q, s] = coords.t + coords.q * timesteps
mask[coords.q, s] = 1
indexes = torch.from_numpy(indexes).to(device)
mask = torch.from_numpy(mask).to(device)
return indexes, mask
def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
"""Build sequence corresponding to the pattern from the input tensor z.
The sequence is built using up to sequence_steps if specified, and non-pattern
coordinates are filled with the special token.
Args:
z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T].
special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
Steps that are beyond valid steps will be replaced by the special_token in that case.
Returns:
values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S
corresponding either to the sequence_steps if provided, otherwise to the length of the pattern.
indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S].
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].
"""
B, K, T = z.shape
indexes, mask = self._build_pattern_sequence_scatter_indexes(
T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
)
z = z.view(B, -1)
# we append the special token as the last index of our flattened z tensor
z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
values = z[:, indexes.view(-1)]
values = values.view(B, K, indexes.shape[-1])
return values, indexes, mask
def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int,
keep_only_valid_steps: bool = False,
is_model_output: bool = False,
device: tp.Union[torch.device, str] = 'cpu'):
"""Builds scatter indexes required to retrieve the original multi-codebook sequence
from interleaving pattern.
Args:
sequence_steps (int): Sequence steps.
n_q (int): Number of codebooks.
keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
Steps that are beyond valid steps will be replaced by the special_token in that case.
is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
device (Union[torch.device, str]): Device for created tensors.
Returns:
torch.Tensor: Indexes for reconstructing the output, of shape [K, T].
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
"""
ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
# TODO(jade): Do we want to further truncate to only valid timesteps here as well?
timesteps = self.timesteps
assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
assert sequence_steps <= len(ref_layout), \
f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
# ensure we take the appropriate indexes to keep the model output from the first special token as well
if is_model_output:
ref_layout = ref_layout[1:]
# single item indexing being super slow with pytorch vs. numpy, so we use numpy here
indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy()
mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy()
# fill indexes with last sequence step value that will correspond to our special token
indexes[:] = n_q * sequence_steps
for s, sequence_codes in enumerate(ref_layout):
if s < sequence_steps:
for code in sequence_codes:
if code.t < timesteps:
indexes[code.q, code.t] = s + code.q * sequence_steps
mask[code.q, code.t] = 1
indexes = torch.from_numpy(indexes).to(device)
mask = torch.from_numpy(mask).to(device)
return indexes, mask
def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
"""Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving.
The sequence is reverted using up to timesteps if specified, and non-pattern coordinates
are filled with the special token.
Args:
s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
Returns:
values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T
corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise.
indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
"""
B, K, S = s.shape
indexes, mask = self._build_reverted_sequence_scatter_indexes(
S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
)
s = s.view(B, -1)
# we append the special token as the last index of our flattened z tensor
s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
values = s[:, indexes.view(-1)]
values = values.view(B, K, indexes.shape[-1])
return values, indexes, mask
def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False):
"""Revert model logits obtained on a sequence built from the pattern
back to a tensor matching the original sequence.
This method is similar to ``revert_pattern_sequence`` with the following specificities:
1. It is designed to work with the extra cardinality dimension
2. We return the logits for the first sequence item that matches the special_token and
which matching target in the original sequence is the first item of the sequence,
while we skip the last logits as there is no matching target
"""
B, card, K, S = logits.shape
indexes, mask = self._build_reverted_sequence_scatter_indexes(
S, K, keep_only_valid_steps, is_model_output=True, device=logits.device
)
logits = logits.reshape(B, card, -1)
# we append the special token as the last index of our flattened z tensor
logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S]
values = logits[:, :, indexes.view(-1)]
values = values.view(B, card, K, indexes.shape[-1])
return values, indexes, mask
class CodebooksPatternProvider(ABC):
"""Abstraction around providing pattern for interleaving codebooks.
The CodebooksPatternProvider abstraction allows to implement various strategies to
define interleaving pattern of sequences composed of multiple codebooks. For a given
number of codebooks `n_q`, the pattern provider can generate a specified pattern
corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern
can be used to construct a new sequence from the original codes respecting the specified
pattern. The pattern is defined as a list of list of code coordinates, code coordinate
being a tuple with the original timestep and codebook to build the new sequence.
Note that all patterns must start with an empty list that is then used to insert a first
sequence step of special tokens in the newly generated sequence.
Args:
n_q (int): number of codebooks.
cached (bool): if True, patterns for a given length are cached. In general
that should be true for efficiency reason to avoid synchronization points.
"""
def __init__(self, n_q: int, cached: bool = True):
assert n_q > 0
self.n_q = n_q
self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore
@abstractmethod
def get_pattern(self, timesteps: int) -> Pattern:
"""Builds pattern with specific interleaving between codebooks.
Args:
timesteps (int): Total numer of timesteps.
"""
raise NotImplementedError()
class DelayedPatternProvider(CodebooksPatternProvider):
"""Provider for delayed pattern across delayed codebooks.
Codebooks are delayed in the sequence and sequence steps will contain codebooks
from different timesteps.
Example:
Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence:
[[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4]]
The resulting sequence obtained from the returned pattern is:
[[S, 1, 2, 3, 4],
[S, S, 1, 2, 3],
[S, S, S, 1, 2]]
(with S being a special token)
Args:
n_q (int): Number of codebooks.
delays (Optional[List[int]]): Delay for each of the codebooks.
If delays not defined, each codebook is delayed by 1 compared to the previous one.
flatten_first (int): Flatten the first N timesteps.
empty_initial (int): Prepend with N empty list of coordinates.
"""
def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None,
flatten_first: int = 0, empty_initial: int = 0):
super().__init__(n_q)
if delays is None:
delays = list(range(n_q))
self.delays = delays
self.flatten_first = flatten_first
self.empty_initial = empty_initial
assert len(self.delays) == self.n_q
assert sorted(self.delays) == self.delays
def get_pattern(self, timesteps: int) -> Pattern:
out: PatternLayout = [[]]
max_delay = max(self.delays)
if self.empty_initial:
out += [[] for _ in range(self.empty_initial)]
if self.flatten_first:
for t in range(min(timesteps, self.flatten_first)):
for q in range(self.n_q):
out.append([LayoutCoord(t, q)])
for t in range(self.flatten_first, timesteps + max_delay):
v = []
for q, delay in enumerate(self.delays):
t_for_q = t - delay
if t_for_q >= self.flatten_first:
v.append(LayoutCoord(t_for_q, q))
out.append(v)
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
class ParallelPatternProvider(DelayedPatternProvider):
"""Provider for parallel pattern across codebooks.
This pattern provider is a special case of the delayed pattern with actually no delay,
hence delays=repeat(0, n_q).
Args:
n_q (int): Number of codebooks.
"""
def __init__(self, n_q: int):
super().__init__(n_q, [0] * n_q)
class UnrolledPatternProvider(CodebooksPatternProvider):
"""Provider for unrolling codebooks pattern.
This pattern provider enables to represent the codebook flattened completely or only to some extend
while also specifying a given delay between the flattened codebooks representation, allowing to
unroll the codebooks in the sequence.
Example:
1. Flattening of the codebooks.
By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q),
taking n_q = 3 and timesteps = 4:
[[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4]]
will result into:
[[S, S, 1, S, S, 2, S, S, 3, S, S, 4],
[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
[1, S, S, 2, S, S, 3, S, S, 4, S, S]]
2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step
for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example
taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]:
[[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4]]
will result into:
[[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
[1, S, S, 2, S, S, 3, S, S, 4, S, S]]
3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks
allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the
same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1]
and delays = [0, 3, 3]:
[[1, 2, 3, 4],
[1, 2, 3, 4],
[1, 2, 3, 4]]
will result into:
[[S, S, S, 1, S, 2, S, 3, S, 4],
[S, S, S, 1, S, 2, S, 3, S, 4],
[1, 2, 3, S, 4, S, 5, S, 6, S]]
Args:
n_q (int): Number of codebooks.
flattening (Optional[List[int]]): Flattening schema over the codebooks. If not defined,
the codebooks will be flattened to 1 codebook per step, meaning that the sequence will
have n_q extra steps for each timestep.
delays (Optional[List[int]]): Delay for each of the codebooks. If not defined,
no delay is added and therefore will default to [0] * ``n_q``.
Note that two codebooks that will be flattened to the same inner step
should have the same delay, otherwise the pattern is considered as invalid.
"""
FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay'])
def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None,
delays: tp.Optional[tp.List[int]] = None):
super().__init__(n_q)
if flattening is None:
flattening = list(range(n_q))
if delays is None:
delays = [0] * n_q
assert len(flattening) == n_q
assert len(delays) == n_q
assert sorted(flattening) == flattening
assert sorted(delays) == delays
self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening)
self.max_delay = max(delays)
def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]):
"""Build a flattened codebooks representation as a dictionary of inner step
and the actual codebook indices corresponding to the flattened codebook. For convenience, we
also store the delay associated to the flattened codebook to avoid maintaining an extra mapping.
"""
flattened_codebooks: dict = {}
for q, (inner_step, delay) in enumerate(zip(flattening, delays)):
if inner_step not in flattened_codebooks:
flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay)
else:
flat_codebook = flattened_codebooks[inner_step]
assert flat_codebook.delay == delay, (
"Delay and flattening between codebooks is inconsistent: ",
"two codebooks flattened to the same position should have the same delay."
)
flat_codebook.codebooks.append(q)
flattened_codebooks[inner_step] = flat_codebook
return flattened_codebooks
@property
def _num_inner_steps(self):
"""Number of inner steps to unroll between timesteps in order to flatten the codebooks.
"""
return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1
def num_virtual_steps(self, timesteps: int) -> int:
return timesteps * self._num_inner_steps + 1
def get_pattern(self, timesteps: int) -> Pattern:
"""Builds pattern for delay across codebooks.
Args:
timesteps (int): Total numer of timesteps.
"""
# the PatternLayout is built as a tuple of sequence position and list of coordinates
# so that it can be reordered properly given the required delay between codebooks of given timesteps
indexed_out: list = [(-1, [])]
max_timesteps = timesteps + self.max_delay
for t in range(max_timesteps):
# for each timestep, we unroll the flattened codebooks,
# emitting the sequence step with the corresponding delay
for step in range(self._num_inner_steps):
if step in self._flattened_codebooks:
# we have codebooks at this virtual step to emit
step_codebooks = self._flattened_codebooks[step]
t_for_q = t + step_codebooks.delay
coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks]
if t_for_q < max_timesteps and t < max_timesteps:
indexed_out.append((t_for_q, coords))
else:
# there is no codebook in this virtual step so we emit an empty list
indexed_out.append((t, []))
out = [coords for _, coords in sorted(indexed_out)]
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
class VALLEPattern(CodebooksPatternProvider):
"""Almost VALL-E style pattern. We futher allow some delays for the
codebooks other than the first one.
Args:
n_q (int): Number of codebooks.
delays (Optional[List[int]]): Delay for each of the codebooks.
If delays not defined, each codebook is delayed by 1 compared to the previous one.
"""
def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
super().__init__(n_q)
if delays is None:
delays = [0] * (n_q - 1)
self.delays = delays
assert len(self.delays) == self.n_q - 1
assert sorted(self.delays) == self.delays
def get_pattern(self, timesteps: int) -> Pattern:
out: PatternLayout = [[]]
for t in range(timesteps):
out.append([LayoutCoord(t, 0)])
max_delay = max(self.delays)
for t in range(timesteps + max_delay):
v = []
for q, delay in enumerate(self.delays):
t_for_q = t - delay
if t_for_q >= 0:
v.append(LayoutCoord(t_for_q, q + 1))
out.append(v)
return Pattern(out, n_q=self.n_q, timesteps=timesteps)
class MusicLMPattern(CodebooksPatternProvider):
"""Almost MusicLM style pattern. This is equivalent to full flattening
but in a different order.
Args:
n_q (int): Number of codebooks.
group_by (int): Number of codebooks to group together.
"""
def __init__(self, n_q: int, group_by: int = 2):
super().__init__(n_q)
self.group_by = group_by
def get_pattern(self, timesteps: int) -> Pattern:
out: PatternLayout = [[]]
for offset in range(0, self.n_q, self.group_by):
for t in range(timesteps):
for q in range(offset, offset + self.group_by):
out.append([LayoutCoord(t, q)])
return Pattern(out, n_q=self.n_q, timesteps=timesteps)

View File

@ -0,0 +1,990 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from collections import defaultdict
from copy import deepcopy
from dataclasses import dataclass, field
from itertools import chain
import logging
import math
import random
import re
import typing as tp
import warnings
from einops import rearrange
from num2words import num2words
import spacy
from transformers import T5EncoderModel, T5Tokenizer # type: ignore
import torchaudio
import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from .streaming import StreamingModule
from .transformer import create_sin_embedding
from ..data.audio_dataset import SegmentInfo
from ..utils.autocast import TorchAutocast
from ..utils.utils import hash_trick, length_to_mask, collate
logger = logging.getLogger(__name__)
TextCondition = tp.Optional[str] # a text condition can be a string or None (if doesn't exist)
ConditionType = tp.Tuple[Tensor, Tensor] # condition, mask
class WavCondition(tp.NamedTuple):
wav: Tensor
length: Tensor
path: tp.List[tp.Optional[str]] = []
def nullify_condition(condition: ConditionType, dim: int = 1):
"""This function transforms an input condition to a null condition.
The way it is done by converting it to a single zero vector similarly
to how it is done inside WhiteSpaceTokenizer and NoopTokenizer.
Args:
condition (ConditionType): a tuple of condition and mask (tp.Tuple[Tensor, Tensor])
dim (int): the dimension that will be truncated (should be the time dimension)
WARNING!: dim should not be the batch dimension!
Returns:
ConditionType: a tuple of null condition and mask
"""
assert dim != 0, "dim cannot be the batch dimension!"
assert type(condition) == tuple and \
type(condition[0]) == Tensor and \
type(condition[1]) == Tensor, "'nullify_condition' got an unexpected input type!"
cond, mask = condition
B = cond.shape[0]
last_dim = cond.dim() - 1
out = cond.transpose(dim, last_dim)
out = 0. * out[..., :1]
out = out.transpose(dim, last_dim)
mask = torch.zeros((B, 1), device=out.device).int()
assert cond.dim() == out.dim()
return out, mask
def nullify_wav(wav: Tensor) -> WavCondition:
"""Create a nullified WavCondition from a wav tensor with appropriate shape.
Args:
wav (Tensor): tensor of shape [B, T]
Returns:
WavCondition: wav condition with nullified wav.
"""
null_wav, _ = nullify_condition((wav, torch.zeros_like(wav)), dim=wav.dim() - 1)
return WavCondition(
wav=null_wav,
length=torch.tensor([0] * wav.shape[0], device=wav.device),
path=['null_wav'] * wav.shape[0]
)
@dataclass
class ConditioningAttributes:
text: tp.Dict[str, tp.Optional[str]] = field(default_factory=dict)
wav: tp.Dict[str, WavCondition] = field(default_factory=dict)
def __getitem__(self, item):
return getattr(self, item)
@property
def text_attributes(self):
return self.text.keys()
@property
def wav_attributes(self):
return self.wav.keys()
@property
def attributes(self):
return {"text": self.text_attributes, "wav": self.wav_attributes}
def to_flat_dict(self):
return {
**{f"text.{k}": v for k, v in self.text.items()},
**{f"wav.{k}": v for k, v in self.wav.items()},
}
@classmethod
def from_flat_dict(cls, x):
out = cls()
for k, v in x.items():
kind, att = k.split(".")
out[kind][att] = v
return out
class SegmentWithAttributes(SegmentInfo):
"""Base class for all dataclasses that are used for conditioning.
All child classes should implement `to_condition_attributes` that converts
the existing attributes to a dataclass of type ConditioningAttributes.
"""
def to_condition_attributes(self) -> ConditioningAttributes:
raise NotImplementedError()
class Tokenizer:
"""Base class for all tokenizers
(in case we want to introduce more advances tokenizers in the future).
"""
def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[Tensor, Tensor]:
raise NotImplementedError()
class WhiteSpaceTokenizer(Tokenizer):
"""This tokenizer should be used for natural language descriptions.
For example:
["he didn't, know he's going home.", 'shorter sentence'] =>
[[78, 62, 31, 4, 78, 25, 19, 34],
[59, 77, 0, 0, 0, 0, 0, 0]]
"""
PUNCTUATIONS = "?:!.,;"
def __init__(self, n_bins: int, pad_idx: int = 0, language: str = "en_core_web_sm",
lemma: bool = True, stopwords: bool = True) -> None:
self.n_bins = n_bins
self.pad_idx = pad_idx
self.lemma = lemma
self.stopwords = stopwords
try:
self.nlp = spacy.load(language)
except IOError:
spacy.cli.download(language) # type: ignore
self.nlp = spacy.load(language)
@tp.no_type_check
def __call__(
self,
texts: tp.List[tp.Optional[str]],
return_text: bool = False
) -> tp.Tuple[Tensor, Tensor]:
"""Take a list of strings and convert them to a tensor of indices.
Args:
texts (tp.List[str]): List of strings.
return_text (bool, optional): Whether to return text as additional tuple item. Defaults to False.
Returns:
tp.Tuple[Tensor, Tensor]:
- Indices of words in the LUT.
- And a mask indicating where the padding tokens are
"""
output, lengths = [], []
texts = deepcopy(texts)
for i, text in enumerate(texts):
# if current sample doesn't have a certain attribute, replace with pad token
if text is None:
output.append(Tensor([self.pad_idx]))
lengths.append(0)
continue
# convert numbers to words
text = re.sub(r"(\d+)", lambda x: num2words(int(x.group(0))), text) # type: ignore
# normalize text
text = self.nlp(text) # type: ignore
# remove stopwords
if self.stopwords:
text = [w for w in text if not w.is_stop] # type: ignore
# remove punctuations
text = [w for w in text if w.text not in self.PUNCTUATIONS] # type: ignore
# lemmatize if needed
text = [getattr(t, "lemma_" if self.lemma else "text") for t in text] # type: ignore
texts[i] = " ".join(text)
lengths.append(len(text))
# convert to tensor
tokens = Tensor([hash_trick(w, self.n_bins) for w in text])
output.append(tokens)
mask = length_to_mask(torch.IntTensor(lengths)).int()
padded_output = pad_sequence(output, padding_value=self.pad_idx).int().t()
if return_text:
return padded_output, mask, texts # type: ignore
return padded_output, mask
class NoopTokenizer(Tokenizer):
"""This tokenizer should be used for global conditioners such as: artist, genre, key, etc.
The difference between this and WhiteSpaceTokenizer is that NoopTokenizer does not split
strings, so "Jeff Buckley" will get it's own index. Whereas WhiteSpaceTokenizer will
split it to ["Jeff", "Buckley"] and return an index per word.
For example:
["Queen", "ABBA", "Jeff Buckley"] => [43, 55, 101]
["Metal", "Rock", "Classical"] => [0, 223, 51]
"""
def __init__(self, n_bins: int, pad_idx: int = 0):
self.n_bins = n_bins
self.pad_idx = pad_idx
def __call__(self, texts: tp.List[tp.Optional[str]]) -> tp.Tuple[Tensor, Tensor]:
output, lengths = [], []
for text in texts:
# if current sample doesn't have a certain attribute, replace with pad token
if text is None:
output.append(self.pad_idx)
lengths.append(0)
else:
output.append(hash_trick(text, self.n_bins))
lengths.append(1)
tokens = torch.LongTensor(output).unsqueeze(1)
mask = length_to_mask(torch.IntTensor(lengths)).int()
return tokens, mask
class BaseConditioner(nn.Module):
"""Base model for all conditioner modules. We allow the output dim to be different
than the hidden dim for two reasons: 1) keep our LUTs small when the vocab is large;
2) make all condition dims consistent.
Args:
dim (int): Hidden dim of the model (text-encoder/LUT).
output_dim (int): Output dim of the conditioner.
"""
def __init__(self, dim, output_dim):
super().__init__()
self.dim = dim
self.output_dim = output_dim
self.output_proj = nn.Linear(dim, output_dim)
def tokenize(self, *args, **kwargs) -> tp.Any:
"""Should be any part of the processing that will lead to a synchronization
point, e.g. BPE tokenization with transfer to the GPU.
The returned value will be saved and return later when calling forward().
"""
raise NotImplementedError()
def forward(self, inputs: tp.Any) -> ConditionType:
"""Gets input that should be used as conditioning (e.g, genre, description or a waveform).
Outputs a ConditionType, after the input data was embedded as a dense vector.
Returns:
ConditionType:
- A tensor of size [B, T, D] where B is the batch size, T is the length of the
output embedding and D is the dimension of the embedding.
- And a mask indicating where the padding tokens.
"""
raise NotImplementedError()
class TextConditioner(BaseConditioner):
...
class LUTConditioner(TextConditioner):
"""Lookup table TextConditioner.
Args:
n_bins (int): Number of bins.
dim (int): Hidden dim of the model (text-encoder/LUT).
output_dim (int): Output dim of the conditioner.
tokenizer (str): Name of the tokenizer.
pad_idx (int, optional): Index for padding token. Defaults to 0.
"""
def __init__(self, n_bins: int, dim: int, output_dim: int, tokenizer: str, pad_idx: int = 0):
super().__init__(dim, output_dim)
self.embed = nn.Embedding(n_bins, dim)
self.tokenizer: Tokenizer
if tokenizer == "whitespace":
self.tokenizer = WhiteSpaceTokenizer(n_bins, pad_idx=pad_idx)
elif tokenizer == "noop":
self.tokenizer = NoopTokenizer(n_bins, pad_idx=pad_idx)
else:
raise ValueError(f"unrecognized tokenizer `{tokenizer}`.")
def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
device = self.embed.weight.device
tokens, mask = self.tokenizer(x)
tokens, mask = tokens.to(device), mask.to(device)
return tokens, mask
def forward(self, inputs: tp.Tuple[torch.Tensor, torch.Tensor]) -> ConditionType:
tokens, mask = inputs
embeds = self.embed(tokens)
embeds = self.output_proj(embeds)
embeds = (embeds * mask.unsqueeze(-1))
return embeds, mask
class T5Conditioner(TextConditioner):
"""T5-based TextConditioner.
Args:
name (str): Name of the T5 model.
output_dim (int): Output dim of the conditioner.
finetune (bool): Whether to fine-tune T5 at train time.
device (str): Device for T5 Conditioner.
autocast_dtype (tp.Optional[str], optional): Autocast dtype.
word_dropout (float, optional): Word dropout probability.
normalize_text (bool, optional): Whether to apply text normalization.
"""
MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
"google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large",
"google/flan-t5-xl", "google/flan-t5-xxl"]
MODELS_DIMS = {
"t5-small": 512,
"t5-base": 768,
"t5-large": 1024,
"t5-3b": 1024,
"t5-11b": 1024,
"google/flan-t5-small": 512,
"google/flan-t5-base": 768,
"google/flan-t5-large": 1024,
"google/flan-t5-3b": 1024,
"google/flan-t5-11b": 1024,
}
def __init__(self, name: str, output_dim: int, finetune: bool, device: str,
autocast_dtype: tp.Optional[str] = 'float32', word_dropout: float = 0.,
normalize_text: bool = False):
assert name in self.MODELS, f"unrecognized t5 model name (should in {self.MODELS})"
super().__init__(self.MODELS_DIMS[name], output_dim)
self.device = device
self.name = name
self.finetune = finetune
self.word_dropout = word_dropout
if autocast_dtype is None or self.device == 'cpu':
self.autocast = TorchAutocast(enabled=False)
if self.device != 'cpu':
logger.warning("T5 has no autocast, this might lead to NaN")
else:
dtype = getattr(torch, autocast_dtype)
assert isinstance(dtype, torch.dtype)
logger.info(f"T5 will be evaluated with autocast as {autocast_dtype}")
self.autocast = TorchAutocast(enabled=True, device_type=self.device, dtype=dtype)
# Let's disable logging temporarily because T5 will vomit some errors otherwise.
# thanks https://gist.github.com/simon-weber/7853144
previous_level = logging.root.manager.disable
logging.disable(logging.ERROR)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
try:
self.t5_tokenizer = T5Tokenizer.from_pretrained(name)
t5 = T5EncoderModel.from_pretrained(name).train(mode=finetune)
finally:
logging.disable(previous_level)
if finetune:
self.t5 = t5
else:
# this makes sure that the t5 models is not part
# of the saved checkpoint
self.__dict__["t5"] = t5.to(device)
self.normalize_text = normalize_text
if normalize_text:
self.text_normalizer = WhiteSpaceTokenizer(1, lemma=True, stopwords=True)
def tokenize(self, x: tp.List[tp.Optional[str]]) -> tp.Dict[str, torch.Tensor]:
# if current sample doesn't have a certain attribute, replace with empty string
entries: tp.List[str] = [xi if xi is not None else "" for xi in x]
if self.normalize_text:
_, _, entries = self.text_normalizer(entries, return_text=True)
if self.word_dropout > 0. and self.training:
new_entries = []
for entry in entries:
words = [word for word in entry.split(" ") if random.random() >= self.word_dropout]
new_entries.append(" ".join(words))
entries = new_entries
empty_idx = torch.LongTensor([i for i, xi in enumerate(entries) if xi == ""])
inputs = self.t5_tokenizer(entries, return_tensors="pt", padding=True).to(self.device)
mask = inputs["attention_mask"]
mask[empty_idx, :] = 0 # zero-out index where the input is non-existant
return inputs
def forward(self, inputs: tp.Dict[str, torch.Tensor]) -> ConditionType:
mask = inputs["attention_mask"]
with torch.set_grad_enabled(self.finetune), self.autocast:
embeds = self.t5(**inputs).last_hidden_state
embeds = self.output_proj(embeds.to(self.output_proj.weight))
embeds = (embeds * mask.unsqueeze(-1))
return embeds, mask
class WaveformConditioner(BaseConditioner):
"""Base class for all conditioners that take a waveform as input.
Classes that inherit must implement `_get_wav_embedding` that outputs
a continuous tensor, and `_downsampling_factor` that returns the down-sampling
factor of the embedding model.
Args:
dim (int): The internal representation dimension.
output_dim (int): Output dimension.
device (tp.Union[torch.device, str]): Device.
"""
def __init__(self, dim: int, output_dim: int, device: tp.Union[torch.device, str]):
super().__init__(dim, output_dim)
self.device = device
def tokenize(self, wav_length: WavCondition) -> WavCondition:
wav, length, path = wav_length
assert length is not None
return WavCondition(wav.to(self.device), length.to(self.device), path)
def _get_wav_embedding(self, wav: Tensor) -> Tensor:
"""Gets as input a wav and returns a dense vector of conditions."""
raise NotImplementedError()
def _downsampling_factor(self):
"""Returns the downsampling factor of the embedding model."""
raise NotImplementedError()
def forward(self, inputs: WavCondition) -> ConditionType:
"""
Args:
input (WavCondition): Tuple of (waveform, lengths).
Returns:
ConditionType: Dense vector representing the conditioning along with its' mask.
"""
wav, lengths, path = inputs
with torch.no_grad():
embeds = self._get_wav_embedding(wav)
embeds = embeds.to(self.output_proj.weight)
embeds = self.output_proj(embeds)
if lengths is not None:
lengths = lengths / self._downsampling_factor()
mask = length_to_mask(lengths, max_len=embeds.shape[1]).int() # type: ignore
else:
mask = torch.ones_like(embeds)
embeds = (embeds * mask.unsqueeze(2).to(self.device))
return embeds, mask
class ChromaStemConditioner(WaveformConditioner):
"""Chroma conditioner that uses DEMUCS to first filter out drums and bass. The is followed by
the insight the drums and bass often dominate the chroma, leading to the chroma not containing the
information about melody.
Args:
output_dim (int): Output dimension for the conditioner.
sample_rate (int): Sample rate for the chroma extractor.
n_chroma (int): Number of chroma for the chroma extractor.
radix2_exp (int): Radix2 exponent for the chroma extractor.
duration (float): Duration used during training. This is later used for correct padding
in case we are using chroma as prefix.
match_len_on_eval (bool, optional): If True then all chromas are padded to the training
duration. Defaults to False.
eval_wavs (str, optional): Path to a json egg with waveform, this waveforms are used as
conditions during eval (for cases where we don't want to leak test conditions like MusicCaps).
Defaults to None.
n_eval_wavs (int, optional): Limits the number of waveforms used for conditioning. Defaults to 0.
device (tp.Union[torch.device, str], optional): Device for the conditioner.
**kwargs: Additional parameters for the chroma extractor.
"""
def __init__(self, output_dim: int, sample_rate: int, n_chroma: int, radix2_exp: int,
duration: float, match_len_on_eval: bool = True, eval_wavs: tp.Optional[str] = None,
n_eval_wavs: int = 0, device: tp.Union[torch.device, str] = "cpu", **kwargs):
from demucs import pretrained
super().__init__(dim=n_chroma, output_dim=output_dim, device=device)
self.autocast = TorchAutocast(enabled=device != "cpu", device_type=self.device, dtype=torch.float32)
self.sample_rate = sample_rate
self.match_len_on_eval = match_len_on_eval
self.duration = duration
self.__dict__["demucs"] = pretrained.get_model('htdemucs').to(device)
self.stem2idx = {'drums': 0, 'bass': 1, 'other': 2, 'vocal': 3}
self.stem_idx = torch.LongTensor([self.stem2idx['vocal'], self.stem2idx['other']]).to(device)
self.chroma = ChromaExtractor(sample_rate=sample_rate, n_chroma=n_chroma, radix2_exp=radix2_exp,
device=device, **kwargs)
self.chroma_len = self._get_chroma_len()
def _downsampling_factor(self):
return self.chroma.winhop
def _get_chroma_len(self):
"""Get length of chroma during training"""
dummy_wav = torch.zeros((1, self.sample_rate * self.duration), device=self.device)
dummy_chr = self.chroma(dummy_wav)
return dummy_chr.shape[1]
@torch.no_grad()
def _get_filtered_wav(self, wav):
from demucs.apply import apply_model
from demucs.audio import convert_audio
with self.autocast:
wav = convert_audio(wav, self.sample_rate, self.demucs.samplerate, self.demucs.audio_channels)
stems = apply_model(self.demucs, wav, device=self.device)
stems = stems[:, self.stem_idx] # extract stem
stems = stems.sum(1) # merge extracted stems
stems = stems.mean(1, keepdim=True) # mono
stems = convert_audio(stems, self.demucs.samplerate, self.sample_rate, 1)
return stems
@torch.no_grad()
def _get_wav_embedding(self, wav):
# avoid 0-size tensors when we are working with null conds
if wav.shape[-1] == 1:
return self.chroma(wav)
stems = self._get_filtered_wav(wav)
chroma = self.chroma(stems)
if self.match_len_on_eval:
b, t, c = chroma.shape
if t > self.chroma_len:
chroma = chroma[:, :self.chroma_len]
logger.debug(f'chroma was truncated! ({t} -> {chroma.shape[1]})')
elif t < self.chroma_len:
# chroma = F.pad(chroma, (0, 0, 0, self.chroma_len - t))
n_repeat = int(math.ceil(self.chroma_len / t))
chroma = chroma.repeat(1, n_repeat, 1)
chroma = chroma[:, :self.chroma_len]
logger.debug(f'chroma was zero-padded! ({t} -> {chroma.shape[1]})')
return chroma
class ChromaExtractor(nn.Module):
"""Chroma extraction class, handles chroma extraction and quantization.
Args:
sample_rate (int): Sample rate.
n_chroma (int): Number of chroma to consider.
radix2_exp (int): Radix2 exponent.
nfft (tp.Optional[int], optional): Number of FFT.
winlen (tp.Optional[int], optional): Window length.
winhop (tp.Optional[int], optional): Window hop size.
argmax (bool, optional): Whether to use argmax. Defaults to False.
norm (float, optional): Norm for chroma normalization. Defaults to inf.
device (tp.Union[torch.device, str], optional): Device to use. Defaults to cpu.
"""
def __init__(self, sample_rate: int, n_chroma: int = 12, radix2_exp: int = 12,
nfft: tp.Optional[int] = None, winlen: tp.Optional[int] = None, winhop: tp.Optional[int] = None,
argmax: bool = False, norm: float = torch.inf, device: tp.Union[torch.device, str] = "cpu"):
super().__init__()
from librosa import filters
self.device = device
self.autocast = TorchAutocast(enabled=device != "cpu", device_type=self.device, dtype=torch.float32)
self.winlen = winlen or 2 ** radix2_exp
self.nfft = nfft or self.winlen
self.winhop = winhop or (self.winlen // 4)
self.sr = sample_rate
self.n_chroma = n_chroma
self.norm = norm
self.argmax = argmax
self.window = torch.hann_window(self.winlen).to(device)
self.fbanks = torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0,
n_chroma=self.n_chroma)).to(device)
self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen,
hop_length=self.winhop, power=2, center=True,
pad=0, normalized=True).to(device)
def forward(self, wav):
with self.autocast:
T = wav.shape[-1]
# in case we are getting a wav that was dropped out (nullified)
# make sure wav length is no less that nfft
if T < self.nfft:
pad = self.nfft - T
r = 0 if pad % 2 == 0 else 1
wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0)
assert wav.shape[-1] == self.nfft, f'expected len {self.nfft} but got {wav.shape[-1]}'
spec = self.spec(wav).squeeze(1)
raw_chroma = torch.einsum("cf,...ft->...ct", self.fbanks, spec)
norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6)
norm_chroma = rearrange(norm_chroma, "b d t -> b t d")
if self.argmax:
idx = norm_chroma.argmax(-1, keepdims=True)
norm_chroma[:] = 0
norm_chroma.scatter_(dim=-1, index=idx, value=1)
return norm_chroma
def dropout_condition(sample: ConditioningAttributes, condition_type: str, condition: str):
"""Utility function for nullifying an attribute inside an ConditioningAttributes object.
If the condition is of type "wav", then nullify it using "nullify_condition".
If the condition is of any other type, set its' value to None.
Works in-place.
"""
if condition_type not in ["text", "wav"]:
raise ValueError(
"dropout_condition got an unexpected condition type!"
f" expected 'wav' or 'text' but got '{condition_type}'"
)
if condition not in getattr(sample, condition_type):
raise ValueError(
"dropout_condition received an unexpected condition!"
f" expected wav={sample.wav.keys()} and text={sample.text.keys()}"
f"but got '{condition}' of type '{condition_type}'!"
)
if condition_type == "wav":
wav, length, path = sample.wav[condition]
sample.wav[condition] = nullify_wav(wav)
else:
sample.text[condition] = None
return sample
class DropoutModule(nn.Module):
"""Base class for all dropout modules."""
def __init__(self, seed: int = 1234):
super().__init__()
self.rng = torch.Generator()
self.rng.manual_seed(seed)
class AttributeDropout(DropoutModule):
"""Applies dropout with a given probability per attribute. This is different from the behavior of
ClassifierFreeGuidanceDropout as this allows for attributes to be dropped out separately. For example,
"artist" can be dropped while "genre" remains. This is in contrast to ClassifierFreeGuidanceDropout
where if "artist" is dropped "genre" must also be dropped.
Args:
p (tp.Dict[str, float]): A dict mapping between attributes and dropout probability. For example:
...
"genre": 0.1,
"artist": 0.5,
"wav": 0.25,
...
active_on_eval (bool, optional): Whether the dropout is active at eval. Default to False.
seed (int, optional): Random seed.
"""
def __init__(self, p: tp.Dict[str, tp.Dict[str, float]], active_on_eval: bool = False, seed: int = 1234):
super().__init__(seed=seed)
self.active_on_eval = active_on_eval
# construct dict that return the values from p otherwise 0
self.p = {}
for condition_type, probs in p.items():
self.p[condition_type] = defaultdict(lambda: 0, probs)
def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
"""
Args:
samples (tp.List[ConditioningAttributes]): List of conditions.
Returns:
tp.List[ConditioningAttributes]: List of conditions after certain attributes were set to None.
"""
if not self.training and not self.active_on_eval:
return samples
samples = deepcopy(samples)
for condition_type, ps in self.p.items(): # for condition types [text, wav]
for condition, p in ps.items(): # for attributes of each type (e.g., [artist, genre])
if torch.rand(1, generator=self.rng).item() < p:
for sample in samples:
dropout_condition(sample, condition_type, condition)
return samples
def __repr__(self):
return f"AttributeDropout({dict(self.p)})"
class ClassifierFreeGuidanceDropout(DropoutModule):
"""Applies Classifier Free Guidance dropout, meaning all attributes
are dropped with the same probability.
Args:
p (float): Probability to apply condition dropout during training.
seed (int): Random seed.
"""
def __init__(self, p: float, seed: int = 1234):
super().__init__(seed=seed)
self.p = p
def forward(self, samples: tp.List[ConditioningAttributes]) -> tp.List[ConditioningAttributes]:
"""
Args:
samples (tp.List[ConditioningAttributes]): List of conditions.
Returns:
tp.List[ConditioningAttributes]: List of conditions after all attributes were set to None.
"""
if not self.training:
return samples
# decide on which attributes to drop in a batched fashion
drop = torch.rand(1, generator=self.rng).item() < self.p
if not drop:
return samples
# nullify conditions of all attributes
samples = deepcopy(samples)
for condition_type in ["wav", "text"]:
for sample in samples:
for condition in sample.attributes[condition_type]:
dropout_condition(sample, condition_type, condition)
return samples
def __repr__(self):
return f"ClassifierFreeGuidanceDropout(p={self.p})"
class ConditioningProvider(nn.Module):
"""Main class to provide conditions given all the supported conditioners.
Args:
conditioners (dict): Dictionary of conditioners.
merge_text_conditions_p (float, optional): Probability to merge all text sources
into a single text condition. Defaults to 0.
drop_desc_p (float, optional): Probability to drop the original description
when merging all text sources into a single text condition. Defaults to 0.
device (tp.Union[torch.device, str], optional): Device for conditioners and output condition types.
"""
def __init__(
self,
conditioners: tp.Dict[str, BaseConditioner],
merge_text_conditions_p: float = 0,
drop_desc_p: float = 0,
device: tp.Union[torch.device, str] = "cpu",
):
super().__init__()
self.device = device
self.merge_text_conditions_p = merge_text_conditions_p
self.drop_desc_p = drop_desc_p
self.conditioners = nn.ModuleDict(conditioners)
@property
def text_conditions(self):
return [k for k, v in self.conditioners.items() if isinstance(v, TextConditioner)]
@property
def wav_conditions(self):
return [k for k, v in self.conditioners.items() if isinstance(v, WaveformConditioner)]
@property
def has_wav_condition(self):
return len(self.wav_conditions) > 0
def tokenize(self, inputs: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.Any]:
"""Match attributes/wavs with existing conditioners in self, and compute tokenize them accordingly.
This should be called before starting any real GPU work to avoid synchronization points.
This will return a dict matching conditioner names to their arbitrary tokenized representations.
Args:
inputs (list[ConditioningAttribres]): List of ConditioningAttributes objects containing
text and wav conditions.
"""
assert all([type(x) == ConditioningAttributes for x in inputs]), \
"got unexpected types input for conditioner! should be tp.List[ConditioningAttributes]" \
f" but types were {set([type(x) for x in inputs])}"
output = {}
text = self._collate_text(inputs)
wavs = self._collate_wavs(inputs)
assert set(text.keys() | wavs.keys()).issubset(set(self.conditioners.keys())), \
f"got an unexpected attribute! Expected {self.conditioners.keys()}, got {text.keys(), wavs.keys()}"
for attribute, batch in chain(text.items(), wavs.items()):
output[attribute] = self.conditioners[attribute].tokenize(batch)
return output
def forward(self, tokenized: tp.Dict[str, tp.Any]) -> tp.Dict[str, ConditionType]:
"""Compute pairs of `(embedding, mask)` using the configured conditioners
and the tokenized representations. The output is for example:
{
"genre": (torch.Tensor([B, 1, D_genre]), torch.Tensor([B, 1])),
"description": (torch.Tensor([B, T_desc, D_desc]), torch.Tensor([B, T_desc])),
...
}
Args:
tokenized (dict): Dict of tokenized representations as returned by `tokenize()`.
"""
output = {}
for attribute, inputs in tokenized.items():
condition, mask = self.conditioners[attribute](inputs)
output[attribute] = (condition, mask)
return output
def _collate_text(self, samples: tp.List[ConditioningAttributes]) -> tp.Dict[str, tp.List[tp.Optional[str]]]:
"""Given a list of ConditioningAttributes objects, compile a dictionary where the keys
are the attributes and the values are the aggregated input per attribute.
For example:
Input:
[
ConditioningAttributes(text={"genre": "Rock", "description": "A rock song with a guitar solo"}, wav=...),
ConditioningAttributes(text={"genre": "Hip-hop", "description": "A hip-hop verse"}, wav=...),
]
Output:
{
"genre": ["Rock", "Hip-hop"],
"description": ["A rock song with a guitar solo", "A hip-hop verse"]
}
"""
batch_per_attribute: tp.Dict[str, tp.List[tp.Optional[str]]] = defaultdict(list)
def _merge_conds(cond, merge_text_conditions_p=0, drop_desc_p=0):
def is_valid(k, v):
k_valid = k in ['key', 'bpm', 'genre', 'moods', 'instrument']
v_valid = v is not None and isinstance(v, (int, float, str, list))
return k_valid and v_valid
def process_value(v):
if isinstance(v, (int, float, str)):
return v
if isinstance(v, list):
return ", ".join(v)
else:
RuntimeError(f"unknown type for text value! ({type(v), v})")
desc = cond.text['description']
meta_data = ""
if random.uniform(0, 1) < merge_text_conditions_p:
meta_pairs = [f'{k}: {process_value(v)}' for k, v in cond.text.items() if is_valid(k, v)]
random.shuffle(meta_pairs)
meta_data = ". ".join(meta_pairs)
desc = desc if not random.uniform(0, 1) < drop_desc_p else None
if desc is None:
desc = meta_data if len(meta_data) > 1 else None
else:
desc = desc.rstrip('.') + ". " + meta_data
cond.text['description'] = desc.strip() if desc else None
if self.training and self.merge_text_conditions_p:
for sample in samples:
_merge_conds(sample, self.merge_text_conditions_p, self.drop_desc_p)
texts = [x.text for x in samples]
for text in texts:
for condition in self.text_conditions:
batch_per_attribute[condition].append(text[condition])
return batch_per_attribute
def _collate_wavs(self, samples: tp.List[ConditioningAttributes]):
"""Generate a dict where the keys are attributes by which we fetch similar wavs,
and the values are Tensors of wavs according to said attribtues.
*Note*: by the time the samples reach this function, each sample should have some waveform
inside the "wav" attribute. It should be either:
1. A real waveform
2. A null waveform due to the sample having no similar waveforms (nullified by the dataset)
3. A null waveform due to it being dropped in a dropout module (nullified by dropout)
Args:
samples (tp.List[ConditioningAttributes]): List of ConditioningAttributes samples.
Returns:
dict: A dicionary mapping an attribute name to wavs.
"""
wavs = defaultdict(list)
lens = defaultdict(list)
paths = defaultdict(list)
out = {}
for sample in samples:
for attribute in self.wav_conditions:
wav, length, path = sample.wav[attribute]
wavs[attribute].append(wav.flatten())
lens[attribute].append(length)
paths[attribute].append(path)
# stack all wavs to a single tensor
for attribute in self.wav_conditions:
stacked_wav, _ = collate(wavs[attribute], dim=0)
out[attribute] = WavCondition(stacked_wav.unsqueeze(1),
torch.cat(lens['self_wav']), paths[attribute]) # type: ignore
return out
class ConditionFuser(StreamingModule):
"""Condition fuser handles the logic to combine the different conditions
to the actual model input.
Args:
fuse2cond (tp.Dict[str, str]): A dictionary that says how to fuse
each condition. For example:
{
"prepend": ["description"],
"sum": ["genre", "bpm"],
"cross": ["description"],
}
cross_attention_pos_emb (bool, optional): Use positional embeddings in cross attention.
cross_attention_pos_emb_scale (int): Scale for positional embeddings in cross attention if used.
"""
FUSING_METHODS = ["sum", "prepend", "cross", "input_interpolate"]
def __init__(self, fuse2cond: tp.Dict[str, tp.List[str]], cross_attention_pos_emb: bool = False,
cross_attention_pos_emb_scale: float = 1.0):
super().__init__()
assert all(
[k in self.FUSING_METHODS for k in fuse2cond.keys()]
), f"got invalid fuse method, allowed methods: {self.FUSING_MEHTODS}"
self.cross_attention_pos_emb = cross_attention_pos_emb
self.cross_attention_pos_emb_scale = cross_attention_pos_emb_scale
self.fuse2cond: tp.Dict[str, tp.List[str]] = fuse2cond
self.cond2fuse: tp.Dict[str, str] = {}
for fuse_method, conditions in fuse2cond.items():
for condition in conditions:
self.cond2fuse[condition] = fuse_method
def forward(
self,
input: Tensor,
conditions: tp.Dict[str, ConditionType]
) -> tp.Tuple[Tensor, tp.Optional[Tensor]]:
"""Fuse the conditions to the provided model input.
Args:
input (Tensor): Transformer input.
conditions (tp.Dict[str, ConditionType]): Dict of conditions.
Returns:
tp.Tuple[Tensor, Tensor]: The first tensor is the transformer input
after the conditions have been fused. The second output tensor is the tensor
used for cross-attention or None if no cross attention inputs exist.
"""
B, T, _ = input.shape
if 'offsets' in self._streaming_state:
first_step = False
offsets = self._streaming_state['offsets']
else:
first_step = True
offsets = torch.zeros(input.shape[0], dtype=torch.long, device=input.device)
assert set(conditions.keys()).issubset(set(self.cond2fuse.keys())), \
f"given conditions contain unknown attributes for fuser, " \
f"expected {self.cond2fuse.keys()}, got {conditions.keys()}"
cross_attention_output = None
for cond_type, (cond, cond_mask) in conditions.items():
op = self.cond2fuse[cond_type]
if op == "sum":
input += cond
elif op == "input_interpolate":
cond = rearrange(cond, "b t d -> b d t")
cond = F.interpolate(cond, size=input.shape[1])
input += rearrange(cond, "b d t -> b t d")
elif op == "prepend":
if first_step:
input = torch.cat([cond, input], dim=1)
elif op == "cross":
if cross_attention_output is not None:
cross_attention_output = torch.cat([cross_attention_output, cond], dim=1)
else:
cross_attention_output = cond
else:
raise ValueError(f"unknown op ({op})")
if self.cross_attention_pos_emb and cross_attention_output is not None:
positions = torch.arange(
cross_attention_output.shape[1],
device=cross_attention_output.device
).view(1, -1, 1)
pos_emb = create_sin_embedding(positions, cross_attention_output.shape[-1])
cross_attention_output = cross_attention_output + self.cross_attention_pos_emb_scale * pos_emb
if self._is_streaming:
self._streaming_state['offsets'] = offsets + T
return input, cross_attention_output

245
audiocraft/modules/conv.py Normal file
View File

@ -0,0 +1,245 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import typing as tp
import warnings
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.utils import spectral_norm, weight_norm
CONV_NORMALIZATIONS = frozenset(['none', 'weight_norm', 'spectral_norm',
'time_group_norm'])
def apply_parametrization_norm(module: nn.Module, norm: str = 'none'):
assert norm in CONV_NORMALIZATIONS
if norm == 'weight_norm':
return weight_norm(module)
elif norm == 'spectral_norm':
return spectral_norm(module)
else:
# We already check was in CONV_NORMALIZATION, so any other choice
# doesn't need reparametrization.
return module
def get_norm_module(module: nn.Module, causal: bool = False, norm: str = 'none', **norm_kwargs):
"""Return the proper normalization module. If causal is True, this will ensure the returned
module is causal, or return an error if the normalization doesn't support causal evaluation.
"""
assert norm in CONV_NORMALIZATIONS
if norm == 'time_group_norm':
if causal:
raise ValueError("GroupNorm doesn't support causal evaluation.")
assert isinstance(module, nn.modules.conv._ConvNd)
return nn.GroupNorm(1, module.out_channels, **norm_kwargs)
else:
return nn.Identity()
def get_extra_padding_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int,
padding_total: int = 0) -> int:
"""See `pad_for_conv1d`.
"""
length = x.shape[-1]
n_frames = (length - kernel_size + padding_total) / stride + 1
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
return ideal_length - length
def pad_for_conv1d(x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0):
"""Pad for a convolution to make sure that the last window is full.
Extra padding is added at the end. This is required to ensure that we can rebuild
an output of the same length, as otherwise, even with padding, some time steps
might get removed.
For instance, with total padding = 4, kernel size = 4, stride = 2:
0 0 1 2 3 4 5 0 0 # (0s are padding)
1 2 3 # (output frames of a convolution, last 0 is never used)
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
1 2 3 4 # once you removed padding, we are missing one time step !
"""
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
return F.pad(x, (0, extra_padding))
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
If this is the case, we insert extra 0 padding to the right before the reflection happen.
"""
length = x.shape[-1]
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
if mode == 'reflect':
max_pad = max(padding_left, padding_right)
extra_pad = 0
if length <= max_pad:
extra_pad = max_pad - length + 1
x = F.pad(x, (0, extra_pad))
padded = F.pad(x, paddings, mode, value)
end = padded.shape[-1] - extra_pad
return padded[..., :end]
else:
return F.pad(x, paddings, mode, value)
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
"""Remove padding from x, handling properly zero padding. Only for 1d!
"""
padding_left, padding_right = paddings
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
assert (padding_left + padding_right) <= x.shape[-1]
end = x.shape[-1] - padding_right
return x[..., padding_left: end]
class NormConv1d(nn.Module):
"""Wrapper around Conv1d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def __init__(self, *args, causal: bool = False, norm: str = 'none',
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
super().__init__()
self.conv = apply_parametrization_norm(nn.Conv1d(*args, **kwargs), norm)
self.norm = get_norm_module(self.conv, causal, norm, **norm_kwargs)
self.norm_type = norm
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
return x
class NormConv2d(nn.Module):
"""Wrapper around Conv2d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
super().__init__()
self.conv = apply_parametrization_norm(nn.Conv2d(*args, **kwargs), norm)
self.norm = get_norm_module(self.conv, causal=False, norm=norm, **norm_kwargs)
self.norm_type = norm
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
return x
class NormConvTranspose1d(nn.Module):
"""Wrapper around ConvTranspose1d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def __init__(self, *args, causal: bool = False, norm: str = 'none',
norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
super().__init__()
self.convtr = apply_parametrization_norm(nn.ConvTranspose1d(*args, **kwargs), norm)
self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
self.norm_type = norm
def forward(self, x):
x = self.convtr(x)
x = self.norm(x)
return x
class NormConvTranspose2d(nn.Module):
"""Wrapper around ConvTranspose2d and normalization applied to this conv
to provide a uniform interface across normalization approaches.
"""
def __init__(self, *args, norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {}, **kwargs):
super().__init__()
self.convtr = apply_parametrization_norm(nn.ConvTranspose2d(*args, **kwargs), norm)
self.norm = get_norm_module(self.convtr, causal=False, norm=norm, **norm_kwargs)
def forward(self, x):
x = self.convtr(x)
x = self.norm(x)
return x
class StreamableConv1d(nn.Module):
"""Conv1d with some builtin handling of asymmetric or causal padding
and normalization.
"""
def __init__(self, in_channels: int, out_channels: int,
kernel_size: int, stride: int = 1, dilation: int = 1,
groups: int = 1, bias: bool = True, causal: bool = False,
norm: str = 'none', norm_kwargs: tp.Dict[str, tp.Any] = {},
pad_mode: str = 'reflect'):
super().__init__()
# warn user on unusual setup between dilation and stride
if stride > 1 and dilation > 1:
warnings.warn('StreamableConv1d has been initialized with stride > 1 and dilation > 1'
f' (kernel_size={kernel_size} stride={stride}, dilation={dilation}).')
self.conv = NormConv1d(in_channels, out_channels, kernel_size, stride,
dilation=dilation, groups=groups, bias=bias, causal=causal,
norm=norm, norm_kwargs=norm_kwargs)
self.causal = causal
self.pad_mode = pad_mode
def forward(self, x):
B, C, T = x.shape
kernel_size = self.conv.conv.kernel_size[0]
stride = self.conv.conv.stride[0]
dilation = self.conv.conv.dilation[0]
kernel_size = (kernel_size - 1) * dilation + 1 # effective kernel size with dilations
padding_total = kernel_size - stride
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
if self.causal:
# Left padding for causal
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
else:
# Asymmetric padding required for odd strides
padding_right = padding_total // 2
padding_left = padding_total - padding_right
x = pad1d(x, (padding_left, padding_right + extra_padding), mode=self.pad_mode)
return self.conv(x)
class StreamableConvTranspose1d(nn.Module):
"""ConvTranspose1d with some builtin handling of asymmetric or causal padding
and normalization.
"""
def __init__(self, in_channels: int, out_channels: int,
kernel_size: int, stride: int = 1, causal: bool = False,
norm: str = 'none', trim_right_ratio: float = 1.,
norm_kwargs: tp.Dict[str, tp.Any] = {}):
super().__init__()
self.convtr = NormConvTranspose1d(in_channels, out_channels, kernel_size, stride,
causal=causal, norm=norm, norm_kwargs=norm_kwargs)
self.causal = causal
self.trim_right_ratio = trim_right_ratio
assert self.causal or self.trim_right_ratio == 1., \
"`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
assert self.trim_right_ratio >= 0. and self.trim_right_ratio <= 1.
def forward(self, x):
kernel_size = self.convtr.convtr.kernel_size[0]
stride = self.convtr.convtr.stride[0]
padding_total = kernel_size - stride
y = self.convtr(x)
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
# removed at the very end, when keeping only the right length for the output,
# as removing it here would require also passing the length at the matching layer
# in the encoder.
if self.causal:
# Trim the padding on the right according to the specified ratio
# if trim_right_ratio = 1.0, trim everything from right
padding_right = math.ceil(padding_total * self.trim_right_ratio)
padding_left = padding_total - padding_right
y = unpad1d(y, (padding_left, padding_right))
else:
# Asymmetric padding required for odd strides
padding_right = padding_total // 2
padding_left = padding_total - padding_right
y = unpad1d(y, (padding_left, padding_right))
return y

View File

@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from torch import nn
class StreamableLSTM(nn.Module):
"""LSTM without worrying about the hidden state, nor the layout of the data.
Expects input as convolutional layout.
"""
def __init__(self, dimension: int, num_layers: int = 2, skip: bool = True):
super().__init__()
self.skip = skip
self.lstm = nn.LSTM(dimension, dimension, num_layers)
def forward(self, x):
x = x.permute(2, 0, 1)
y, _ = self.lstm(x)
if self.skip:
y = y + x
y = y.permute(1, 2, 0)
return y

124
audiocraft/modules/rope.py Normal file
View File

@ -0,0 +1,124 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing as tp
from torch import nn
import torch
class XPos(nn.Module):
"""Length-extrapolatable positional embedding (xPos) from [Sun et al 2022](https://arxiv.org/abs/2212.10554v1).
This applies an exponential decay to the RoPE rotation matrix.
Args:
dim (int): Embedding dimension.
smoothing (float): Smoothing factor applied to the decay rates.
base_scale (int): Base decay rate, given in terms of scaling time.
device (torch.device or None): Device on which to initialize the module.
dtype (torch.dtype): dtype to use to generate the embedding.
"""
def __init__(self, dim: int, smoothing: float = 0.4, base_scale: int = 512,
device=None, dtype: torch.dtype = torch.float32):
super().__init__()
assert dim % 2 == 0
assert dtype in [torch.float64, torch.float32]
self.dtype = dtype
self.base_scale = base_scale
half_dim = dim // 2
adim = torch.arange(half_dim, device=device, dtype=dtype)
decay_rates = (adim / half_dim + smoothing) / (1.0 + smoothing)
self.register_buffer("decay_rates", decay_rates)
self.decay: tp.Optional[torch.Tensor] = None
def get_decay(self, start: int, end: int):
"""Create complex decay tensor, cache values for fast computation.
"""
if self.decay is None or end > self.decay.shape[0]:
assert isinstance(self.decay_rates, torch.Tensor) # Satisfy type checker.
idx = torch.arange(end, device=self.decay_rates.device, dtype=self.dtype)
power = idx / self.base_scale
scale = self.decay_rates ** power.unsqueeze(-1)
self.decay = torch.polar(scale, torch.zeros_like(scale))
return self.decay[start:end] # [T, C/2]
class RotaryEmbedding(nn.Module):
"""Rotary positional embedding (RoPE) from [Su et al 2022](https://arxiv.org/abs/2104.09864).
Args:
dim (int): Embedding dimension (twice the number of frequencies).
max_period (float): Maximum period of the rotation frequencies.
xpos (bool): Use xPos, applies an exponential decay to rotation matrix.
scale (float): Scale of positional embedding, set to 0 to deactivate.
device (torch.device or None): Device on which to initialize the module.
dtype (torch.dtype): dtype to use to generate the embedding.
"""
def __init__(self, dim: int, max_period: float = 10000.0, xpos: bool = False,
scale: float = 1.0, device=None, dtype: torch.dtype = torch.float32):
super().__init__()
assert dim % 2 == 0
self.scale = scale
assert dtype in [torch.float64, torch.float32]
self.dtype = dtype
adim = torch.arange(0, dim, 2, device=device, dtype=dtype)[: (dim // 2)]
frequencies = 1.0 / (max_period ** (adim / dim))
self.register_buffer("frequencies", frequencies)
self.rotation: tp.Optional[torch.Tensor] = None
self.xpos = XPos(dim, device=device, dtype=dtype) if xpos else None
def get_rotation(self, start: int, end: int):
"""Create complex rotation tensor, cache values for fast computation.
"""
if self.rotation is None or end > self.rotation.shape[0]:
assert isinstance(self.frequencies, torch.Tensor) # Satisfy type checker.
idx = torch.arange(end, device=self.frequencies.device, dtype=self.dtype)
angles = torch.outer(idx, self.frequencies)
self.rotation = torch.polar(torch.ones_like(angles), angles)
return self.rotation[start:end]
def rotate(self, x: torch.Tensor, start: int = 0, invert_decay: bool = False):
"""Apply rope rotation to query or key tensor.
"""
T = x.shape[1]
rotation = self.get_rotation(start, start + T).unsqueeze(0).unsqueeze(2)
if self.xpos:
decay = self.xpos.get_decay(start, start + T).unsqueeze(0).unsqueeze(2)
else:
decay = 1.0
if invert_decay:
decay = decay ** -1
x_complex = torch.view_as_complex(x.to(self.dtype).reshape(*x.shape[:-1], -1, 2))
scaled_rotation = (rotation * decay) * self.scale + (1.0 - self.scale)
x_out = torch.view_as_real(x_complex * scaled_rotation).flatten(-2)
return x_out.type_as(x)
def rotate_qk(self, query: torch.Tensor, key: torch.Tensor, start: int = 0):
""" Apply rope rotation to both query and key tensors.
Supports streaming mode, in which query and key are not expected to have the same shape.
In streaming mode, key will be of legnth [P + C] with P the cached past timesteps, but
query will be [C] (typically C == 1).
Args:
query (torch.Tensor): Query to rotate.
key (torch.Tensor): Key to rotate.
start (int): Start index of the sequence for time offset.
"""
query_timesteps = query.shape[1]
key_timesteps = key.shape[1]
streaming_offset = key_timesteps - query_timesteps
query_out = self.rotate(query, start + streaming_offset)
key_out = self.rotate(key, start, invert_decay=True)
return query_out, key_out

View File

@ -0,0 +1,258 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing as tp
import numpy as np
import torch.nn as nn
from .conv import StreamableConv1d, StreamableConvTranspose1d
from .lstm import StreamableLSTM
class SEANetResnetBlock(nn.Module):
"""Residual block from SEANet model.
Args:
dim (int): Dimension of the input/output.
kernel_sizes (list): List of kernel sizes for the convolutions.
dilations (list): List of dilations for the convolutions.
activation (str): Activation function.
activation_params (dict): Parameters to provide to the activation function.
norm (str): Normalization method.
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
causal (bool): Whether to use fully causal convolution.
pad_mode (str): Padding mode for the convolutions.
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
true_skip (bool): Whether to use true skip connection or a simple
(streamable) convolution as the skip connection.
"""
def __init__(self, dim: int, kernel_sizes: tp.List[int] = [3, 1], dilations: tp.List[int] = [1, 1],
activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, causal: bool = False,
pad_mode: str = 'reflect', compress: int = 2, true_skip: bool = True):
super().__init__()
assert len(kernel_sizes) == len(dilations), 'Number of kernel sizes should match number of dilations'
act = getattr(nn, activation)
hidden = dim // compress
block = []
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
in_chs = dim if i == 0 else hidden
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
block += [
act(**activation_params),
StreamableConv1d(in_chs, out_chs, kernel_size=kernel_size, dilation=dilation,
norm=norm, norm_kwargs=norm_params,
causal=causal, pad_mode=pad_mode),
]
self.block = nn.Sequential(*block)
self.shortcut: nn.Module
if true_skip:
self.shortcut = nn.Identity()
else:
self.shortcut = StreamableConv1d(dim, dim, kernel_size=1, norm=norm, norm_kwargs=norm_params,
causal=causal, pad_mode=pad_mode)
def forward(self, x):
return self.shortcut(x) + self.block(x)
class SEANetEncoder(nn.Module):
"""SEANet encoder.
Args:
channels (int): Audio channels.
dimension (int): Intermediate representation dimension.
n_filters (int): Base width for the model.
n_residual_layers (int): nb of residual layers.
ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
that must match the decoder order. We use the decoder order as some models may only employ the decoder.
activation (str): Activation function.
activation_params (dict): Parameters to provide to the activation function.
norm (str): Normalization method.
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
kernel_size (int): Kernel size for the initial convolution.
last_kernel_size (int): Kernel size for the initial convolution.
residual_kernel_size (int): Kernel size for the residual layers.
dilation_base (int): How much to increase the dilation with each layer.
causal (bool): Whether to use fully causal convolution.
pad_mode (str): Padding mode for the convolutions.
true_skip (bool): Whether to use true skip connection or a simple
(streamable) convolution as the skip connection in the residual network blocks.
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
lstm (int): Number of LSTM layers at the end of the encoder.
disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
For the encoder, it corresponds to the N first blocks.
"""
def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
disable_norm_outer_blocks: int = 0):
super().__init__()
self.channels = channels
self.dimension = dimension
self.n_filters = n_filters
self.ratios = list(reversed(ratios))
del ratios
self.n_residual_layers = n_residual_layers
self.hop_length = np.prod(self.ratios)
self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
self.disable_norm_outer_blocks = disable_norm_outer_blocks
assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
"Number of blocks for which to disable norm is invalid." \
"It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
act = getattr(nn, activation)
mult = 1
model: tp.List[nn.Module] = [
StreamableConv1d(channels, mult * n_filters, kernel_size,
norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
]
# Downsample to raw audio scale
for i, ratio in enumerate(self.ratios):
block_norm = 'none' if self.disable_norm_outer_blocks >= i + 2 else norm
# Add residual layers
for j in range(n_residual_layers):
model += [
SEANetResnetBlock(mult * n_filters, kernel_sizes=[residual_kernel_size, 1],
dilations=[dilation_base ** j, 1],
norm=block_norm, norm_params=norm_params,
activation=activation, activation_params=activation_params,
causal=causal, pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
# Add downsampling layers
model += [
act(**activation_params),
StreamableConv1d(mult * n_filters, mult * n_filters * 2,
kernel_size=ratio * 2, stride=ratio,
norm=block_norm, norm_kwargs=norm_params,
causal=causal, pad_mode=pad_mode),
]
mult *= 2
if lstm:
model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
model += [
act(**activation_params),
StreamableConv1d(mult * n_filters, dimension, last_kernel_size,
norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
]
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
class SEANetDecoder(nn.Module):
"""SEANet decoder.
Args:
channels (int): Audio channels.
dimension (int): Intermediate representation dimension.
n_filters (int): Base width for the model.
n_residual_layers (int): nb of residual layers.
ratios (Sequence[int]): kernel size and stride ratios.
activation (str): Activation function.
activation_params (dict): Parameters to provide to the activation function.
final_activation (str): Final activation function after all convolutions.
final_activation_params (dict): Parameters to provide to the activation function.
norm (str): Normalization method.
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
kernel_size (int): Kernel size for the initial convolution.
last_kernel_size (int): Kernel size for the initial convolution.
residual_kernel_size (int): Kernel size for the residual layers.
dilation_base (int): How much to increase the dilation with each layer.
causal (bool): Whether to use fully causal convolution.
pad_mode (str): Padding mode for the convolutions.
true_skip (bool): Whether to use true skip connection or a simple.
(streamable) convolution as the skip connection in the residual network blocks.
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
lstm (int): Number of LSTM layers at the end of the encoder.
disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
For the decoder, it corresponds to the N last blocks.
trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
If equal to 1.0, it means that all the trimming is done at the right.
"""
def __init__(self, channels: int = 1, dimension: int = 128, n_filters: int = 32, n_residual_layers: int = 3,
ratios: tp.List[int] = [8, 5, 4, 2], activation: str = 'ELU', activation_params: dict = {'alpha': 1.0},
final_activation: tp.Optional[str] = None, final_activation_params: tp.Optional[dict] = None,
norm: str = 'none', norm_params: tp.Dict[str, tp.Any] = {}, kernel_size: int = 7,
last_kernel_size: int = 7, residual_kernel_size: int = 3, dilation_base: int = 2, causal: bool = False,
pad_mode: str = 'reflect', true_skip: bool = True, compress: int = 2, lstm: int = 0,
disable_norm_outer_blocks: int = 0, trim_right_ratio: float = 1.0):
super().__init__()
self.dimension = dimension
self.channels = channels
self.n_filters = n_filters
self.ratios = ratios
del ratios
self.n_residual_layers = n_residual_layers
self.hop_length = np.prod(self.ratios)
self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
self.disable_norm_outer_blocks = disable_norm_outer_blocks
assert self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks, \
"Number of blocks for which to disable norm is invalid." \
"It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
act = getattr(nn, activation)
mult = int(2 ** len(self.ratios))
model: tp.List[nn.Module] = [
StreamableConv1d(dimension, mult * n_filters, kernel_size,
norm='none' if self.disable_norm_outer_blocks == self.n_blocks else norm,
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
]
if lstm:
model += [StreamableLSTM(mult * n_filters, num_layers=lstm)]
# Upsample to raw audio scale
for i, ratio in enumerate(self.ratios):
block_norm = 'none' if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1) else norm
# Add upsampling layers
model += [
act(**activation_params),
StreamableConvTranspose1d(mult * n_filters, mult * n_filters // 2,
kernel_size=ratio * 2, stride=ratio,
norm=block_norm, norm_kwargs=norm_params,
causal=causal, trim_right_ratio=trim_right_ratio),
]
# Add residual layers
for j in range(n_residual_layers):
model += [
SEANetResnetBlock(mult * n_filters // 2, kernel_sizes=[residual_kernel_size, 1],
dilations=[dilation_base ** j, 1],
activation=activation, activation_params=activation_params,
norm=block_norm, norm_params=norm_params, causal=causal,
pad_mode=pad_mode, compress=compress, true_skip=true_skip)]
mult //= 2
# Add final layers
model += [
act(**activation_params),
StreamableConv1d(n_filters, channels, last_kernel_size,
norm='none' if self.disable_norm_outer_blocks >= 1 else norm,
norm_kwargs=norm_params, causal=causal, pad_mode=pad_mode)
]
# Add optional final activation to decoder (eg. tanh)
if final_activation is not None:
final_act = getattr(nn, final_activation)
final_activation_params = final_activation_params or {}
model += [
final_act(**final_activation_params)
]
self.model = nn.Sequential(*model)
def forward(self, z):
y = self.model(z)
return y

View File

@ -0,0 +1,135 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Streaming module API that should be implemented by all Streaming components,
"""
from contextlib import contextmanager
import typing as tp
from torch import nn
import torch
State = tp.Dict[str, torch.Tensor]
class StreamingModule(nn.Module):
"""Common API for streaming components.
Each streaming component has a streaming state, which is just a dict[str, Tensor].
By convention, the first dim of each tensor must be the batch size.
Don't use dots in the key names, as this would clash with submodules
(like in state_dict).
If `self._is_streaming` is True, the component should use and remember
the proper state inside `self._streaming_state`.
To set a streaming component in streaming state, use
with module.streaming():
...
This will automatically reset the streaming state when exiting the context manager.
This also automatically propagates to all streaming children module.
Some module might also implement the `StreamingModule.flush` method, although
this one is trickier, as all parents module must be StreamingModule and implement
it as well for it to work properly. See `StreamingSequential` after.
"""
def __init__(self) -> None:
super().__init__()
self._streaming_state: State = {}
self._is_streaming = False
def _apply_named_streaming(self, fn: tp.Any):
for name, module in self.named_modules():
if isinstance(module, StreamingModule):
fn(name, module)
def _set_streaming(self, streaming: bool):
def _set_streaming(name, module):
module._is_streaming = streaming
self._apply_named_streaming(_set_streaming)
@contextmanager
def streaming(self):
"""Context manager to enter streaming mode. Reset streaming state on exit.
"""
self._set_streaming(True)
try:
yield
finally:
self._set_streaming(False)
self.reset_streaming()
def reset_streaming(self):
"""Reset the streaming state.
"""
def _reset(name: str, module: StreamingModule):
module._streaming_state.clear()
self._apply_named_streaming(_reset)
def get_streaming_state(self) -> State:
"""Return the streaming state, including that of sub-modules.
"""
state: State = {}
def _add(name: str, module: StreamingModule):
if name:
name += "."
for key, value in module._streaming_state.items():
state[name + key] = value
self._apply_named_streaming(_add)
return state
def set_streaming_state(self, state: State):
"""Set the streaming state, including that of sub-modules.
"""
state = dict(state)
def _set(name: str, module: StreamingModule):
if name:
name += "."
module._streaming_state.clear()
for key, value in list(state.items()):
# complexity is not ideal here, but probably fine.
if key.startswith(name):
local_key = key[len(name):]
if '.' not in local_key:
module._streaming_state[local_key] = value
del state[key]
self._apply_named_streaming(_set)
assert len(state) == 0, list(state.keys())
def flush(self, x: tp.Optional[torch.Tensor] = None):
"""Flush any remaining outputs that were waiting for completion.
Typically, for convolutions, this will add the final padding
and process the last buffer.
This should take an optional argument `x`, which will be provided
if a module before this one in the streaming pipeline has already
spitted out a flushed out buffer.
"""
if x is None:
return None
else:
return self(x)
class StreamingSequential(StreamingModule, nn.Sequential):
"""A streaming compatible alternative of `nn.Sequential`.
"""
def flush(self, x: tp.Optional[torch.Tensor] = None):
for module in self:
if isinstance(module, StreamingModule):
x = module.flush(x)
elif x is not None:
x = module(x)
return x

View File

@ -0,0 +1,747 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Transformer model, with streaming support, xformer attention support
and easy causal attention with a potentially finite receptive field.
See `StreamingTransformer` for more information.
Unlike regular PyTorch Transformer, we make the hard choice that batches are first.
"""
import typing as tp
from einops import rearrange
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.checkpoint import checkpoint as torch_checkpoint
from xformers import ops
from .rope import RotaryEmbedding
from .streaming import StreamingModule
_efficient_attention_backend: str = 'torch'
def set_efficient_attention_backend(backend: str = 'torch'):
# Using torch by default, it seems a bit faster on older P100 GPUs (~20% faster).
global _efficient_attention_backend
assert _efficient_attention_backend in ['xformers', 'torch']
_efficient_attention_backend = backend
def _get_attention_time_dimension() -> int:
if _efficient_attention_backend == 'torch':
return 2
else:
return 1
def _is_profiled() -> bool:
# Return true if we are currently running with a xformers profiler activated.
try:
from xformers.profiler import profiler
except ImportError:
return False
return profiler._Profiler._CURRENT_PROFILER is not None
def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module:
"""Create normalization module for transformer encoder layer.
Args:
norm_type (str): Normalization method.
dim (int): Dimension of the normalized layer.
**kwargs (dict): Additional parameters for normalization layer.
Returns:
nn.Module: Normalization module.
"""
if norm_type == 'layer_norm':
return nn.LayerNorm(dim, eps=1e-5, **kwargs)
else:
raise ValueError(f"Unknown norm type: {norm_type}")
def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000,
dtype: torch.dtype = torch.float32) -> torch.Tensor:
"""Create sinusoidal positional embedding, with shape `[B, T, C]`.
Args:
positions (torch.Tensor): LongTensor of positions.
dim (int): Dimension of the embedding.
max_period (float): Maximum period of the cosine/sine functions.
dtype (torch.dtype or str): dtype to use to generate the embedding.
Returns:
torch.Tensor: Sinusoidal positional embedding.
"""
# We aim for BTC format
assert dim % 2 == 0
half_dim = dim // 2
positions = positions.to(dtype)
adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1)
max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype) # avoid sync point
phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
def expand_repeated_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers"""
if n_rep == 1:
return x
if _efficient_attention_backend == 'torch':
bs, n_kv_heads, slen, head_dim = x.shape
return (
x[:, :, None, :, :]
.expand(bs, n_kv_heads, n_rep, slen, head_dim)
.reshape(bs, n_kv_heads * n_rep, slen, head_dim)
)
else:
bs, slen, n_kv_heads, head_dim = x.shape
return (
x[:, :, :, None, :]
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)
class LayerScale(nn.Module):
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
This rescales diagonaly the residual outputs close to 0, with a learnt scale.
Args:
channels (int): Number of channels.
init (float): Initial scale.
channel_last (bool): If True, expect `[*, C]` shaped tensors, otherwise, `[*, C, T]`.
device (torch.device or None): Device on which to initialize the module.
dtype (torch.dtype or None): dtype to use to initialize the module.
"""
def __init__(self, channels: int, init: float = 1e-4, channel_last: bool = True,
device=None, dtype=None):
super().__init__()
self.channel_last = channel_last
self.scale = nn.Parameter(
torch.full((channels,), init,
requires_grad=True, device=device, dtype=dtype))
def forward(self, x: torch.Tensor):
if self.channel_last:
return self.scale * x
else:
return self.scale[:, None] * x
class StreamingMultiheadAttention(StreamingModule):
"""Similar to `nn.MultiheadAttention` but with support for streaming, causal evaluation.
Args:
embed_dim (int): Dimension to project to.
num_heads (int): Number of heads.
dropout (float): Dropout level.
bias (bool): Use bias in projections.
causal (bool): Causal mask applied automatically.
past_context (int or None): Receptive field for the causal mask, infinite if None.
custom (bool): Use custom MHA implementation, for testing / benchmarking.
memory_efficient (bool): Use xformers based memory efficient attention.
attention_as_float32 (bool): Perform the attention as float32
(especially important with memory_efficient as autocast won't do this automatically).
rope (`RotaryEmbedding` or None): Rope embedding to use.
cross_attention: Should be true when used as a cross attention.
All keys and values must be available at once, streaming is only for the queries.
Cannot be used with `causal` or `rope` (as it wouldn't make sens to
intepret the time steps in the keys relative to those in the queries).
safe_streaming (bool): Bug fix, will go away with xformers update.
qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product.
kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
This will lead to faster decoding time on A100 or other GPUs with tensorcore.
device (torch.device or None): Sevice on which to initialize.
dtype (torch.dtype or None): dtype to use.
"""
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True,
causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False,
memory_efficient: bool = False, attention_as_float32: bool = False,
rope: tp.Optional[RotaryEmbedding] = None, cross_attention: bool = False,
safe_streaming: bool = True, qk_layer_norm: bool = False, kv_repeat: int = 1,
device=None, dtype=None):
super().__init__()
factory_kwargs = {'device': device, 'dtype': dtype}
if past_context is not None:
assert causal
self.embed_dim = embed_dim
self.causal = causal
self.past_context = past_context
self.memory_efficient = memory_efficient
self.attention_as_float32 = attention_as_float32
self.rope = rope
self.cross_attention = cross_attention
self.safe_streaming = safe_streaming
self.num_heads = num_heads
self.dropout = dropout
self.kv_repeat = kv_repeat
if cross_attention:
assert not causal, "Causal cannot work with cross attention."
assert rope is None, "Rope cannot work with cross attention."
if memory_efficient:
_verify_xformers_memory_efficient_compat()
self.custom = _is_custom(custom, memory_efficient)
if self.custom:
out_dim = embed_dim
assert num_heads % kv_repeat == 0
assert not cross_attention or kv_repeat == 1
num_kv = num_heads // kv_repeat
kv_dim = (embed_dim // num_heads) * num_kv
out_dim += 2 * kv_dim
in_proj = nn.Linear(embed_dim, out_dim, bias=bias, **factory_kwargs)
# We try to follow the default PyTorch MHA convention, to easily compare results.
self.in_proj_weight = in_proj.weight
self.in_proj_bias = in_proj.bias
if bias:
self.in_proj_bias.data.zero_() # Following Pytorch convention
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
if bias:
self.out_proj.bias.data.zero_()
else:
assert not qk_layer_norm
assert kv_repeat == 1
self.mha = nn.MultiheadAttention(
embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True,
**factory_kwargs)
self.qk_layer_norm = qk_layer_norm
if qk_layer_norm:
assert self.custom
assert kv_repeat == 1
ln_dim = embed_dim
self.q_layer_norm = nn.LayerNorm(ln_dim)
self.k_layer_norm = nn.LayerNorm(ln_dim)
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
if not self.custom:
# Support compat with regular MHA
keys = [n for n, _ in self.mha.named_parameters()]
for key in keys:
if prefix + key in state_dict:
state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key)
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
def _get_mask(self, current_steps: int, device: torch.device, dtype: torch.dtype):
# Return a causal mask, accounting for potentially stored past keys/values
# We actually return a bias for the attention score, as this has the same
# convention both in the builtin MHA in Pytorch, and Xformers functions.
time_dim = _get_attention_time_dimension()
if self.memory_efficient:
from xformers.ops import LowerTriangularMask
if current_steps == 1:
# If we only have one step, then we do not need a mask.
return None
elif 'past_keys' in self._streaming_state:
raise RuntimeError('Not supported at the moment')
else:
# Then we can safely use a lower triangular mask
return LowerTriangularMask()
if self._streaming_state:
past_keys = self._streaming_state['past_keys']
past_steps = past_keys.shape[time_dim]
else:
past_steps = 0
queries_pos = torch.arange(
past_steps, current_steps + past_steps, device=device).view(-1, 1)
keys_pos = torch.arange(past_steps + current_steps, device=device).view(1, -1)
delta = queries_pos - keys_pos
valid = delta >= 0
if self.past_context is not None:
valid &= (delta <= self.past_context)
return torch.where(
valid,
torch.zeros([], device=device, dtype=dtype),
torch.full([], float('-inf'), device=device, dtype=dtype))
def _complete_kv(self, k, v):
time_dim = _get_attention_time_dimension()
if self.cross_attention:
# With cross attention we assume all keys and values
# are already available, and streaming is with respect
# to the queries only.
return k, v
# Complete the key/value pair using the streaming state.
if self._streaming_state:
pk = self._streaming_state['past_keys']
nk = torch.cat([pk, k], dim=time_dim)
if v is k:
nv = nk
else:
pv = self._streaming_state['past_values']
nv = torch.cat([pv, v], dim=time_dim)
else:
nk = k
nv = v
assert nk.shape[time_dim] == nv.shape[time_dim]
offset = 0
if self.past_context is not None:
offset = max(0, nk.shape[time_dim] - self.past_context)
if self._is_streaming:
self._streaming_state['past_keys'] = nk[:, offset:]
if v is not k:
self._streaming_state['past_values'] = nv[:, offset:]
if 'offset' in self._streaming_state:
self._streaming_state['offset'] += offset
else:
self._streaming_state['offset'] = torch.tensor(0)
return nk, nv
def _apply_rope(self, query: torch.Tensor, key: torch.Tensor):
# TODO: fix and verify layout.
assert _efficient_attention_backend == 'xformers', 'Rope not supported with torch attn.'
# Apply rope embeddings to query and key tensors.
assert self.rope is not None
if 'past_keys' in self._streaming_state:
past_keys_offset = self._streaming_state['past_keys'].shape[1]
else:
past_keys_offset = 0
if 'offset' in self._streaming_state:
past_context_offset = int(self._streaming_state['offset'].item())
else:
past_context_offset = 0
streaming_offset = past_context_offset + past_keys_offset
return self.rope.rotate_qk(query, key, start=streaming_offset)
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
key_padding_mask=None, need_weights=False, attn_mask=None,
average_attn_weights=True, is_causal=False):
assert attn_mask is None
assert not is_causal, ("new param added in torch 2.0.1 not supported, "
"use the causal args in the constructor.")
time_dim = _get_attention_time_dimension()
if time_dim == 2:
layout = "b h t d"
else:
layout = "b t h d"
dtype = query.dtype
if self._is_streaming:
assert self.causal or self.cross_attention, \
"Streaming only available for causal or cross attention"
if self.causal:
# At the moment we specialize only for the self-attention case.
assert query.shape[1] == key.shape[1], "Causal only for same length query / key / value"
assert value.shape[1] == key.shape[1], "Causal only for same length query / key / value"
attn_mask = self._get_mask(query.shape[1], query.device, query.dtype)
if self.custom:
# custom implementation
assert need_weights is False
assert key_padding_mask is None
if self.cross_attention:
# Different queries, keys, values, we have to spit manually the weights
# before applying the linear.
dim = self.in_proj_weight.shape[0] // 3
if self.in_proj_bias is None:
bias_q, bias_k, bias_v = None, None, None
else:
bias_q = self.in_proj_bias[:dim]
bias_k = self.in_proj_bias[dim: 2 * dim]
bias_v = self.in_proj_bias[2 * dim:]
q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q)
# todo: when streaming, we could actually save k, v and check the shape actually match.
k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim], bias_k)
v = nn.functional.linear(value, self.in_proj_weight[2 * dim:], bias_v)
if self.qk_layer_norm is True:
q = self.q_layer_norm(q)
k = self.k_layer_norm(k)
q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]]
else:
if not _is_profiled():
# profiling breaks that propertysomehow.
assert query is key, "specialized implementation"
assert value is key, "specialized implementation"
projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias)
if self.kv_repeat == 1:
if time_dim == 2:
bound_layout = "b h p t d"
else:
bound_layout = "b t p h d"
packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads)
q, k, v = ops.unbind(packed, dim=2)
else:
embed_dim = self.embed_dim
per_head_dim = (embed_dim // self.num_heads)
kv_heads = self.num_heads // self.kv_repeat
q = projected[:, :, :embed_dim]
start = embed_dim
end = start + per_head_dim * kv_heads
k = projected[:, :, start: end]
v = projected[:, :, end:]
q = rearrange(q, f"b t (h d) -> {layout}", h=self.num_heads)
k = rearrange(k, f"b t (h d) -> {layout}", h=kv_heads)
v = rearrange(v, f"b t (h d) -> {layout}", h=kv_heads)
if self.qk_layer_norm is True:
assert self.kv_repeat == 1
q, k = [rearrange(x, f"{layout} -> b t (h d)") for x in [q, k]]
q = self.q_layer_norm(q)
k = self.k_layer_norm(k)
q, k = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k]]
if self.rope:
q, k = self._apply_rope(q, k)
k, v = self._complete_kv(k, v)
if self.kv_repeat > 1:
k = expand_repeated_kv(k, self.kv_repeat)
v = expand_repeated_kv(v, self.kv_repeat)
if self.attention_as_float32:
q, k, v = [x.float() for x in [q, k, v]]
if self.memory_efficient:
p = self.dropout if self.training else 0
if _efficient_attention_backend == 'torch':
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v, is_causal=attn_mask is not None, dropout_p=p)
else:
x = ops.memory_efficient_attention(q, k, v, attn_mask, p=p)
else:
# We include the dot product as float32, for consistency
# with the other implementations that include that step
# as part of the attention. Note that when using `autocast`,
# the einsums would be done as bfloat16, but the softmax
# would be done as bfloat16, so `attention_as_float32` will
# extend a bit the range of operations done in float32,
# although this should make no difference.
q = q / q.shape[-1] ** 0.5
key_layout = layout.replace('t', 'k')
query_layout = layout
if self._is_streaming and self.safe_streaming and q.device.type == 'cuda':
with torch.autocast(device_type=q.device.type, dtype=torch.float32):
pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
else:
pre_w = torch.einsum(f"{query_layout},{key_layout}-> b h t k", q, k)
if attn_mask is not None:
pre_w = pre_w + attn_mask
w = torch.softmax(pre_w, dim=-1)
w = F.dropout(w, self.dropout, training=self.training).to(v)
# Key and value have the same format.
x = torch.einsum(f"b h t k, {key_layout} -> {layout}", w, v)
x = x.to(dtype)
x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads)
x = self.out_proj(x)
else:
key, value = self._complete_kv(key, value)
if self.attention_as_float32:
query, key, value = [x.float() for x in [query, key, value]]
x, _ = self.mha(
query, key, value, key_padding_mask,
need_weights, attn_mask, average_attn_weights)
x = x.to(dtype)
return x, None
class StreamingTransformerLayer(nn.TransformerEncoderLayer):
"""TransformerLayer with Streaming / Causal support.
This also integrates cross_attention, when passing `cross_attention=True`,
rather than having two separate classes like in PyTorch.
Args:
d_model (int): Dimension of the data.
num_heads (int): Number of heads.
dim_feedforward (int): Intermediate dimension of FF module.
dropout (float): Dropout both for MHA and FF.
bias_ff (bool): Use bias for FF.
bias_attn (bool): Use bias for MHA.
causal (bool): Causal mask applied automatically.
past_context (int or None): Receptive field for the causal mask, infinite if None.
custom (bool): Use custom MHA implementation, for testing / benchmarking.
memory_efficient (bool): Use xformers based memory efficient attention.
attention_as_float32 (bool): Perform the attention as float32
(especially important with memory_efficient as autocast won't do this automatically).
qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product in attention.
qk_layer_norm_cross (bool): Same for the cross attention.
cross_attention (bool): If True, expect to get secondary input for cross-attention.
Cross attention will use the default MHA, as it typically won't require
special treatment.
layer_scale (float or None): If not None, LayerScale will be used with
the given value as initial scale.
rope (`RotaryEmbedding` or None): Rope embedding to use.
attention_dropout (float or None): If not None, separate the value of the dimension dropout
in FFN and of the attention dropout.
kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads).
This will lead to faster decoding time on A100 or other GPUs with tensorcore.
device (torch.device or None): Device on which to initialize.
dtype (torch.dtype or None): dtype to use.
**kwargs: See `nn.TransformerEncoderLayer`.
"""
def __init__(self, d_model: int, num_heads: int, dim_feedforward: int = 2048, dropout: float = 0.1,
bias_ff: bool = True, bias_attn: bool = True, causal: bool = False,
past_context: tp.Optional[int] = None, custom: bool = False,
memory_efficient: bool = False, attention_as_float32: bool = False,
qk_layer_norm: bool = False, qk_layer_norm_cross: bool = False,
cross_attention: bool = False, layer_scale: tp.Optional[float] = None,
rope: tp.Optional[RotaryEmbedding] = None, attention_dropout: tp.Optional[float] = None,
kv_repeat: int = 1, norm: str = 'layer_norm', device=None, dtype=None, **kwargs):
super().__init__(d_model, num_heads, dim_feedforward, dropout,
device=device, dtype=dtype, batch_first=True, **kwargs)
factory_kwargs = {'device': device, 'dtype': dtype}
# Redefine self_attn to our streaming multi-head attention
attn_kwargs: tp.Dict[str, tp.Any] = {
'embed_dim': d_model,
'num_heads': num_heads,
'dropout': dropout if attention_dropout is None else attention_dropout,
'bias': bias_attn,
'custom': custom,
'memory_efficient': memory_efficient,
'attention_as_float32': attention_as_float32,
}
self.self_attn: StreamingMultiheadAttention = StreamingMultiheadAttention(
causal=causal, past_context=past_context, rope=rope, qk_layer_norm=qk_layer_norm,
kv_repeat=kv_repeat, **attn_kwargs, **factory_kwargs) # type: ignore
# Redefine feedforward layers to expose bias parameter
self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias_ff, **factory_kwargs)
self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias_ff, **factory_kwargs)
self.layer_scale_1: nn.Module
self.layer_scale_2: nn.Module
if layer_scale is None:
self.layer_scale_1 = nn.Identity()
self.layer_scale_2 = nn.Identity()
else:
self.layer_scale_1 = LayerScale(d_model, layer_scale, **factory_kwargs)
self.layer_scale_2 = LayerScale(d_model, layer_scale, **factory_kwargs)
self.cross_attention: tp.Optional[nn.Module] = None
if cross_attention:
self.cross_attention = StreamingMultiheadAttention(
cross_attention=True, qk_layer_norm=qk_layer_norm_cross,
**attn_kwargs, **factory_kwargs)
# Norm and dropout
self.dropout_cross = nn.Dropout(dropout)
# eps value matching that used in PyTorch reference implementation.
self.norm_cross = nn.LayerNorm(d_model, eps=1e-5, **factory_kwargs)
self.layer_scale_cross: nn.Module
if layer_scale is None:
self.layer_scale_cross = nn.Identity()
else:
self.layer_scale_cross = LayerScale(d_model, layer_scale, **factory_kwargs)
self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore
def _cross_attention_block(self, src: torch.Tensor,
cross_attention_src: torch.Tensor) -> torch.Tensor:
assert self.cross_attention is not None
# queries are from src, keys and values from cross_attention_src.
x = self.cross_attention(
src, cross_attention_src, cross_attention_src, need_weights=False)[0]
return self.dropout_cross(x) # type: ignore
def forward(self, src: torch.Tensor, src_mask: tp.Optional[torch.Tensor] = None, # type: ignore
src_key_padding_mask: tp.Optional[torch.Tensor] = None,
cross_attention_src: tp.Optional[torch.Tensor] = None):
if self.cross_attention is None:
assert cross_attention_src is None
else:
assert cross_attention_src is not None
x = src
if self.norm_first:
x = x + self.layer_scale_1(
self._sa_block(self.norm1(x), src_mask, src_key_padding_mask))
if cross_attention_src is not None:
x = x + self.layer_scale_cross(
self._cross_attention_block(
self.norm_cross(x), cross_attention_src))
x = x + self.layer_scale_2(self._ff_block(self.norm2(x)))
else:
x = self.norm1(x + self.layer_scale_1(
self._sa_block(x, src_mask, src_key_padding_mask)))
if cross_attention_src is not None:
x = self.norm_cross(
x + self.layer_scale_cross(
self._cross_attention_block(src, cross_attention_src)))
x = self.norm2(x + self.layer_scale_2(self._ff_block(x)))
return x
class StreamingTransformer(StreamingModule):
"""Transformer with Streaming / Causal support.
Args:
d_model (int): Dimension of the data.
num_heads (int): Number of heads.
dim_feedforward (int): Intermediate dimension of FF module.
dropout (float): Dropout both for MHA and FF.
bias_ff (bool): Use bias for FF.
bias_attn (bool): Use bias for MHA.
causal (bool): Causal mask applied automatically.
past_context (int or None): Receptive field for the causal mask, infinite if None.
custom (bool): Use custom MHA implementation, for testing / benchmarking.
memory_efficient (bool): Use xformers based memory efficient attention.
attention_as_float32 (bool): Perform the attention as float32
(especially important with memory_efficient as autocast won't do this automatically).
cross_attention (bool): If True, expect to get secondary input for cross-attention.
layer_scale (float or None): If not None, LayerScale will be used
with the given value as initial scale.
positional_embedding (str): Positional embedding strategy (sin, rope, or sin_rope).
max_period (float): Maximum period of the time embedding.
positional_scale (float): Scale of positional embedding, set to 0 to deactivate.
xpos (bool): Apply xpos exponential decay to positional embedding (rope only).
lr (float or None): learning rate override through the `make_optim_group` API.
weight_decay (float or None): Weight_decay override through the `make_optim_group` API.
layer_class: (subclass of `StreamingTransformerLayer): class to use
to initialize the layers, allowing further customization outside of Audiocraft.
checkpointing (str): Checkpointing strategy to reduce memory usage.
No checkpointing if set to 'none'. Per layer checkpointing using PyTorch
if set to 'torch' (entire layer checkpointed, i.e. linears are evaluated twice,
minimal memory usage, but maximal runtime). Finally, `xformers_default` provide
a policy for opting-out some operations of the checkpointing like
linear layers and attention, providing a middle ground between speed and memory.
device (torch.device or None): Device on which to initialize.
dtype (torch.dtype or None): dtype to use.
**kwargs: See `nn.TransformerEncoderLayer`.
"""
def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_feedforward: int = 2048,
dropout: float = 0.1, bias_ff: bool = True, bias_attn: bool = True,
causal: bool = False, past_context: tp.Optional[int] = None,
custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False,
cross_attention: bool = False, layer_scale: tp.Optional[float] = None,
positional_embedding: str = 'sin', max_period: float = 10_000, positional_scale: float = 1.,
xpos: bool = False, lr: tp.Optional[float] = None, weight_decay: tp.Optional[float] = None,
layer_class: tp.Type[StreamingTransformerLayer] = StreamingTransformerLayer,
checkpointing: str = 'none', device=None, dtype=None, **kwargs):
super().__init__()
assert d_model % num_heads == 0
self.positional_embedding = positional_embedding
self.max_period = max_period
self.positional_scale = positional_scale
self.weight_decay = weight_decay
self.lr = lr
assert positional_embedding in ['sin', 'rope', 'sin_rope']
self.rope: tp.Optional[RotaryEmbedding] = None
if self.positional_embedding in ['rope', 'sin_rope']:
assert _is_custom(custom, memory_efficient)
self.rope = RotaryEmbedding(d_model // num_heads, max_period=max_period,
xpos=xpos, scale=positional_scale, device=device)
self.checkpointing = checkpointing
assert checkpointing in ['none', 'torch', 'xformers_default', 'xformers_mm']
if self.checkpointing.startswith('xformers'):
_verify_xformers_internal_compat()
self.layers = nn.ModuleList()
for idx in range(num_layers):
self.layers.append(
layer_class(
d_model=d_model, num_heads=num_heads, dim_feedforward=dim_feedforward,
dropout=dropout, bias_ff=bias_ff, bias_attn=bias_attn,
causal=causal, past_context=past_context, custom=custom,
memory_efficient=memory_efficient, attention_as_float32=attention_as_float32,
cross_attention=cross_attention, layer_scale=layer_scale, rope=self.rope,
device=device, dtype=dtype, **kwargs))
if self.checkpointing != 'none':
for layer in self.layers:
# see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the
# backward hook inside of FSDP...
layer._magma_checkpointed = True # type: ignore
assert layer.layer_drop == 0., "Need further checking" # type: ignore
def _apply_layer(self, layer, *args, **kwargs):
method = self.checkpointing
if method == 'none':
return layer(*args, **kwargs)
elif method == 'torch':
return torch_checkpoint(layer, *args, use_reentrant=False, **kwargs)
elif method.startswith('xformers'):
from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy
if method == 'xformers_default':
# those operations will be saved, and not recomputed.
# According to Francisco we can get smarter policies but this is a good start.
allow_list = [
"xformers.efficient_attention_forward_cutlass.default",
"xformers_flash.flash_fwd.default",
"aten.addmm.default",
"aten.mm.default",
]
elif method == 'xformers_mm':
# those operations will be saved, and not recomputed.
# According to Francisco we can get smarter policies but this is a good start.
allow_list = [
"aten.addmm.default",
"aten.mm.default",
]
else:
raise ValueError(f"xformers checkpointing xformers policy {method} is not known.")
policy_fn = _get_default_policy(allow_list)
return checkpoint(layer, *args, policy_fn=policy_fn, **kwargs)
else:
raise ValueError(f"Checkpointing method {method} is unknown.")
def forward(self, x: torch.Tensor, *args, **kwargs):
B, T, C = x.shape
if 'offsets' in self._streaming_state:
offsets = self._streaming_state['offsets']
else:
offsets = torch.zeros(B, dtype=torch.long, device=x.device)
if self.positional_embedding in ['sin', 'sin_rope']:
positions = torch.arange(T, device=x.device).view(1, -1, 1)
positions = positions + offsets.view(-1, 1, 1)
pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype)
x = x + self.positional_scale * pos_emb
for layer in self.layers:
x = self._apply_layer(layer, x, *args, **kwargs)
if self._is_streaming:
self._streaming_state['offsets'] = offsets + T
return x
def make_optim_group(self):
group = {"params": list(self.parameters())}
if self.lr is not None:
group["lr"] = self.lr
if self.weight_decay is not None:
group["weight_decay"] = self.weight_decay
return group
# special attention attention related function
def _verify_xformers_memory_efficient_compat():
try:
from xformers.ops import memory_efficient_attention, LowerTriangularMask # noqa
except ImportError:
raise ImportError(
"xformers is not installed. Please install it and try again.\n"
"To install on AWS and Azure, run \n"
"FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n"
"pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n"
"To install on FAIR Cluster, run \n"
"FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n"
"pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n")
def _verify_xformers_internal_compat():
try:
from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy # noqa
except ImportError:
raise ImportError(
"Francisco's fairinternal xformers is not installed. Please install it and try again.\n"
"To install on AWS and Azure, run \n"
"FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n"
"pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n"
"To install on FAIR Cluster, run \n"
"FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n"
"pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n")
def _is_custom(custom: bool, memory_efficient: bool):
return custom or memory_efficient

0
audiocraft/py.typed Normal file
View File

View File

@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# flake8: noqa
from .vq import ResidualVectorQuantizer
from .base import BaseQuantizer, DummyQuantizer, QuantizedResult

View File

@ -0,0 +1,107 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Base class for all quantizers.
"""
from dataclasses import dataclass, field
import typing as tp
import torch
from torch import nn
@dataclass
class QuantizedResult:
x: torch.Tensor
codes: torch.Tensor
bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
penalty: tp.Optional[torch.Tensor] = None
metrics: dict = field(default_factory=dict)
class BaseQuantizer(nn.Module):
"""Base class for quantizers.
"""
def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult:
"""
Given input tensor x, returns first the quantized (or approximately quantized)
representation along with quantized codes, bandwidth, and any penalty term for the loss.
Finally, this returns a dict of metrics to update logging etc.
Frame rate must be passed so that the bandwidth is properly computed.
"""
raise NotImplementedError()
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
"""
raise NotImplementedError()
def decode(self, codes: torch.Tensor) -> torch.Tensor:
"""Decode the given codes to the quantized representation.
"""
raise NotImplementedError()
@property
def total_codebooks(self):
"""Total number of codebooks.
"""
raise NotImplementedError()
@property
def num_codebooks(self):
"""Number of active codebooks.
"""
raise NotImplementedError()
def set_num_codebooks(self, n: int):
"""Set the number of active codebooks.
"""
raise NotImplementedError()
class DummyQuantizer(BaseQuantizer):
"""Fake quantizer that actually does not perform any quantization.
"""
def __init__(self):
super().__init__()
def forward(self, x: torch.Tensor, frame_rate: int):
q = x.unsqueeze(1)
return QuantizedResult(x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x))
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
In the case of the DummyQuantizer, the codes are actually identical
to the input and resulting quantized representation as no quantization is done.
"""
return x.unsqueeze(1)
def decode(self, codes: torch.Tensor) -> torch.Tensor:
"""Decode the given codes to the quantized representation.
In the case of the DummyQuantizer, the codes are actually identical
to the input and resulting quantized representation as no quantization is done.
"""
return codes.squeeze(1)
@property
def total_codebooks(self):
"""Total number of codebooks.
"""
return 1
@property
def num_codebooks(self):
"""Total number of codebooks.
"""
return self.total_codebooks
def set_num_codebooks(self, n: int):
"""Set the number of active codebooks.
"""
raise AttributeError("Cannot override the number of codebooks for the dummy quantizer")

View File

@ -0,0 +1,400 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import typing as tp
from einops import rearrange, repeat
import flashy
import torch
from torch import nn, einsum
import torch.nn.functional as F
def exists(val: tp.Optional[tp.Any]) -> bool:
return val is not None
def default(val: tp.Any, d: tp.Any) -> tp.Any:
return val if exists(val) else d
def l2norm(t):
return F.normalize(t, p=2, dim=-1)
def ema_inplace(moving_avg, new, decay: float):
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
return (x + epsilon) / (x.sum() + n_categories * epsilon)
def uniform_init(*shape: int):
t = torch.empty(shape)
nn.init.kaiming_uniform_(t)
return t
def sample_vectors(samples, num: int):
num_samples, device = samples.shape[0], samples.device
if num_samples >= num:
indices = torch.randperm(num_samples, device=device)[:num]
else:
indices = torch.randint(0, num_samples, (num,), device=device)
return samples[indices]
def kmeans(samples, num_clusters: int, num_iters: int = 10):
dim, dtype = samples.shape[-1], samples.dtype
means = sample_vectors(samples, num_clusters)
for _ in range(num_iters):
diffs = rearrange(samples, "n d -> n () d") - rearrange(
means, "c d -> () c d"
)
dists = -(diffs ** 2).sum(dim=-1)
buckets = dists.max(dim=-1).indices
bins = torch.bincount(buckets, minlength=num_clusters)
zero_mask = bins == 0
bins_min_clamped = bins.masked_fill(zero_mask, 1)
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
new_means = new_means / bins_min_clamped[..., None]
means = torch.where(zero_mask[..., None], means, new_means)
return means, bins
def orthgonal_loss_fn(t):
# eq (2) from https://arxiv.org/abs/2112.00384
n = t.shape[0]
normed_codes = l2norm(t)
identity = torch.eye(n, device=t.device)
cosine_sim = einsum("i d, j d -> i j", normed_codes, normed_codes)
return ((cosine_sim - identity) ** 2).sum() / (n ** 2)
class EuclideanCodebook(nn.Module):
"""Codebook with Euclidean distance.
Args:
dim (int): Dimension.
codebook_size (int): Codebook size.
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
If set to true, run the k-means algorithm on the first training batch and use
the learned centroids as initialization.
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
decay (float): Decay for exponential moving average over the codebooks.
epsilon (float): Epsilon value for numerical stability.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
"""
def __init__(
self,
dim: int,
codebook_size: int,
kmeans_init: int = False,
kmeans_iters: int = 10,
decay: float = 0.8,
epsilon: float = 1e-5,
threshold_ema_dead_code: int = 2,
):
super().__init__()
self.decay = decay
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
embed = init_fn(codebook_size, dim)
self.codebook_size = codebook_size
self.kmeans_iters = kmeans_iters
self.epsilon = epsilon
self.threshold_ema_dead_code = threshold_ema_dead_code
self.register_buffer("inited", torch.Tensor([not kmeans_init]))
self.register_buffer("cluster_size", torch.zeros(codebook_size))
self.register_buffer("embed", embed)
self.register_buffer("embed_avg", embed.clone())
@torch.jit.ignore
def init_embed_(self, data):
if self.inited:
return
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
self.embed.data.copy_(embed)
self.embed_avg.data.copy_(embed.clone())
self.cluster_size.data.copy_(cluster_size)
self.inited.data.copy_(torch.Tensor([True]))
# Make sure all buffers across workers are in sync after initialization
flashy.distrib.broadcast_tensors(self.buffers())
def replace_(self, samples, mask):
modified_codebook = torch.where(
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
)
self.embed.data.copy_(modified_codebook)
def expire_codes_(self, batch_samples):
if self.threshold_ema_dead_code == 0:
return
expired_codes = self.cluster_size < self.threshold_ema_dead_code
if not torch.any(expired_codes):
return
batch_samples = rearrange(batch_samples, "... d -> (...) d")
self.replace_(batch_samples, mask=expired_codes)
flashy.distrib.broadcast_tensors(self.buffers())
def preprocess(self, x):
x = rearrange(x, "... d -> (...) d")
return x
def quantize(self, x):
embed = self.embed.t()
dist = -(
x.pow(2).sum(1, keepdim=True)
- 2 * x @ embed
+ embed.pow(2).sum(0, keepdim=True)
)
embed_ind = dist.max(dim=-1).indices
return embed_ind
def postprocess_emb(self, embed_ind, shape):
return embed_ind.view(*shape[:-1])
def dequantize(self, embed_ind):
quantize = F.embedding(embed_ind, self.embed)
return quantize
def encode(self, x):
shape = x.shape
# pre-process
x = self.preprocess(x)
# quantize
embed_ind = self.quantize(x)
# post-process
embed_ind = self.postprocess_emb(embed_ind, shape)
return embed_ind
def decode(self, embed_ind):
quantize = self.dequantize(embed_ind)
return quantize
def forward(self, x):
shape, dtype = x.shape, x.dtype
x = self.preprocess(x)
self.init_embed_(x)
embed_ind = self.quantize(x)
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
embed_ind = self.postprocess_emb(embed_ind, shape)
quantize = self.dequantize(embed_ind)
if self.training:
# We do the expiry of code at that point as buffers are in sync
# and all the workers will take the same decision.
self.expire_codes_(x)
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
embed_sum = x.t() @ embed_onehot
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
cluster_size = (
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
* self.cluster_size.sum()
)
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
self.embed.data.copy_(embed_normalized)
return quantize, embed_ind
class VectorQuantization(nn.Module):
"""Vector quantization implementation.
Currently supports only euclidean distance.
Args:
dim (int): Dimension
codebook_size (int): Codebook size
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
decay (float): Decay for exponential moving average over the codebooks.
epsilon (float): Epsilon value for numerical stability.
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
kmeans_iters (int): Number of iterations used for kmeans initialization.
threshold_ema_dead_code (int):
channels_last (bool): Channels are the last dimension in the input tensors.
commitment_weight (float): Weight for commitment loss.
orthogonal_reg_weight (float): Orthogonal regularization weights.
orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
orthogonal_reg_max_codes (optional int): Maximum number of codes to consider
for orthogonal regulariation.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
"""
def __init__(
self,
dim: int,
codebook_size: int,
codebook_dim: tp.Optional[int] = None,
decay: float = 0.8,
epsilon: float = 1e-5,
kmeans_init: bool = False,
kmeans_iters: int = 10,
threshold_ema_dead_code: int = 2,
channels_last: bool = False,
commitment_weight: float = 1.,
orthogonal_reg_weight: float = 0.0,
orthogonal_reg_active_codes_only: bool = False,
orthogonal_reg_max_codes: tp.Optional[int] = None,
):
super().__init__()
_codebook_dim: int = default(codebook_dim, dim)
requires_projection = _codebook_dim != dim
self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity())
self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity())
self.epsilon = epsilon
self.commitment_weight = commitment_weight
self.orthogonal_reg_weight = orthogonal_reg_weight
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
decay=decay, epsilon=epsilon,
threshold_ema_dead_code=threshold_ema_dead_code)
self.codebook_size = codebook_size
self.channels_last = channels_last
@property
def codebook(self):
return self._codebook.embed
@property
def inited(self):
return self._codebook.inited
def _preprocess(self, x):
if not self.channels_last:
x = rearrange(x, "b d n -> b n d")
return x
def _postprocess(self, quantize):
if not self.channels_last:
quantize = rearrange(quantize, "b n d -> b d n")
return quantize
def encode(self, x):
x = self._preprocess(x)
x = self.project_in(x)
embed_in = self._codebook.encode(x)
return embed_in
def decode(self, embed_ind):
quantize = self._codebook.decode(embed_ind)
quantize = self.project_out(quantize)
quantize = self._postprocess(quantize)
return quantize
def forward(self, x):
device = x.device
x = self._preprocess(x)
x = self.project_in(x)
quantize, embed_ind = self._codebook(x)
if self.training:
quantize = x + (quantize - x).detach()
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
if self.training:
if self.commitment_weight > 0:
commit_loss = F.mse_loss(quantize.detach(), x)
loss = loss + commit_loss * self.commitment_weight
if self.orthogonal_reg_weight > 0:
codebook = self.codebook
if self.orthogonal_reg_active_codes_only:
# only calculate orthogonal loss for the activated codes for this batch
unique_code_ids = torch.unique(embed_ind)
codebook = codebook[unique_code_ids]
num_codes = codebook.shape[0]
if exists(self.orthogonal_reg_max_codes) and num_codes > self.orthogonal_reg_max_codes:
rand_ids = torch.randperm(num_codes, device=device)[:self.orthogonal_reg_max_codes]
codebook = codebook[rand_ids]
orthogonal_reg_loss = orthgonal_loss_fn(codebook)
loss = loss + orthogonal_reg_loss * self.orthogonal_reg_weight
quantize = self.project_out(quantize)
quantize = self._postprocess(quantize)
return quantize, embed_ind, loss
class ResidualVectorQuantization(nn.Module):
"""Residual vector quantization implementation.
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
"""
def __init__(self, *, num_quantizers, **kwargs):
super().__init__()
self.layers = nn.ModuleList(
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
)
def forward(self, x, n_q: tp.Optional[int] = None):
quantized_out = 0.0
residual = x
all_losses = []
all_indices = []
n_q = n_q or len(self.layers)
for i, layer in enumerate(self.layers[:n_q]):
quantized, indices, loss = layer(residual)
residual = residual - quantized
quantized_out = quantized_out + quantized
all_indices.append(indices)
all_losses.append(loss)
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
return quantized_out, out_indices, out_losses
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
residual = x
all_indices = []
n_q = n_q or len(self.layers)
for layer in self.layers[:n_q]:
indices = layer.encode(residual)
quantized = layer.decode(indices)
residual = residual - quantized
all_indices.append(indices)
out_indices = torch.stack(all_indices)
return out_indices
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
quantized_out = torch.tensor(0.0, device=q_indices.device)
for i, indices in enumerate(q_indices):
layer = self.layers[i]
quantized = layer.decode(indices)
quantized_out = quantized_out + quantized
return quantized_out

View File

@ -0,0 +1,116 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import typing as tp
import torch
from .base import BaseQuantizer, QuantizedResult
from .core_vq import ResidualVectorQuantization
class ResidualVectorQuantizer(BaseQuantizer):
"""Residual Vector Quantizer.
Args:
dimension (int): Dimension of the codebooks.
n_q (int): Number of residual vector quantizers used.
q_dropout (bool): Random quantizer drop out at train time.
bins (int): Codebook size.
decay (float): Decay for exponential moving average over the codebooks.
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
kmeans_iters (int): Number of iterations used for kmeans initialization.
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
orthogonal_reg_weight (float): Orthogonal regularization weights.
orthogonal_reg_active_codes_only (bool): Apply orthogonal regularization only on active codes.
orthogonal_reg_max_codes (optional int): Maximum number of codes to consider.
for orthogonal regulariation.
"""
def __init__(
self,
dimension: int = 256,
n_q: int = 8,
q_dropout: bool = False,
bins: int = 1024,
decay: float = 0.99,
kmeans_init: bool = True,
kmeans_iters: int = 10,
threshold_ema_dead_code: int = 2,
orthogonal_reg_weight: float = 0.0,
orthogonal_reg_active_codes_only: bool = False,
orthogonal_reg_max_codes: tp.Optional[int] = None,
):
super().__init__()
self.max_n_q = n_q
self.n_q = n_q
self.q_dropout = q_dropout
self.dimension = dimension
self.bins = bins
self.decay = decay
self.kmeans_init = kmeans_init
self.kmeans_iters = kmeans_iters
self.threshold_ema_dead_code = threshold_ema_dead_code
self.orthogonal_reg_weight = orthogonal_reg_weight
self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only
self.orthogonal_reg_max_codes = orthogonal_reg_max_codes
self.vq = ResidualVectorQuantization(
dim=self.dimension,
codebook_size=self.bins,
num_quantizers=self.n_q,
decay=self.decay,
kmeans_init=self.kmeans_init,
kmeans_iters=self.kmeans_iters,
threshold_ema_dead_code=self.threshold_ema_dead_code,
orthogonal_reg_weight=self.orthogonal_reg_weight,
orthogonal_reg_active_codes_only=self.orthogonal_reg_active_codes_only,
orthogonal_reg_max_codes=self.orthogonal_reg_max_codes,
channels_last=False
)
def forward(self, x: torch.Tensor, frame_rate: int):
n_q = self.n_q
if self.training and self.q_dropout:
n_q = int(torch.randint(1, self.n_q + 1, (1,)).item())
bw_per_q = math.log2(self.bins) * frame_rate / 1000
quantized, codes, commit_loss = self.vq(x, n_q=n_q)
codes = codes.transpose(0, 1)
# codes is [B, K, T], with T frames, K nb of codebooks.
bw = torch.tensor(n_q * bw_per_q).to(x)
return QuantizedResult(quantized, codes, bw, penalty=torch.mean(commit_loss))
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Encode a given input tensor with the specified frame rate at the given bandwidth.
The RVQ encode method sets the appropriate number of quantizer to use
and returns indices for each quantizer.
"""
n_q = self.n_q
codes = self.vq.encode(x, n_q=n_q)
codes = codes.transpose(0, 1)
# codes is [B, K, T], with T frames, K nb of codebooks.
return codes
def decode(self, codes: torch.Tensor) -> torch.Tensor:
"""Decode the given codes to the quantized representation.
"""
# codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T].
codes = codes.transpose(0, 1)
quantized = self.vq.decode(codes)
return quantized
@property
def total_codebooks(self):
return self.max_n_q
@property
def num_codebooks(self):
return self.n_q
def set_num_codebooks(self, n: int):
assert n > 0 and n <= self.max_n_q
self.n_q = n

View File

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

View File

@ -0,0 +1,40 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
class TorchAutocast:
"""TorchAutocast utility class.
Allows you to enable and disable autocast. This is specially useful
when dealing with different architectures and clusters with different
levels of support.
Args:
enabled (bool): Whether to enable torch.autocast or not.
args: Additional args for torch.autocast.
kwargs: Additional kwargs for torch.autocast
"""
def __init__(self, enabled: bool, *args, **kwargs):
self.autocast = torch.autocast(*args, **kwargs) if enabled else None
def __enter__(self):
if self.autocast is None:
return
try:
self.autocast.__enter__()
except RuntimeError:
device = self.autocast.device
dtype = self.autocast.fast_dtype
raise RuntimeError(
f"There was an error autocasting with dtype={dtype} device={device}\n"
"If you are on the FAIR Cluster, you might need to use autocast_dtype=float16"
)
def __exit__(self, *args, **kwargs):
if self.autocast is None:
return
self.autocast.__exit__(*args, **kwargs)

View File

@ -0,0 +1,56 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
Utility to export a training checkpoint to a lightweight release checkpoint.
"""
from pathlib import Path
import typing as tp
from omegaconf import OmegaConf, DictConfig
import torch
def _clean_lm_cfg(cfg: DictConfig):
OmegaConf.set_struct(cfg, False)
# This used to be set automatically in the LM solver, need a more robust solution
# for the future.
cfg['transformer_lm']['card'] = 2048
cfg['transformer_lm']['n_q'] = 4
# Experimental params no longer supported.
bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters',
'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop']
for name in bad_params:
del cfg['transformer_lm'][name]
OmegaConf.set_struct(cfg, True)
return cfg
def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
sig = Path(checkpoint_path).parent.name
assert len(sig) == 8, "Not a valid Dora signature"
pkg = torch.load(checkpoint_path, 'cpu')
new_pkg = {
'best_state': pkg['ema']['state']['model'],
'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
}
out_file = Path(out_folder) / f'{sig}.th'
torch.save(new_pkg, out_file)
return out_file
def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
sig = Path(checkpoint_path).parent.name
assert len(sig) == 8, "Not a valid Dora signature"
pkg = torch.load(checkpoint_path, 'cpu')
new_pkg = {
'best_state': pkg['fsdp_best_state']['model'],
'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg']))
}
out_file = Path(out_folder) / f'{sig}.th'
torch.save(new_pkg, out_file)
return out_file

View File

@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
try:
import IPython.display as ipd # type: ignore
except ImportError:
# Note in a notebook...
pass
import torch
def display_audio(samples: torch.Tensor, sample_rate: int):
"""Renders an audio player for the given audio samples.
Args:
samples (torch.Tensor): a Tensor of decoded audio samples
with shapes [B, C, T] or [C, T]
sample_rate (int): sample rate audio should be displayed with.
"""
assert samples.dim() == 2 or samples.dim() == 3
samples = samples.detach().cpu()
if samples.dim() == 2:
samples = samples[None, ...]
for audio in samples:
ipd.display(ipd.Audio(audio, rate=sample_rate))

234
audiocraft/utils/utils.py Normal file
View File

@ -0,0 +1,234 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from concurrent.futures import ProcessPoolExecutor
from functools import wraps
import hashlib
import logging
import typing as tp
import flashy
import flashy.distrib
import omegaconf
import torch
from torch.nn.utils.rnn import pad_sequence
logger = logging.getLogger(__name__)
def dict_from_config(cfg: omegaconf.DictConfig) -> dict:
"""Convenience function to map an omegaconf configuration to a dictionary.
Args:
cfg (omegaconf.DictConfig): Original configuration to map to dict.
Returns:
dict: Config as dictionary object.
"""
dct = omegaconf.OmegaConf.to_container(cfg, resolve=True)
assert isinstance(dct, dict)
return dct
def random_subset(dataset, max_samples: int, seed: int = 42) -> torch.utils.data.Subset:
if max_samples >= len(dataset):
return dataset
generator = torch.Generator().manual_seed(seed)
perm = torch.randperm(len(dataset), generator=generator)
return torch.utils.data.Subset(dataset, perm[:max_samples].tolist())
def get_loader(dataset, num_samples: tp.Optional[int], batch_size: int,
num_workers: int, seed: int, **kwargs) -> torch.utils.data.DataLoader:
"""Convenience function to load dataset into a dataloader with optional subset sampling.
Args:
dataset: Dataset to load.
num_samples (Optional[int]): Number of samples to limit subset size.
batch_size (int): Batch size.
num_workers (int): Number of workers for data loading.
seed (int): Random seed.
"""
if num_samples is not None:
dataset = random_subset(dataset, num_samples, seed)
dataloader = flashy.distrib.loader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
**kwargs
)
return dataloader
def get_dataset_from_loader(dataloader):
dataset = dataloader.dataset
if isinstance(dataset, torch.utils.data.Subset):
return dataset.dataset
else:
return dataset
def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
"""torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
Args:
input (torch.Tensor): The input tensor containing probabilities.
num_samples (int): Number of samples to draw.
replacement (bool): Whether to draw with replacement or not.
Keywords args:
generator (torch.Generator): A pseudorandom number generator for sampling.
Returns:
torch.Tensor: Last dimension contains num_samples indices
sampled from the multinomial probability distribution
located in the last dimension of tensor input.
"""
input_ = input.reshape(-1, input.shape[-1])
output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
output = output_.reshape(*list(input.shape[:-1]), -1)
return output
def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
"""Sample next token from top K values along the last dimension of the input probs tensor.
Args:
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
k (int): The k in top-k.
Returns:
torch.Tensor: Sampled tokens.
"""
top_k_value, _ = torch.topk(probs, k, dim=-1)
min_value_top_k = top_k_value[..., [-1]]
probs *= (probs >= min_value_top_k).float()
probs.div_(probs.sum(dim=-1, keepdim=True))
next_token = multinomial(probs, num_samples=1)
return next_token
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
"""Sample next token from top P probabilities along the last dimension of the input probs tensor.
Args:
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
p (int): The p in top-p.
Returns:
torch.Tensor: Sampled tokens.
"""
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort *= (~mask).float()
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token
class DummyPoolExecutor:
"""Dummy pool executor to use when we actually have only 1 worker.
(e.g. instead of ProcessPoolExecutor).
"""
class DummyResult:
def __init__(self, func, *args, **kwargs):
self.func = func
self.args = args
self.kwargs = kwargs
def result(self):
return self.func(*self.args, **self.kwargs)
def __init__(self, workers, mp_context=None):
pass
def submit(self, func, *args, **kwargs):
return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, exc_tb):
return
def get_pool_executor(num_workers: int, mp_context=None):
return ProcessPoolExecutor(num_workers, mp_context) if num_workers > 1 else DummyPoolExecutor(1)
def length_to_mask(lengths: torch.Tensor, max_len: tp.Optional[int] = None) -> torch.Tensor:
"""Utility function to convert a tensor of sequence lengths to a mask (useful when working on padded sequences).
For example: [3, 5] => [[1, 1, 1, 0, 0], [1, 1, 1, 1, 1]]
Args:
lengths (torch.Tensor): tensor with lengths
max_len (int): can set the max length manually. Defaults to None.
Returns:
torch.Tensor: mask with 0s where there is pad tokens else 1s
"""
assert len(lengths.shape) == 1, "Length shape should be 1 dimensional."
final_length = lengths.max().item() if not max_len else max_len
final_length = max(final_length, 1) # if all seqs are of len zero we don't want a zero-size tensor
return torch.arange(final_length)[None, :].to(lengths.device) < lengths[:, None]
def hash_trick(word: str, vocab_size: int) -> int:
"""Hash trick to pair each word with an index
Args:
word (str): word we wish to convert to an index
vocab_size (int): size of the vocabulary
Returns:
int: index of the word in the embedding LUT
"""
hash = int(hashlib.sha256(word.encode("utf-8")).hexdigest(), 16)
return hash % vocab_size
def with_rank_rng(base_seed: int = 1234):
"""Decorator for a function so that the function will use a Random Number Generator
whose state depend on the GPU rank. The original RNG state is restored upon returning.
Args:
base_seed (int): Random seed.
"""
def _decorator(fun: tp.Callable):
@wraps(fun)
def _decorated(*args, **kwargs):
state = torch.get_rng_state()
seed = base_seed ^ flashy.distrib.rank()
torch.manual_seed(seed)
logger.debug('Rank dependent seed set to %d', seed)
try:
return fun(*args, **kwargs)
finally:
torch.set_rng_state(state)
logger.debug('RNG state restored.')
return _decorated
return _decorator
def collate(tensors: tp.List[torch.Tensor], dim: int = 0) -> tp.Tuple[torch.Tensor, torch.Tensor]:
"""Get a list of tensors and collate them to a single tensor. according to the following logic:
- `dim` specifies the time dimension which will be stacked and padded.
- The output will contain 1 new dimension (dimension index 0) which will be the size of
of the original list.
Args:
tensors (tp.List[torch.Tensor]): List of tensors to collate.
dim (int): Dimension which will be stacked and padded.
Returns:
tp.Tuple[torch.Tensor, torch.Tensor]:
torch.Tensor: Stacked and padded tensor. The output will contain 1 new dimension
(dimension index 0) which will be the size of the original list.
torch.Tensor: Tensor containing length of original tensor sizes (without padding).
"""
tensors = [x.transpose(0, dim) for x in tensors]
lens = torch.LongTensor([len(x) for x in tensors])
padded_tensors = pad_sequence(tensors)
padded_tensors = padded_tensors.transpose(0, 1)
padded_tensors = padded_tensors.transpose(1, dim + 1)
return padded_tensors, lens

20
requirements.txt Normal file
View File

@ -0,0 +1,20 @@
# please make sure you have already a pytorch install that is cuda enabled!
av
einops
flashy>=0.0.1
hydra-core>=1.1
hydra_colorlog
julius
num2words
numpy
sentencepiece
spacy==3.5.2
torch>=2.0.0
torchaudio>=2.0.0
huggingface_hub
tqdm
transformers
xformers
demucs
librosa
gradio