91 lines
3.2 KiB
Python
91 lines
3.2 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
"""
|
|
Utility functions to load from the checkpoints.
|
|
Each checkpoint is a torch.saved dict with the following keys:
|
|
- 'xp.cfg': the hydra config as dumped during training. This should be used
|
|
to rebuild the object using the audiocraft.models.builders functions,
|
|
- 'model_best_state': a readily loadable best state for the model, including
|
|
the conditioner. The model obtained from `xp.cfg` should be compatible
|
|
with this state dict. In the case of a LM, the encodec model would not be
|
|
bundled along but instead provided separately.
|
|
|
|
Those functions also support loading from a remote location with the Torch Hub API.
|
|
They also support overriding some parameters, in particular the device and dtype
|
|
of the returned model.
|
|
"""
|
|
|
|
from pathlib import Path
|
|
from huggingface_hub import hf_hub_download
|
|
import typing as tp
|
|
import os
|
|
|
|
from omegaconf import OmegaConf
|
|
import torch
|
|
|
|
from . import builders
|
|
|
|
|
|
HF_MODEL_CHECKPOINTS_MAP = {
|
|
"small": "facebook/musicgen-small",
|
|
"medium": "facebook/musicgen-medium",
|
|
"large": "facebook/musicgen-large",
|
|
"melody": "facebook/musicgen-melody",
|
|
}
|
|
|
|
|
|
def _get_state_dict(
|
|
file_or_url_or_id: tp.Union[Path, str],
|
|
filename: tp.Optional[str] = None,
|
|
device='cpu',
|
|
cache_dir: tp.Optional[str] = None,
|
|
):
|
|
# Return the state dict either from a file or url
|
|
file_or_url_or_id = str(file_or_url_or_id)
|
|
assert isinstance(file_or_url_or_id, str)
|
|
|
|
if os.path.isfile(file_or_url_or_id):
|
|
return torch.load(file_or_url_or_id, map_location=device)
|
|
|
|
elif file_or_url_or_id.startswith('https://'):
|
|
return torch.hub.load_state_dict_from_url(file_or_url_or_id, map_location=device, check_hash=True)
|
|
|
|
elif file_or_url_or_id in HF_MODEL_CHECKPOINTS_MAP:
|
|
assert filename is not None, "filename needs to be defined if using HF checkpoints"
|
|
|
|
repo_id = HF_MODEL_CHECKPOINTS_MAP[file_or_url_or_id]
|
|
file = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir)
|
|
return torch.load(file, map_location=device)
|
|
|
|
else:
|
|
raise ValueError(f"{file_or_url_or_id} is not a valid name, path or link that can be loaded.")
|
|
|
|
|
|
def load_compression_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
|
|
pkg = _get_state_dict(file_or_url_or_id, filename="compression_state_dict.bin", cache_dir=cache_dir)
|
|
cfg = OmegaConf.create(pkg['xp.cfg'])
|
|
cfg.device = str(device)
|
|
model = builders.get_compression_model(cfg)
|
|
model.load_state_dict(pkg['best_state'])
|
|
model.eval()
|
|
return model
|
|
|
|
|
|
def load_lm_model(file_or_url_or_id: tp.Union[Path, str], device='cpu', cache_dir: tp.Optional[str] = None):
|
|
pkg = _get_state_dict(file_or_url_or_id, filename="state_dict.bin", cache_dir=cache_dir)
|
|
cfg = OmegaConf.create(pkg['xp.cfg'])
|
|
cfg.device = str(device)
|
|
if cfg.device == 'cpu':
|
|
cfg.dtype = 'float32'
|
|
else:
|
|
cfg.dtype = 'float16'
|
|
model = builders.get_lm_model(cfg)
|
|
model.load_state_dict(pkg['best_state'])
|
|
model.eval()
|
|
model.cfg = cfg
|
|
return model
|