-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
212 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
from functools import partial | ||
import sys | ||
|
||
import torch | ||
from torch.nn.utils.rnn import pad_sequence | ||
from wenet.dataset import processor | ||
from wenet.dataset.datapipes import WenetRawDatasetSource, WenetTarShardDatasetSource | ||
|
||
|
||
def padding(data): | ||
""" Padding the data into training data | ||
Args: | ||
data: List[{key, feat, label} | ||
Returns: | ||
Tuple(keys, feats, labels, feats lengths, label lengths) | ||
""" | ||
sample = data | ||
assert isinstance(sample, list) | ||
feats_length = torch.tensor([x['feat'].size(0) for x in sample], | ||
dtype=torch.int32) | ||
order = torch.argsort(feats_length, descending=True) | ||
feats_lengths = torch.tensor([sample[i]['feat'].size(0) for i in order], | ||
dtype=torch.int32) | ||
sorted_feats = [sample[i]['feat'] for i in order] | ||
sorted_keys = [sample[i]['key'] for i in order] | ||
padded_feats = pad_sequence(sorted_feats, | ||
batch_first=True, | ||
padding_value=0) | ||
batch = { | ||
"keys": sorted_keys, | ||
"feats": padded_feats, | ||
"feats_lengths": feats_lengths, | ||
} | ||
return batch | ||
|
||
|
||
def Dataset(data_type, data_list_file, conf=None, partition=True): | ||
""" Construct dataset from arguments for ssl model | ||
We have two shuffle stage in the Dataset. The first is global | ||
shuffle at shards tar/raw file level. The second is global shuffle | ||
at training samples level. | ||
Args: | ||
data_type(str): raw/shard | ||
partition(bool): whether to do data partition in terms of rank | ||
""" | ||
assert conf is not None | ||
assert data_type in ['raw', 'shard'] | ||
# cycle dataset | ||
cycle = conf.get('cycle', 1) | ||
# stage1 shuffle: source | ||
list_shuffle = conf.get('list_shuffle', True) | ||
|
||
list_shuffle_size = sys.maxsize | ||
if list_shuffle: | ||
list_shuffle_conf = conf.get('list_shuffle_conf', {}) | ||
list_shuffle_size = list_shuffle_conf.get('shuffle_size', | ||
list_shuffle_size) | ||
if data_type == 'raw': | ||
dataset = WenetRawDatasetSource(data_list_file, | ||
partition=partition, | ||
shuffle=list_shuffle, | ||
shuffle_size=list_shuffle_size, | ||
cycle=cycle) | ||
dataset = dataset.map(processor.parse_json) | ||
else: | ||
dataset = WenetTarShardDatasetSource(data_list_file, | ||
partition=partition, | ||
shuffle=list_shuffle, | ||
shuffle_size=list_shuffle_size, | ||
cycle=cycle) | ||
dataset = dataset.map_ignore_error(processor.decode_wav) | ||
|
||
singal_channel_conf = conf.get('singal_channel_conf', {}) | ||
dataset = dataset.map( | ||
partial(processor.singal_channel, **singal_channel_conf)) | ||
|
||
filter_conf = conf.get('filter_conf', {}) | ||
dataset = dataset.filter(partial(processor.filter, **filter_conf)) | ||
|
||
resample_conf = conf.get('resample_conf', {}) | ||
dataset = dataset.map(partial(processor.resample, **resample_conf)) | ||
|
||
speed_perturb = conf.get('speed_perturb', False) | ||
if speed_perturb: | ||
dataset = dataset.map(partial(processor.speed_perturb)) | ||
|
||
feats_type = conf.get('feats_type', 'fbank') | ||
assert feats_type in ['fbank', 'mfcc', 'log_mel_spectrogram'] | ||
if feats_type == 'fbank': | ||
fbank_conf = conf.get('fbank_conf', {}) | ||
dataset = dataset.map(partial(processor.compute_fbank, **fbank_conf)) | ||
elif feats_type == 'mfcc': | ||
mfcc_conf = conf.get('mfcc_conf', {}) | ||
dataset = dataset.map(partial(processor.compute_mfcc, **mfcc_conf)) | ||
elif feats_type == 'log_mel_spectrogram': | ||
log_mel_spectrogram_conf = conf.get('log_mel_spectrogram_conf', {}) | ||
dataset = dataset.map( | ||
partial(processor.compute_log_mel_spectrogram, | ||
**log_mel_spectrogram_conf)) | ||
spec_aug = conf.get('spec_aug', True) | ||
spec_sub = conf.get('spec_sub', False) | ||
spec_trim = conf.get('spec_trim', False) | ||
if spec_aug: | ||
spec_aug_conf = conf.get('spec_aug_conf', {}) | ||
dataset = dataset.map(partial(processor.spec_aug, **spec_aug_conf)) | ||
if spec_sub: | ||
spec_sub_conf = conf.get('spec_sub_conf', {}) | ||
dataset = dataset.map(partial(processor.spec_sub, **spec_sub_conf)) | ||
if spec_trim: | ||
spec_trim_conf = conf.get('spec_trim_conf', {}) | ||
dataset = dataset.map(partial(processor.spec_trim, **spec_trim_conf)) | ||
|
||
shuffle = conf.get('shuffle', True) | ||
if shuffle: | ||
shuffle_conf = conf.get('shuffle_conf', {}) | ||
dataset = dataset.shuffle(buffer_size=shuffle_conf['shuffle_size']) | ||
|
||
sort = conf.get('sort', True) | ||
if sort: | ||
sort_conf = conf.get('sort_conf', {}) | ||
dataset = dataset.sort(buffer_size=sort_conf['sort_size'], | ||
key_func=processor.sort_by_feats) | ||
|
||
batch_conf = conf.get('batch_conf', {}) | ||
batch_type = batch_conf.get('batch_type', 'static') | ||
assert batch_type in ['static', 'bucket', 'dynamic'] | ||
if batch_type == 'static': | ||
assert 'batch_size' in batch_conf | ||
batch_size = batch_conf.get('batch_size', 16) | ||
dataset = dataset.batch(batch_size, wrapper_class=padding) | ||
elif batch_type == 'bucket': | ||
assert 'bucket_boundaries' in batch_conf | ||
assert 'bucket_batch_sizes' in batch_conf | ||
dataset = dataset.bucket_by_sequence_length( | ||
processor.feats_length_fn, | ||
batch_conf['bucket_boundaries'], | ||
batch_conf['bucket_batch_sizes'], | ||
wrapper_class=padding) | ||
else: | ||
max_frames_in_batch = batch_conf.get('max_frames_in_batch', 12000) | ||
dataset = dataset.dynamic_batch( | ||
processor.DynamicBatchWindow(max_frames_in_batch), | ||
wrapper_class=padding, | ||
) | ||
|
||
return dataset | ||
|
||
|
||
def init_dataset(data_type, data_list_file, conf=None, partition=True): | ||
return Dataset(data_type, data_list_file, conf, partition) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import copy | ||
from typing import Optional | ||
from wenet.dataset.dataset import Dataset | ||
|
||
from wenet.text.base_tokenizer import BaseTokenizer | ||
|
||
|
||
def init_asr_dataset(data_type, | ||
data_list_file, | ||
tokenizer: Optional[BaseTokenizer] = None, | ||
conf=None, | ||
partition=True): | ||
return Dataset(data_type, data_list_file, tokenizer, conf, partition) | ||
|
||
|
||
def init_dataset(dataset_type, | ||
data_type, | ||
data_list_file, | ||
tokenizer: Optional[BaseTokenizer] = None, | ||
conf=None, | ||
partition=True, | ||
split='train'): | ||
assert dataset_type in ['asr', 'ssl'] | ||
|
||
if split != 'train': | ||
cv_conf = copy.deepcopy(conf) | ||
cv_conf['cycle'] = 1 | ||
cv_conf['speed_perturb'] = False | ||
cv_conf['spec_aug'] = False | ||
cv_conf['spec_sub'] = False | ||
cv_conf['spec_trim'] = False | ||
cv_conf['shuffle'] = False | ||
cv_conf['list_shuffle'] = False | ||
conf = cv_conf | ||
|
||
if dataset_type == 'asr': | ||
return init_asr_dataset(data_type, data_list_file, tokenizer, conf, | ||
partition) | ||
else: | ||
from wenet.ssl.init_dataset import init_dataset as init_ssl_dataset | ||
return init_ssl_dataset(data_type, data_list_file, conf, partition) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters