219 lines
8.4 KiB
Python
219 lines
8.4 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
"""
|
|
All the functions to build the relevant models and modules
|
|
from the Hydra config.
|
|
"""
|
|
|
|
import typing as tp
|
|
import warnings
|
|
|
|
import audiocraft
|
|
import omegaconf
|
|
import torch
|
|
|
|
from .encodec import CompressionModel, EncodecModel, FlattenedCompressionModel # noqa
|
|
from .lm import LMModel
|
|
from ..modules.codebooks_patterns import (
|
|
CodebooksPatternProvider,
|
|
DelayedPatternProvider,
|
|
ParallelPatternProvider,
|
|
UnrolledPatternProvider,
|
|
VALLEPattern,
|
|
MusicLMPattern,
|
|
)
|
|
from ..modules.conditioners import (
|
|
BaseConditioner,
|
|
ConditioningProvider,
|
|
LUTConditioner,
|
|
T5Conditioner,
|
|
ConditionFuser,
|
|
ChromaStemConditioner,
|
|
)
|
|
from .. import quantization as qt
|
|
from ..utils.utils import dict_from_config
|
|
|
|
|
|
def get_quantizer(quantizer: str, cfg: omegaconf.DictConfig, dimension: int) -> qt.BaseQuantizer:
|
|
klass = {
|
|
'no_quant': qt.DummyQuantizer,
|
|
'rvq': qt.ResidualVectorQuantizer
|
|
}[quantizer]
|
|
kwargs = dict_from_config(getattr(cfg, quantizer))
|
|
if quantizer != 'no_quant':
|
|
kwargs['dimension'] = dimension
|
|
return klass(**kwargs)
|
|
|
|
|
|
def get_encodec_autoencoder(encoder_name: str, cfg: omegaconf.DictConfig):
|
|
if encoder_name == 'seanet':
|
|
kwargs = dict_from_config(getattr(cfg, 'seanet'))
|
|
encoder_override_kwargs = kwargs.pop('encoder')
|
|
decoder_override_kwargs = kwargs.pop('decoder')
|
|
encoder_kwargs = {**kwargs, **encoder_override_kwargs}
|
|
decoder_kwargs = {**kwargs, **decoder_override_kwargs}
|
|
encoder = audiocraft.modules.SEANetEncoder(**encoder_kwargs)
|
|
decoder = audiocraft.modules.SEANetDecoder(**decoder_kwargs)
|
|
return encoder, decoder
|
|
else:
|
|
raise KeyError(f'Unexpected compression model {cfg.compression_model}')
|
|
|
|
|
|
def get_compression_model(cfg: omegaconf.DictConfig) -> CompressionModel:
|
|
"""Instantiate a compression model.
|
|
"""
|
|
if cfg.compression_model == 'encodec':
|
|
kwargs = dict_from_config(getattr(cfg, 'encodec'))
|
|
encoder_name = kwargs.pop('autoencoder')
|
|
quantizer_name = kwargs.pop('quantizer')
|
|
encoder, decoder = get_encodec_autoencoder(encoder_name, cfg)
|
|
quantizer = get_quantizer(quantizer_name, cfg, encoder.dimension)
|
|
frame_rate = kwargs['sample_rate'] // encoder.hop_length
|
|
renormalize = kwargs.pop('renormalize', None)
|
|
renorm = kwargs.pop('renorm')
|
|
if renormalize is None:
|
|
renormalize = renorm is not None
|
|
warnings.warn("You are using a deprecated EnCodec model. Please migrate to new renormalization.")
|
|
return EncodecModel(encoder, decoder, quantizer,
|
|
frame_rate=frame_rate, renormalize=renormalize, **kwargs).to(cfg.device)
|
|
else:
|
|
raise KeyError(f'Unexpected compression model {cfg.compression_model}')
|
|
|
|
|
|
def get_lm_model(cfg: omegaconf.DictConfig) -> LMModel:
|
|
"""Instantiate a transformer LM.
|
|
"""
|
|
if cfg.lm_model == 'transformer_lm':
|
|
kwargs = dict_from_config(getattr(cfg, 'transformer_lm'))
|
|
n_q = kwargs['n_q']
|
|
q_modeling = kwargs.pop('q_modeling', None)
|
|
codebooks_pattern_cfg = getattr(cfg, 'codebooks_pattern')
|
|
attribute_dropout = dict_from_config(getattr(cfg, 'attribute_dropout'))
|
|
cls_free_guidance = dict_from_config(getattr(cfg, 'classifier_free_guidance'))
|
|
cfg_prob, cfg_coef = cls_free_guidance["training_dropout"], cls_free_guidance["inference_coef"]
|
|
fuser = get_condition_fuser(cfg)
|
|
condition_provider = get_conditioner_provider(kwargs["dim"], cfg).to(cfg.device)
|
|
if len(fuser.fuse2cond['cross']) > 0: # enforce cross-att programatically
|
|
kwargs['cross_attention'] = True
|
|
if codebooks_pattern_cfg.modeling is None:
|
|
assert q_modeling is not None, \
|
|
'LM model should either have a codebook pattern defined or transformer_lm.q_modeling'
|
|
codebooks_pattern_cfg = omegaconf.OmegaConf.create(
|
|
{'modeling': q_modeling, 'delay': {'delays': list(range(n_q))}}
|
|
)
|
|
pattern_provider = get_codebooks_pattern_provider(n_q, codebooks_pattern_cfg)
|
|
return LMModel(
|
|
pattern_provider=pattern_provider,
|
|
condition_provider=condition_provider,
|
|
fuser=fuser,
|
|
cfg_dropout=cfg_prob,
|
|
cfg_coef=cfg_coef,
|
|
attribute_dropout=attribute_dropout,
|
|
dtype=getattr(torch, cfg.dtype),
|
|
device=cfg.device,
|
|
**kwargs
|
|
).to(cfg.device)
|
|
else:
|
|
raise KeyError(f'Unexpected LM model {cfg.lm_model}')
|
|
|
|
|
|
def get_conditioner_provider(output_dim: int, cfg: omegaconf.DictConfig) -> ConditioningProvider:
|
|
"""Instantiate a conditioning model.
|
|
"""
|
|
device = cfg.device
|
|
duration = cfg.dataset.segment_duration
|
|
cfg = getattr(cfg, "conditioners")
|
|
cfg = omegaconf.OmegaConf.create({}) if cfg is None else cfg
|
|
conditioners: tp.Dict[str, BaseConditioner] = {}
|
|
with omegaconf.open_dict(cfg):
|
|
condition_provider_args = cfg.pop('args', {})
|
|
for cond, cond_cfg in cfg.items():
|
|
model_type = cond_cfg["model"]
|
|
model_args = cond_cfg[model_type]
|
|
if model_type == "t5":
|
|
conditioners[str(cond)] = T5Conditioner(output_dim=output_dim, device=device, **model_args)
|
|
elif model_type == "lut":
|
|
conditioners[str(cond)] = LUTConditioner(output_dim=output_dim, **model_args)
|
|
elif model_type == "chroma_stem":
|
|
model_args.pop('cache_path', None)
|
|
conditioners[str(cond)] = ChromaStemConditioner(
|
|
output_dim=output_dim,
|
|
duration=duration,
|
|
device=device,
|
|
**model_args
|
|
)
|
|
else:
|
|
raise ValueError(f"unrecognized conditioning model: {model_type}")
|
|
conditioner = ConditioningProvider(conditioners, device=device, **condition_provider_args)
|
|
return conditioner
|
|
|
|
|
|
def get_condition_fuser(cfg: omegaconf.DictConfig) -> ConditionFuser:
|
|
"""Instantiate a condition fuser object.
|
|
"""
|
|
fuser_cfg = getattr(cfg, "fuser")
|
|
fuser_methods = ["sum", "cross", "prepend", "input_interpolate"]
|
|
fuse2cond = {k: fuser_cfg[k] for k in fuser_methods}
|
|
kwargs = {k: v for k, v in fuser_cfg.items() if k not in fuser_methods}
|
|
fuser = ConditionFuser(fuse2cond=fuse2cond, **kwargs)
|
|
return fuser
|
|
|
|
|
|
def get_codebooks_pattern_provider(n_q: int, cfg: omegaconf.DictConfig) -> CodebooksPatternProvider:
|
|
"""Instantiate a codebooks pattern provider object.
|
|
"""
|
|
pattern_providers = {
|
|
'parallel': ParallelPatternProvider,
|
|
'delay': DelayedPatternProvider,
|
|
'unroll': UnrolledPatternProvider,
|
|
'valle': VALLEPattern,
|
|
'musiclm': MusicLMPattern,
|
|
}
|
|
name = cfg.modeling
|
|
kwargs = dict_from_config(cfg.get(name)) if hasattr(cfg, name) else {}
|
|
klass = pattern_providers[name]
|
|
return klass(n_q, **kwargs)
|
|
|
|
|
|
def get_debug_compression_model(device='cpu'):
|
|
"""Instantiate a debug compression model to be used for unit tests.
|
|
"""
|
|
seanet_kwargs = {
|
|
'n_filters': 4,
|
|
'n_residual_layers': 1,
|
|
'dimension': 32,
|
|
'ratios': [10, 8, 16] # 25 Hz at 32kHz
|
|
}
|
|
encoder = audiocraft.modules.SEANetEncoder(**seanet_kwargs)
|
|
decoder = audiocraft.modules.SEANetDecoder(**seanet_kwargs)
|
|
quantizer = qt.ResidualVectorQuantizer(dimension=32, bins=400, n_q=4)
|
|
init_x = torch.randn(8, 32, 128)
|
|
quantizer(init_x, 1) # initialize kmeans etc.
|
|
compression_model = EncodecModel(
|
|
encoder, decoder, quantizer,
|
|
frame_rate=25, sample_rate=32000, channels=1).to(device)
|
|
return compression_model.eval()
|
|
|
|
|
|
def get_debug_lm_model(device='cpu'):
|
|
"""Instantiate a debug LM to be used for unit tests.
|
|
"""
|
|
pattern = DelayedPatternProvider(n_q=4)
|
|
dim = 16
|
|
providers = {
|
|
'description': LUTConditioner(n_bins=128, dim=dim, output_dim=dim, tokenizer="whitespace"),
|
|
}
|
|
condition_provider = ConditioningProvider(providers)
|
|
fuser = ConditionFuser(
|
|
{'cross': ['description'], 'prepend': [],
|
|
'sum': [], 'input_interpolate': []})
|
|
lm = LMModel(
|
|
pattern, condition_provider, fuser,
|
|
n_q=4, card=400, dim=dim, num_heads=4, custom=True, num_layers=2,
|
|
cross_attention=True, causal=True)
|
|
return lm.to(device).eval()
|