526 lines
21 KiB
Python
526 lines
21 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import argparse
|
|
import copy
|
|
from concurrent.futures import ThreadPoolExecutor, Future
|
|
from dataclasses import dataclass, fields
|
|
from contextlib import ExitStack
|
|
import gzip
|
|
import json
|
|
import logging
|
|
import os
|
|
from pathlib import Path
|
|
import random
|
|
import sys
|
|
import typing as tp
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from .audio import audio_read, audio_info
|
|
from .audio_utils import convert_audio
|
|
from .zip import PathInZip
|
|
|
|
try:
|
|
import dora
|
|
except ImportError:
|
|
dora = None # type: ignore
|
|
|
|
|
|
@dataclass(order=True)
|
|
class BaseInfo:
|
|
|
|
@classmethod
|
|
def _dict2fields(cls, dictionary: dict):
|
|
return {
|
|
field.name: dictionary[field.name]
|
|
for field in fields(cls) if field.name in dictionary
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, dictionary: dict):
|
|
_dictionary = cls._dict2fields(dictionary)
|
|
return cls(**_dictionary)
|
|
|
|
def to_dict(self):
|
|
return {
|
|
field.name: self.__getattribute__(field.name)
|
|
for field in fields(self)
|
|
}
|
|
|
|
|
|
@dataclass(order=True)
|
|
class AudioMeta(BaseInfo):
|
|
path: str
|
|
duration: float
|
|
sample_rate: int
|
|
amplitude: tp.Optional[float] = None
|
|
weight: tp.Optional[float] = None
|
|
# info_path is used to load additional information about the audio file that is stored in zip files.
|
|
info_path: tp.Optional[PathInZip] = None
|
|
|
|
@classmethod
|
|
def from_dict(cls, dictionary: dict):
|
|
base = cls._dict2fields(dictionary)
|
|
if 'info_path' in base and base['info_path'] is not None:
|
|
base['info_path'] = PathInZip(base['info_path'])
|
|
return cls(**base)
|
|
|
|
def to_dict(self):
|
|
d = super().to_dict()
|
|
if d['info_path'] is not None:
|
|
d['info_path'] = str(d['info_path'])
|
|
return d
|
|
|
|
|
|
@dataclass(order=True)
|
|
class SegmentInfo(BaseInfo):
|
|
meta: AudioMeta
|
|
seek_time: float
|
|
n_frames: int # actual number of frames without padding
|
|
total_frames: int # total number of frames, padding included
|
|
sample_rate: int # actual sample rate
|
|
|
|
|
|
DEFAULT_EXTS = ['.wav', '.mp3', '.flac', '.ogg', '.m4a']
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def _get_audio_meta(file_path: str, minimal: bool = True) -> AudioMeta:
|
|
"""AudioMeta from a path to an audio file.
|
|
|
|
Args:
|
|
file_path (str): Resolved path of valid audio file.
|
|
minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
|
|
Returns:
|
|
AudioMeta: Audio file path and its metadata.
|
|
"""
|
|
info = audio_info(file_path)
|
|
amplitude: tp.Optional[float] = None
|
|
if not minimal:
|
|
wav, sr = audio_read(file_path)
|
|
amplitude = wav.abs().max().item()
|
|
return AudioMeta(file_path, info.duration, info.sample_rate, amplitude)
|
|
|
|
|
|
def _resolve_audio_meta(m: AudioMeta, fast: bool = True) -> AudioMeta:
|
|
"""If Dora is available as a dependency, try to resolve potential relative paths
|
|
in list of AudioMeta. This method is expected to be used when loading meta from file.
|
|
|
|
Args:
|
|
m (AudioMeta): Audio meta to resolve.
|
|
fast (bool): If True, uses a really fast check for determining if a file is already absolute or not.
|
|
Only valid on Linux/Mac.
|
|
Returns:
|
|
AudioMeta: Audio meta with resolved path.
|
|
"""
|
|
def is_abs(m):
|
|
if fast:
|
|
return str(m)[0] == '/'
|
|
else:
|
|
os.path.isabs(str(m))
|
|
|
|
if not dora:
|
|
return m
|
|
|
|
if not is_abs(m.path):
|
|
m.path = dora.git_save.to_absolute_path(m.path)
|
|
if m.info_path is not None and not is_abs(m.info_path.zip_path):
|
|
m.info_path.zip_path = dora.git_save.to_absolute_path(m.path)
|
|
return m
|
|
|
|
|
|
def find_audio_files(path: tp.Union[Path, str],
|
|
exts: tp.List[str] = DEFAULT_EXTS,
|
|
resolve: bool = True,
|
|
minimal: bool = True,
|
|
progress: bool = False,
|
|
workers: int = 0) -> tp.List[AudioMeta]:
|
|
"""Build a list of AudioMeta from a given path,
|
|
collecting relevant audio files and fetching meta info.
|
|
|
|
Args:
|
|
path (str or Path): Path to folder containing audio files.
|
|
exts (list of str): List of file extensions to consider for audio files.
|
|
minimal (bool): Whether to only load the minimal set of metadata (takes longer if not).
|
|
progress (bool): Whether to log progress on audio files collection.
|
|
workers (int): number of parallel workers, if 0, use only the current thread.
|
|
Returns:
|
|
List[AudioMeta]: List of audio file path and its metadata.
|
|
"""
|
|
audio_files = []
|
|
futures: tp.List[Future] = []
|
|
pool: tp.Optional[ThreadPoolExecutor] = None
|
|
with ExitStack() as stack:
|
|
if workers > 0:
|
|
pool = ThreadPoolExecutor(workers)
|
|
stack.enter_context(pool)
|
|
|
|
if progress:
|
|
print("Finding audio files...")
|
|
for root, folders, files in os.walk(path, followlinks=True):
|
|
for file in files:
|
|
full_path = Path(root) / file
|
|
if full_path.suffix.lower() in exts:
|
|
audio_files.append(full_path)
|
|
if pool is not None:
|
|
futures.append(pool.submit(_get_audio_meta, str(audio_files[-1]), minimal))
|
|
if progress:
|
|
print(format(len(audio_files), " 8d"), end='\r', file=sys.stderr)
|
|
|
|
if progress:
|
|
print("Getting audio metadata...")
|
|
meta: tp.List[AudioMeta] = []
|
|
for idx, file_path in enumerate(audio_files):
|
|
try:
|
|
if pool is None:
|
|
m = _get_audio_meta(str(file_path), minimal)
|
|
else:
|
|
m = futures[idx].result()
|
|
if resolve:
|
|
m = _resolve_audio_meta(m)
|
|
except Exception as err:
|
|
print("Error with", str(file_path), err, file=sys.stderr)
|
|
continue
|
|
meta.append(m)
|
|
if progress:
|
|
print(format((1 + idx) / len(audio_files), " 3.1%"), end='\r', file=sys.stderr)
|
|
meta.sort()
|
|
return meta
|
|
|
|
|
|
def load_audio_meta(path: tp.Union[str, Path],
|
|
resolve: bool = True, fast: bool = True) -> tp.List[AudioMeta]:
|
|
"""Load list of AudioMeta from an optionally compressed json file.
|
|
|
|
Args:
|
|
path (str or Path): Path to JSON file.
|
|
resolve (bool): Whether to resolve the path from AudioMeta (default=True).
|
|
fast (bool): activates some tricks to make things faster.
|
|
Returns:
|
|
List[AudioMeta]: List of audio file path and its total duration.
|
|
"""
|
|
open_fn = gzip.open if str(path).lower().endswith('.gz') else open
|
|
with open_fn(path, 'rb') as fp: # type: ignore
|
|
lines = fp.readlines()
|
|
meta = []
|
|
for line in lines:
|
|
d = json.loads(line)
|
|
m = AudioMeta.from_dict(d)
|
|
if resolve:
|
|
m = _resolve_audio_meta(m, fast=fast)
|
|
meta.append(m)
|
|
return meta
|
|
|
|
|
|
def save_audio_meta(path: tp.Union[str, Path], meta: tp.List[AudioMeta]):
|
|
"""Save the audio metadata to the file pointer as json.
|
|
|
|
Args:
|
|
path (str or Path): Path to JSON file.
|
|
metadata (list of BaseAudioMeta): List of audio meta to save.
|
|
"""
|
|
Path(path).parent.mkdir(exist_ok=True, parents=True)
|
|
open_fn = gzip.open if str(path).lower().endswith('.gz') else open
|
|
with open_fn(path, 'wb') as fp: # type: ignore
|
|
for m in meta:
|
|
json_str = json.dumps(m.to_dict()) + '\n'
|
|
json_bytes = json_str.encode('utf-8')
|
|
fp.write(json_bytes)
|
|
|
|
|
|
class AudioDataset:
|
|
"""Base audio dataset.
|
|
|
|
The dataset takes a list of AudioMeta and create a dataset composed of segments of audio
|
|
and potentially additional information, by creating random segments from the list of audio
|
|
files referenced in the metadata and applying minimal data pre-processing such as resampling,
|
|
mixing of channels, padding, etc.
|
|
|
|
If no segment_duration value is provided, the AudioDataset will return the full wav for each
|
|
audio file. Otherwise, it will randomly sample audio files and create a segment of the specified
|
|
duration, applying padding if required.
|
|
|
|
By default, only the torch Tensor corresponding to the waveform is returned. Setting return_info=True
|
|
allows to return a tuple containing the torch Tensor and additional metadata on the segment and the
|
|
original audio meta.
|
|
|
|
Args:
|
|
meta (tp.List[AudioMeta]): List of audio files metadata.
|
|
segment_duration (float): Optional segment duration of audio to load.
|
|
If not specified, the dataset will load the full audio segment from the file.
|
|
shuffle (bool): Set to `True` to have the data reshuffled at every epoch.
|
|
sample_rate (int): Target sample rate of the loaded audio samples.
|
|
channels (int): Target number of channels of the loaded audio samples.
|
|
sample_on_duration (bool): Set to `True` to sample segments with probability
|
|
dependent on audio file duration. This is only used if `segment_duration` is provided.
|
|
sample_on_weight (bool): Set to `True` to sample segments using the `weight` entry of
|
|
`AudioMeta`. If `sample_on_duration` is also True, the actual weight will be the product
|
|
of the file duration and file weight. This is only used if `segment_duration` is provided.
|
|
min_segment_ratio (float): Minimum segment ratio to use when the audio file
|
|
is shorter than the desired segment.
|
|
max_read_retry (int): Maximum number of retries to sample an audio segment from the dataset.
|
|
return_info (bool): Whether to return the wav only or return wav along with segment info and metadata.
|
|
min_audio_duration (tp.Optional[float], optional): Minimum audio file duration, in seconds, if provided
|
|
audio shorter than this will be filtered out.
|
|
max_audio_duration (tp.Optional[float], optional): Maximal audio file duration in seconds, if provided
|
|
audio longer than this will be filtered out.
|
|
"""
|
|
def __init__(self,
|
|
meta: tp.List[AudioMeta],
|
|
segment_duration: tp.Optional[float] = None,
|
|
shuffle: bool = True,
|
|
num_samples: int = 10_000,
|
|
sample_rate: int = 48_000,
|
|
channels: int = 2,
|
|
pad: bool = True,
|
|
sample_on_duration: bool = True,
|
|
sample_on_weight: bool = True,
|
|
min_segment_ratio: float = 0.5,
|
|
max_read_retry: int = 10,
|
|
return_info: bool = False,
|
|
min_audio_duration: tp.Optional[float] = None,
|
|
max_audio_duration: tp.Optional[float] = None
|
|
):
|
|
assert len(meta) > 0, 'No audio meta provided to AudioDataset. Please check loading of audio meta.'
|
|
assert segment_duration is None or segment_duration > 0
|
|
assert segment_duration is None or min_segment_ratio >= 0
|
|
logging.debug(f'sample_on_duration: {sample_on_duration}')
|
|
logging.debug(f'sample_on_weight: {sample_on_weight}')
|
|
logging.debug(f'pad: {pad}')
|
|
logging.debug(f'min_segment_ratio: {min_segment_ratio}')
|
|
|
|
self.segment_duration = segment_duration
|
|
self.min_segment_ratio = min_segment_ratio
|
|
self.max_audio_duration = max_audio_duration
|
|
self.min_audio_duration = min_audio_duration
|
|
if self.min_audio_duration is not None and self.max_audio_duration is not None:
|
|
assert self.min_audio_duration <= self.max_audio_duration
|
|
self.meta: tp.List[AudioMeta] = self._filter_duration(meta)
|
|
assert len(self.meta) # Fail fast if all data has been filtered.
|
|
self.total_duration = sum(d.duration for d in self.meta)
|
|
|
|
if segment_duration is None:
|
|
num_samples = len(self.meta)
|
|
self.num_samples = num_samples
|
|
self.shuffle = shuffle
|
|
self.sample_rate = sample_rate
|
|
self.channels = channels
|
|
self.pad = pad
|
|
self.sample_on_weight = sample_on_weight
|
|
self.sample_on_duration = sample_on_duration
|
|
self.sampling_probabilities = self._get_sampling_probabilities()
|
|
self.max_read_retry = max_read_retry
|
|
self.return_info = return_info
|
|
|
|
def __len__(self):
|
|
return self.num_samples
|
|
|
|
def _get_sampling_probabilities(self, normalized: bool = True):
|
|
"""Return the sampling probabilities for each file inside `self.meta`.
|
|
"""
|
|
scores: tp.List[float] = []
|
|
for file_meta in self.meta:
|
|
score = 1.
|
|
if self.sample_on_weight and file_meta.weight is not None:
|
|
score *= file_meta.weight
|
|
if self.sample_on_duration:
|
|
score *= file_meta.duration
|
|
scores.append(score)
|
|
probabilities = torch.tensor(scores)
|
|
if normalized:
|
|
probabilities /= probabilities.sum()
|
|
return probabilities
|
|
|
|
def sample_file(self, rng: torch.Generator) -> AudioMeta:
|
|
"""Sample a given file from `self.meta`. Can be overriden in subclasses.
|
|
This is only called if `segment_duration` is not None.
|
|
|
|
You must use the provided random number generator `rng` for reproducibility.
|
|
"""
|
|
if not self.sample_on_weight and not self.sample_on_duration:
|
|
file_index = int(torch.randint(len(self.sampling_probabilities), (1,), generator=rng).item())
|
|
else:
|
|
file_index = int(torch.multinomial(self.sampling_probabilities, 1, generator=rng).item())
|
|
|
|
return self.meta[file_index]
|
|
|
|
def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentInfo]]:
|
|
if self.segment_duration is None:
|
|
file_meta = self.meta[index]
|
|
out, sr = audio_read(file_meta.path)
|
|
out = convert_audio(out, sr, self.sample_rate, self.channels)
|
|
n_frames = out.shape[-1]
|
|
segment_info = SegmentInfo(file_meta, seek_time=0., n_frames=n_frames, total_frames=n_frames,
|
|
sample_rate=self.sample_rate)
|
|
else:
|
|
rng = torch.Generator()
|
|
if self.shuffle:
|
|
# We use index, plus extra randomness
|
|
rng.manual_seed(index + self.num_samples * random.randint(0, 2**24))
|
|
else:
|
|
# We only use index
|
|
rng.manual_seed(index)
|
|
|
|
for retry in range(self.max_read_retry):
|
|
file_meta = self.sample_file(rng)
|
|
# We add some variance in the file position even if audio file is smaller than segment
|
|
# without ending up with empty segments
|
|
max_seek = max(0, file_meta.duration - self.segment_duration * self.min_segment_ratio)
|
|
seek_time = torch.rand(1, generator=rng).item() * max_seek
|
|
try:
|
|
out, sr = audio_read(file_meta.path, seek_time, self.segment_duration, pad=False)
|
|
out = convert_audio(out, sr, self.sample_rate, self.channels)
|
|
n_frames = out.shape[-1]
|
|
target_frames = int(self.segment_duration * self.sample_rate)
|
|
if self.pad:
|
|
out = F.pad(out, (0, target_frames - n_frames))
|
|
segment_info = SegmentInfo(file_meta, seek_time, n_frames=n_frames, total_frames=target_frames,
|
|
sample_rate=self.sample_rate)
|
|
except Exception as exc:
|
|
logger.warning("Error opening file %s: %r", file_meta.path, exc)
|
|
if retry == self.max_read_retry - 1:
|
|
raise
|
|
else:
|
|
break
|
|
|
|
if self.return_info:
|
|
# Returns the wav and additional information on the wave segment
|
|
return out, segment_info
|
|
else:
|
|
return out
|
|
|
|
def collater(self, samples):
|
|
"""The collater function has to be provided to the dataloader
|
|
if AudioDataset has return_info=True in order to properly collate
|
|
the samples of a batch.
|
|
"""
|
|
if self.segment_duration is None and len(samples) > 1:
|
|
assert self.pad, "Must allow padding when batching examples of different durations."
|
|
|
|
# In this case the audio reaching the collater is of variable length as segment_duration=None.
|
|
to_pad = self.segment_duration is None and self.pad
|
|
if to_pad:
|
|
max_len = max([wav.shape[-1] for wav, _ in samples])
|
|
|
|
def _pad_wav(wav):
|
|
return F.pad(wav, (0, max_len - wav.shape[-1]))
|
|
|
|
if self.return_info:
|
|
if len(samples) > 0:
|
|
assert len(samples[0]) == 2
|
|
assert isinstance(samples[0][0], torch.Tensor)
|
|
assert isinstance(samples[0][1], SegmentInfo)
|
|
|
|
wavs = [wav for wav, _ in samples]
|
|
segment_infos = [copy.deepcopy(info) for _, info in samples]
|
|
|
|
if to_pad:
|
|
# Each wav could be of a different duration as they are not segmented.
|
|
for i in range(len(samples)):
|
|
# Determines the total legth of the signal with padding, so we update here as we pad.
|
|
segment_infos[i].total_frames = max_len
|
|
wavs[i] = _pad_wav(wavs[i])
|
|
|
|
wav = torch.stack(wavs)
|
|
return wav, segment_infos
|
|
else:
|
|
assert isinstance(samples[0], torch.Tensor)
|
|
if to_pad:
|
|
samples = [_pad_wav(s) for s in samples]
|
|
return torch.stack(samples)
|
|
|
|
def _filter_duration(self, meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
|
|
"""Filters out audio files with short durations.
|
|
Removes from meta files that have durations that will not allow to samples examples from them.
|
|
"""
|
|
orig_len = len(meta)
|
|
|
|
# Filter data that is too short.
|
|
if self.min_audio_duration is not None:
|
|
meta = [m for m in meta if m.duration >= self.min_audio_duration]
|
|
|
|
# Filter data that is too long.
|
|
if self.max_audio_duration is not None:
|
|
meta = [m for m in meta if m.duration <= self.max_audio_duration]
|
|
|
|
filtered_len = len(meta)
|
|
removed_percentage = 100*(1-float(filtered_len)/orig_len)
|
|
msg = 'Removed %.2f percent of the data because it was too short or too long.' % removed_percentage
|
|
if removed_percentage < 10:
|
|
logging.debug(msg)
|
|
else:
|
|
logging.warning(msg)
|
|
return meta
|
|
|
|
@classmethod
|
|
def from_meta(cls, root: tp.Union[str, Path], **kwargs):
|
|
"""Instantiate AudioDataset from a path to a directory containing a manifest as a jsonl file.
|
|
|
|
Args:
|
|
root (str or Path): Path to root folder containing audio files.
|
|
kwargs: Additional keyword arguments for the AudioDataset.
|
|
"""
|
|
root = Path(root)
|
|
if root.is_dir():
|
|
if (root / 'data.jsonl').exists():
|
|
root = root / 'data.jsonl'
|
|
elif (root / 'data.jsonl.gz').exists():
|
|
root = root / 'data.jsonl.gz'
|
|
else:
|
|
raise ValueError("Don't know where to read metadata from in the dir. "
|
|
"Expecting either a data.jsonl or data.jsonl.gz file but none found.")
|
|
meta = load_audio_meta(root)
|
|
return cls(meta, **kwargs)
|
|
|
|
@classmethod
|
|
def from_path(cls, root: tp.Union[str, Path], minimal_meta: bool = True,
|
|
exts: tp.List[str] = DEFAULT_EXTS, **kwargs):
|
|
"""Instantiate AudioDataset from a path containing (possibly nested) audio files.
|
|
|
|
Args:
|
|
root (str or Path): Path to root folder containing audio files.
|
|
minimal_meta (bool): Whether to only load minimal metadata or not.
|
|
exts (list of str): Extensions for audio files.
|
|
kwargs: Additional keyword arguments for the AudioDataset.
|
|
"""
|
|
root = Path(root)
|
|
if root.is_file():
|
|
meta = load_audio_meta(root, resolve=True)
|
|
else:
|
|
meta = find_audio_files(root, exts, minimal=minimal_meta, resolve=True)
|
|
return cls(meta, **kwargs)
|
|
|
|
|
|
def main():
|
|
logging.basicConfig(stream=sys.stderr, level=logging.INFO)
|
|
parser = argparse.ArgumentParser(
|
|
prog='audio_dataset',
|
|
description='Generate .jsonl files by scanning a folder.')
|
|
parser.add_argument('root', help='Root folder with all the audio files')
|
|
parser.add_argument('output_meta_file',
|
|
help='Output file to store the metadata, ')
|
|
parser.add_argument('--complete',
|
|
action='store_false', dest='minimal', default=True,
|
|
help='Retrieve all metadata, even the one that are expansive '
|
|
'to compute (e.g. normalization).')
|
|
parser.add_argument('--resolve',
|
|
action='store_true', default=False,
|
|
help='Resolve the paths to be absolute and with no symlinks.')
|
|
parser.add_argument('--workers',
|
|
default=10, type=int,
|
|
help='Number of workers.')
|
|
args = parser.parse_args()
|
|
meta = find_audio_files(args.root, DEFAULT_EXTS, progress=True,
|
|
resolve=args.resolve, minimal=args.minimal, workers=args.workers)
|
|
save_audio_meta(args.output_meta_file, meta)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|