Skip to content

Commit

Permalink
[ssl] init dataset (#2596)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mddct authored Aug 7, 2024
1 parent f346e81 commit fcf26a4
Show file tree
Hide file tree
Showing 3 changed files with 212 additions and 18 deletions.
154 changes: 154 additions & 0 deletions wenet/ssl/init_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
from functools import partial
import sys

import torch
from torch.nn.utils.rnn import pad_sequence
from wenet.dataset import processor
from wenet.dataset.datapipes import WenetRawDatasetSource, WenetTarShardDatasetSource


def padding(data):
""" Padding the data into training data
Args:
data: List[{key, feat, label}
Returns:
Tuple(keys, feats, labels, feats lengths, label lengths)
"""
sample = data
assert isinstance(sample, list)
feats_length = torch.tensor([x['feat'].size(0) for x in sample],
dtype=torch.int32)
order = torch.argsort(feats_length, descending=True)
feats_lengths = torch.tensor([sample[i]['feat'].size(0) for i in order],
dtype=torch.int32)
sorted_feats = [sample[i]['feat'] for i in order]
sorted_keys = [sample[i]['key'] for i in order]
padded_feats = pad_sequence(sorted_feats,
batch_first=True,
padding_value=0)
batch = {
"keys": sorted_keys,
"feats": padded_feats,
"feats_lengths": feats_lengths,
}
return batch


def Dataset(data_type, data_list_file, conf=None, partition=True):
""" Construct dataset from arguments for ssl model
We have two shuffle stage in the Dataset. The first is global
shuffle at shards tar/raw file level. The second is global shuffle
at training samples level.
Args:
data_type(str): raw/shard
partition(bool): whether to do data partition in terms of rank
"""
assert conf is not None
assert data_type in ['raw', 'shard']
# cycle dataset
cycle = conf.get('cycle', 1)
# stage1 shuffle: source
list_shuffle = conf.get('list_shuffle', True)

list_shuffle_size = sys.maxsize
if list_shuffle:
list_shuffle_conf = conf.get('list_shuffle_conf', {})
list_shuffle_size = list_shuffle_conf.get('shuffle_size',
list_shuffle_size)
if data_type == 'raw':
dataset = WenetRawDatasetSource(data_list_file,
partition=partition,
shuffle=list_shuffle,
shuffle_size=list_shuffle_size,
cycle=cycle)
dataset = dataset.map(processor.parse_json)
else:
dataset = WenetTarShardDatasetSource(data_list_file,
partition=partition,
shuffle=list_shuffle,
shuffle_size=list_shuffle_size,
cycle=cycle)
dataset = dataset.map_ignore_error(processor.decode_wav)

singal_channel_conf = conf.get('singal_channel_conf', {})
dataset = dataset.map(
partial(processor.singal_channel, **singal_channel_conf))

filter_conf = conf.get('filter_conf', {})
dataset = dataset.filter(partial(processor.filter, **filter_conf))

resample_conf = conf.get('resample_conf', {})
dataset = dataset.map(partial(processor.resample, **resample_conf))

speed_perturb = conf.get('speed_perturb', False)
if speed_perturb:
dataset = dataset.map(partial(processor.speed_perturb))

feats_type = conf.get('feats_type', 'fbank')
assert feats_type in ['fbank', 'mfcc', 'log_mel_spectrogram']
if feats_type == 'fbank':
fbank_conf = conf.get('fbank_conf', {})
dataset = dataset.map(partial(processor.compute_fbank, **fbank_conf))
elif feats_type == 'mfcc':
mfcc_conf = conf.get('mfcc_conf', {})
dataset = dataset.map(partial(processor.compute_mfcc, **mfcc_conf))
elif feats_type == 'log_mel_spectrogram':
log_mel_spectrogram_conf = conf.get('log_mel_spectrogram_conf', {})
dataset = dataset.map(
partial(processor.compute_log_mel_spectrogram,
**log_mel_spectrogram_conf))
spec_aug = conf.get('spec_aug', True)
spec_sub = conf.get('spec_sub', False)
spec_trim = conf.get('spec_trim', False)
if spec_aug:
spec_aug_conf = conf.get('spec_aug_conf', {})
dataset = dataset.map(partial(processor.spec_aug, **spec_aug_conf))
if spec_sub:
spec_sub_conf = conf.get('spec_sub_conf', {})
dataset = dataset.map(partial(processor.spec_sub, **spec_sub_conf))
if spec_trim:
spec_trim_conf = conf.get('spec_trim_conf', {})
dataset = dataset.map(partial(processor.spec_trim, **spec_trim_conf))

shuffle = conf.get('shuffle', True)
if shuffle:
shuffle_conf = conf.get('shuffle_conf', {})
dataset = dataset.shuffle(buffer_size=shuffle_conf['shuffle_size'])

sort = conf.get('sort', True)
if sort:
sort_conf = conf.get('sort_conf', {})
dataset = dataset.sort(buffer_size=sort_conf['sort_size'],
key_func=processor.sort_by_feats)

batch_conf = conf.get('batch_conf', {})
batch_type = batch_conf.get('batch_type', 'static')
assert batch_type in ['static', 'bucket', 'dynamic']
if batch_type == 'static':
assert 'batch_size' in batch_conf
batch_size = batch_conf.get('batch_size', 16)
dataset = dataset.batch(batch_size, wrapper_class=padding)
elif batch_type == 'bucket':
assert 'bucket_boundaries' in batch_conf
assert 'bucket_batch_sizes' in batch_conf
dataset = dataset.bucket_by_sequence_length(
processor.feats_length_fn,
batch_conf['bucket_boundaries'],
batch_conf['bucket_batch_sizes'],
wrapper_class=padding)
else:
max_frames_in_batch = batch_conf.get('max_frames_in_batch', 12000)
dataset = dataset.dynamic_batch(
processor.DynamicBatchWindow(max_frames_in_batch),
wrapper_class=padding,
)

return dataset


def init_dataset(data_type, data_list_file, conf=None, partition=True):
return Dataset(data_type, data_list_file, conf, partition)
41 changes: 41 additions & 0 deletions wenet/utils/init_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import copy
from typing import Optional
from wenet.dataset.dataset import Dataset

from wenet.text.base_tokenizer import BaseTokenizer


def init_asr_dataset(data_type,
data_list_file,
tokenizer: Optional[BaseTokenizer] = None,
conf=None,
partition=True):
return Dataset(data_type, data_list_file, tokenizer, conf, partition)


def init_dataset(dataset_type,
data_type,
data_list_file,
tokenizer: Optional[BaseTokenizer] = None,
conf=None,
partition=True,
split='train'):
assert dataset_type in ['asr', 'ssl']

if split != 'train':
cv_conf = copy.deepcopy(conf)
cv_conf['cycle'] = 1
cv_conf['speed_perturb'] = False
cv_conf['spec_aug'] = False
cv_conf['spec_sub'] = False
cv_conf['spec_trim'] = False
cv_conf['shuffle'] = False
cv_conf['list_shuffle'] = False
conf = cv_conf

if dataset_type == 'asr':
return init_asr_dataset(data_type, data_list_file, tokenizer, conf,
partition)
else:
from wenet.ssl.init_dataset import init_dataset as init_ssl_dataset
return init_ssl_dataset(data_type, data_list_file, conf, partition)
35 changes: 17 additions & 18 deletions wenet/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
estimate_zero3_model_states_mem_needs_all_live)
from deepspeed.utils.zero_to_fp32 import (
convert_zero_checkpoint_to_fp32_state_dict)
from wenet.dataset.dataset import Dataset
from wenet.utils.checkpoint import save_checkpoint
from wenet.utils.common import (StepTimer, get_nested_attribute, lrs_to_str,
tensor_to_scalar)
Expand All @@ -49,6 +48,7 @@
from wenet.utils.scheduler import WarmupLR, NoamHoldAnnealing
from wenet.utils.ctc_utils import get_blank_id
from wenet.utils.common import TORCH_NPU_AVAILABLE
from wenet.utils.init_dataset import init_dataset


