57 lines
1.9 KiB
Python
57 lines
1.9 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
"""
|
|
Utility to export a training checkpoint to a lightweight release checkpoint.
|
|
"""
|
|
|
|
from pathlib import Path
|
|
import typing as tp
|
|
|
|
from omegaconf import OmegaConf, DictConfig
|
|
import torch
|
|
|
|
|
|
def _clean_lm_cfg(cfg: DictConfig):
|
|
OmegaConf.set_struct(cfg, False)
|
|
# This used to be set automatically in the LM solver, need a more robust solution
|
|
# for the future.
|
|
cfg['transformer_lm']['card'] = 2048
|
|
cfg['transformer_lm']['n_q'] = 4
|
|
# Experimental params no longer supported.
|
|
bad_params = ['spectral_norm_attn_iters', 'spectral_norm_ff_iters',
|
|
'residual_balancer_attn', 'residual_balancer_ff', 'layer_drop']
|
|
for name in bad_params:
|
|
del cfg['transformer_lm'][name]
|
|
OmegaConf.set_struct(cfg, True)
|
|
return cfg
|
|
|
|
|
|
def export_encodec(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
|
|
sig = Path(checkpoint_path).parent.name
|
|
assert len(sig) == 8, "Not a valid Dora signature"
|
|
pkg = torch.load(checkpoint_path, 'cpu')
|
|
new_pkg = {
|
|
'best_state': pkg['ema']['state']['model'],
|
|
'xp.cfg': OmegaConf.to_yaml(pkg['xp.cfg']),
|
|
}
|
|
out_file = Path(out_folder) / f'{sig}.th'
|
|
torch.save(new_pkg, out_file)
|
|
return out_file
|
|
|
|
|
|
def export_lm(checkpoint_path: tp.Union[Path, str], out_folder: tp.Union[Path, str]):
|
|
sig = Path(checkpoint_path).parent.name
|
|
assert len(sig) == 8, "Not a valid Dora signature"
|
|
pkg = torch.load(checkpoint_path, 'cpu')
|
|
new_pkg = {
|
|
'best_state': pkg['fsdp_best_state']['model'],
|
|
'xp.cfg': OmegaConf.to_yaml(_clean_lm_cfg(pkg['xp.cfg']))
|
|
}
|
|
out_file = Path(out_folder) / f'{sig}.th'
|
|
torch.save(new_pkg, out_file)
|
|
return out_file
|