-
Notifications
You must be signed in to change notification settings - Fork 5.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0148732
commit ffa861c
Showing
8 changed files
with
711 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright 2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) | ||
# Apache 2.0 | ||
|
||
from datetime import datetime | ||
import logging | ||
|
||
|
||
def setup_logger(log_filename, log_level='info'): | ||
now = datetime.now() | ||
date_time = now.strftime('%Y-%m-%d-%H-%M-%S') | ||
log_filename = '{}-{}'.format(log_filename, date_time) | ||
formatter = '%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s' | ||
if log_level == 'debug': | ||
level = logging.DEBUG | ||
elif log_level == 'info': | ||
level = logging.INFO | ||
elif log_level == 'warning': | ||
level = logging.WARNING | ||
logging.basicConfig(filename=log_filename, | ||
format=formatter, | ||
level=level, | ||
filemode='w') | ||
console = logging.StreamHandler() | ||
console.setLevel(level) | ||
console.setFormatter(logging.Formatter(formatter)) | ||
logging.getLogger('').addHandler(console) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright 2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) | ||
# Apache 2.0 | ||
|
||
import os | ||
import logging | ||
|
||
import torch | ||
from torch.nn.utils.rnn import pad_sequence | ||
from torch.utils.data import DataLoader | ||
from torch.utils.data import Dataset | ||
|
||
import kaldi | ||
|
||
|
||
def get_ctc_dataloader(feats_scp, | ||
labels_scp=None, | ||
batch_size=1, | ||
shuffle=False, | ||
num_workers=0): | ||
|
||
dataset = CtcDataset(feats_scp=feats_scp, labels_scp=labels_scp) | ||
|
||
collate_fn = CtcDatasetCollateFunc() | ||
|
||
dataloader = DataLoader(dataset, | ||
batch_size=batch_size, | ||
shuffle=shuffle, | ||
num_workers=num_workers, | ||
collate_fn=collate_fn) | ||
|
||
return dataloader | ||
|
||
|
||
class CtcDataset(Dataset): | ||
|
||
def __init__(self, feats_scp, labels_scp=None): | ||
''' | ||
Args: | ||
feats_scp: filename for feats.scp | ||
labels_scp: if provided, it is the filename of labels.scp | ||
''' | ||
assert os.path.isfile(feats_scp) | ||
if labels_scp: | ||
assert os.path.isfile(labels_scp) | ||
logging.info('labels scp: {}'.format(labels_scp)) | ||
else: | ||
logging.warn('No labels scp is given.') | ||
|
||
# items is a dict of [uttid, feat_rxfilename, None] | ||
# or [uttid, feat_rxfilename, label_rxfilename] if labels_scp is not None | ||
items = dict() | ||
|
||
with open(feats_scp, 'r') as f: | ||
for line in f: | ||
# every line has the following format: | ||
# uttid feat_rxfilename | ||
uttid_rxfilename = line.split() | ||
assert len(uttid_rxfilename) == 2 | ||
|
||
uttid, rxfilename = uttid_rxfilename | ||
|
||
assert uttid not in items | ||
|
||
items[uttid] = [uttid, rxfilename, None] | ||
|
||
if labels_scp: | ||
expected_count = len(items) | ||
n = 0 | ||
with open(labels_scp, 'r') as f: | ||
for line in f: | ||
# every line has the following format: | ||
# uttid rxfilename | ||
uttid_rxfilename = line.split() | ||
|
||
assert len(uttid_rxfilename) == 2 | ||
|
||
uttid, rxfilename = uttid_rxfilename | ||
|
||
assert uttid in items | ||
|
||
items[uttid][-1] = rxfilename | ||
|
||
n += 1 | ||
|
||
# every utterance should have a label if | ||
# labels_scp is given | ||
assert n == expected_count | ||
|
||
self.items = list(items.values()) | ||
self.num_items = len(self.items) | ||
self.feats_scp = feats_scp | ||
self.labels_scp = labels_scp | ||
|
||
def __len__(self): | ||
return self.num_items | ||
|
||
def __getitem__(self, i): | ||
''' | ||
Returns: | ||
a list [key, feat_rxfilename, label_rxfilename] | ||
Note that label_rxfilename may be None. | ||
''' | ||
return self.items[i] | ||
|
||
def __str__(self): | ||
s = 'feats scp: {}\n'.format(self.feats_scp) | ||
|
||
if self.labels_scp: | ||
s += 'labels scp: {}\n'.format(self.labels_scp) | ||
|
||
s += 'num utterances: {}\n'.format(self.num_items) | ||
|
||
return s | ||
|
||
|
||
class CtcDatasetCollateFunc: | ||
|
||
def __call__(self, batch): | ||
''' | ||
Args: | ||
batch: a list of [uttid, feat_rxfilename, label_rxfilename]. | ||
Note that label_rxfilename may be None. | ||
Returns: | ||
uttid_list: a list of utterance id | ||
feat: a 3-D float tensor of shape [batch_size, seq_len, feat_dim] | ||
feat_len_list: number of frames of each utterance before padding | ||
label_list: a list of labels of each utterance; It may be None. | ||
label_len_list: label length of each utterance; It is None if label_list is None. | ||
''' | ||
uttid_list = [] # utterance id of each utterance | ||
feat_len_list = [] # number of frames of each utterance | ||
label_list = [] # label of each utterance | ||
label_len_list = [] # label length of each utterance | ||
|
||
feat_list = [] | ||
|
||
for b in batch: | ||
uttid, feat_rxfilename, label_rxfilename = b | ||
|
||
uttid_list.append(uttid) | ||
|
||
feat = kaldi.read_mat(feat_rxfilename).numpy() | ||
feat = torch.from_numpy(feat).float() | ||
feat_list.append(feat) | ||
|
||
feat_len_list.append(feat.size(0)) | ||
|
||
if label_rxfilename: | ||
label = kaldi.read_vec_int(label_rxfilename) | ||
label_list.append(label) | ||
label_len_list.append(len(label)) | ||
|
||
feat = pad_sequence(feat_list, batch_first=True) | ||
|
||
if not label_list: | ||
label_list = None | ||
label_len_list = None | ||
|
||
return uttid_list, feat, feat_len_list, label_list, label_len_list | ||
|
||
|
||
def _test_dataset(): | ||
feats_scp = 'data/train_sp/feats.scp' | ||
labels_scp = 'data/train_sp/labels.scp' | ||
|
||
dataset = CtcDataset(feats_scp=feats_scp, labels_scp=labels_scp) | ||
|
||
print(dataset) | ||
|
||
|
||
def _test_dataloader(): | ||
feats_scp = 'data/test/feats.scp' | ||
labels_scp = 'data/test/labels.scp' | ||
|
||
dataset = CtcDataset(feats_scp=feats_scp, labels_scp=labels_scp) | ||
|
||
dataloader = DataLoader(dataset, | ||
batch_size=2, | ||
num_workers=10, | ||
shuffle=True, | ||
collate_fn=CtcDatasetCollateFunc()) | ||
i = 0 | ||
for batch in dataloader: | ||
uttid_list, feat, feat_len_list, label_list, label_len_list = batch | ||
print(uttid_list, feat.shape, feat_len_list, label_len_list) | ||
i += 1 | ||
if i > 10: | ||
break | ||
|
||
|
||
if __name__ == '__main__': | ||
# _test_dataset() | ||
_test_dataloader() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright 2019-2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang) | ||
# Apache 2.0 | ||
|
||
import logging | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch.nn.utils.rnn import pack_padded_sequence | ||
from torch.nn.utils.rnn import pad_packed_sequence | ||
|
||
|
||
def get_ctc_model(input_dim, | ||
output_dim, | ||
num_layers=4, | ||
hidden_dim=512, | ||
proj_dim=256): | ||
model = CtcModel(input_dim=input_dim, | ||
output_dim=output_dim, | ||
num_layers=num_layers, | ||
hidden_dim=hidden_dim, | ||
proj_dim=proj_dim) | ||
|
||
return model | ||
|
||
|
||
class CtcModel(nn.Module): | ||
|
||
def __init__(self, input_dim, output_dim, num_layers, hidden_dim, proj_dim): | ||
''' | ||
Args: | ||
input_dim: input dimension of the network | ||
output_dim: output dimension of the network | ||
num_layers: number of LSTM layers of the network | ||
hidden_dim: the dimension of the hidden state of LSTM layers | ||
proj_dim: dimension of the affine layer after every LSTM layer | ||
''' | ||
super().__init__() | ||
|
||
lstm_layer_list = [] | ||
proj_layer_list = [] | ||
|
||
# batchnorm requires input of shape [N, C, L] == [batch_size, dim, seq_len] | ||
self.input_batch_norm = nn.BatchNorm1d(num_features=input_dim, | ||
affine=False) | ||
|
||
for i in range(num_layers): | ||
if i == 0: | ||
lstm_input_dim = input_dim | ||
else: | ||
lstm_input_dim = proj_dim | ||
|
||
lstm_layer = nn.LSTM(input_size=lstm_input_dim, | ||
hidden_size=hidden_dim, | ||
num_layers=1, | ||
batch_first=True) | ||
|
||
proj_layer = nn.Linear(in_features=hidden_dim, | ||
out_features=proj_dim) | ||
|
||
lstm_layer_list.append(lstm_layer) | ||
proj_layer_list.append(proj_layer) | ||
|
||
self.lstm_layer_list = nn.ModuleList(lstm_layer_list) | ||
self.proj_layer_list = nn.ModuleList(proj_layer_list) | ||
|
||
self.num_layers = num_layers | ||
|
||
self.prefinal_affine = nn.Linear(in_features=proj_dim, | ||
out_features=output_dim) | ||
|
||
def forward(self, feat, feat_len_list): | ||
''' | ||
Args: | ||
feat: a 3-D tensor of shape [batch_size, seq_len, feat_dim] | ||
feat_len_list: feat length of each utterance before padding | ||
Returns: | ||
a 3-D tensor of shape [batch_size, seq_len, output_dim] | ||
representing log prob, i.e., the output of log_softmax. | ||
''' | ||
x = feat | ||
|
||
# at his point, x is of shape [batch_size, seq_len, feat_dim] | ||
x = x.permute(0, 2, 1) | ||
|
||
# at his point, x is of shape [batch_size, feat_dim, seq_len] == [N, C, L] | ||
x = self.input_batch_norm(x) | ||
|
||
x = x.permute(0, 2, 1) | ||
|
||
# at his point, x is of shape [batch_size, seq_len, feat_dim] == [N, L, C] | ||
|
||
for i in range(self.num_layers): | ||
x = pack_padded_sequence(input=x, | ||
lengths=feat_len_list, | ||
batch_first=True, | ||
enforce_sorted=False) | ||
|
||
# TODO(fangjun): save intermediate LSTM state to support streaming inference | ||
x, _ = self.lstm_layer_list[i](x) | ||
|
||
x, _ = pad_packed_sequence(x, batch_first=True) | ||
|
||
x = self.proj_layer_list[i](x) | ||
|
||
x = torch.tanh(x) | ||
|
||
x = self.prefinal_affine(x) | ||
|
||
x = F.log_softmax(x, dim=-1) | ||
|
||
return x | ||
|
||
|
||
def _test_ctc_model(): | ||
input_dim = 5 | ||
output_dim = 20 | ||
model = CtcModel(input_dim=input_dim, | ||
output_dim=output_dim, | ||
num_layers=2, | ||
hidden_dim=3, | ||
proj_dim=4) | ||
|
||
feat1 = torch.randn((6, input_dim)) | ||
feat2 = torch.randn((8, input_dim)) | ||
|
||
from torch.nn.utils.rnn import pad_sequence | ||
feat = pad_sequence([feat1, feat2], batch_first=True) | ||
assert feat.shape == torch.Size([2, 8, input_dim]) | ||
|
||
feat_len_list = [6, 8] | ||
x = model(feat, feat_len_list) | ||
|
||
assert x.shape == torch.Size([2, 8, output_dim]) | ||
|
||
|
||
if __name__ == '__main__': | ||
_test_ctc_model() |
Oops, something went wrong.