def add_model_args(parser):
Expand Down Expand Up @@ -361,24 +361,23 @@ def init_dataset_and_dataloader(args, configs, tokenizer, seed=777):
# if save_interval in configs, steps mode else epoch mode
if "save_interval" in configs:
configs['dataset_conf']['cycle'] = configs.get('max_epoch', 100)
train_conf = configs['dataset_conf']
cv_conf = copy.deepcopy(train_conf)
cv_conf['cycle'] = 1
cv_conf['speed_perturb'] = False
cv_conf['spec_aug'] = False
cv_conf['spec_sub'] = False
cv_conf['spec_trim'] = False
cv_conf['shuffle'] = False
cv_conf['list_shuffle'] = False

conf = configs['dataset_conf']
dataset_type = configs.get('dataset', 'asr')
configs['vocab_size'] = tokenizer.vocab_size()
train_dataset = Dataset(args.data_type, args.train_data, tokenizer,
train_conf, True)
cv_dataset = Dataset(args.data_type,
args.cv_data,
tokenizer,
cv_conf,
partition=False)
train_dataset = init_dataset(dataset_type,
args.data_type,
args.train_data,
tokenizer,
conf,
True,
split='train')
cv_dataset = init_dataset(dataset_type,
args.data_type,
args.cv_data,
tokenizer,
conf,
partition=False,
split='cv')

# NOTE(xcsong): Why we prefer persistent_workers=True ?
# https://discuss.pytorch.org/t/what-are-the-dis-advantages-of-persistent-workers/102110
Expand Down

0 comments on commit fcf26a4

Please sign in to comment.