Skip to content

Commit

Permalink
add training script.
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Feb 21, 2020
1 parent 0148732 commit ffa861c
Show file tree
Hide file tree
Showing 8 changed files with 711 additions and 2 deletions.
28 changes: 28 additions & 0 deletions egs/aishell/s10b/ctc/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/usr/bin/env python3

# Copyright 2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang)
# Apache 2.0

from datetime import datetime
import logging


def setup_logger(log_filename, log_level='info'):
now = datetime.now()
date_time = now.strftime('%Y-%m-%d-%H-%M-%S')
log_filename = '{}-{}'.format(log_filename, date_time)
formatter = '%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s'
if log_level == 'debug':
level = logging.DEBUG
elif log_level == 'info':
level = logging.INFO
elif log_level == 'warning':
level = logging.WARNING
logging.basicConfig(filename=log_filename,
format=formatter,
level=level,
filemode='w')
console = logging.StreamHandler()
console.setLevel(level)
console.setFormatter(logging.Formatter(formatter))
logging.getLogger('').addHandler(console)
200 changes: 200 additions & 0 deletions egs/aishell/s10b/ctc/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
#!/usr/bin/env python3

# Copyright 2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang)
# Apache 2.0

import os
import logging

import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import kaldi


def get_ctc_dataloader(feats_scp,
labels_scp=None,
batch_size=1,
shuffle=False,
num_workers=0):

dataset = CtcDataset(feats_scp=feats_scp, labels_scp=labels_scp)

collate_fn = CtcDatasetCollateFunc()

dataloader = DataLoader(dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
collate_fn=collate_fn)

return dataloader


class CtcDataset(Dataset):

def __init__(self, feats_scp, labels_scp=None):
'''
Args:
feats_scp: filename for feats.scp
labels_scp: if provided, it is the filename of labels.scp
'''
assert os.path.isfile(feats_scp)
if labels_scp:
assert os.path.isfile(labels_scp)
logging.info('labels scp: {}'.format(labels_scp))
else:
logging.warn('No labels scp is given.')

# items is a dict of [uttid, feat_rxfilename, None]
# or [uttid, feat_rxfilename, label_rxfilename] if labels_scp is not None
items = dict()

with open(feats_scp, 'r') as f:
for line in f:
# every line has the following format:
# uttid feat_rxfilename
uttid_rxfilename = line.split()
assert len(uttid_rxfilename) == 2

uttid, rxfilename = uttid_rxfilename

assert uttid not in items

items[uttid] = [uttid, rxfilename, None]

if labels_scp:
expected_count = len(items)
n = 0
with open(labels_scp, 'r') as f:
for line in f:
# every line has the following format:
# uttid rxfilename
uttid_rxfilename = line.split()

assert len(uttid_rxfilename) == 2

uttid, rxfilename = uttid_rxfilename

assert uttid in items

items[uttid][-1] = rxfilename

n += 1

# every utterance should have a label if
# labels_scp is given
assert n == expected_count

self.items = list(items.values())
self.num_items = len(self.items)
self.feats_scp = feats_scp
self.labels_scp = labels_scp

def __len__(self):
return self.num_items

def __getitem__(self, i):
'''
Returns:
a list [key, feat_rxfilename, label_rxfilename]
Note that label_rxfilename may be None.
'''
return self.items[i]

def __str__(self):
s = 'feats scp: {}\n'.format(self.feats_scp)

if self.labels_scp:
s += 'labels scp: {}\n'.format(self.labels_scp)

s += 'num utterances: {}\n'.format(self.num_items)

return s


class CtcDatasetCollateFunc:

def __call__(self, batch):
'''
Args:
batch: a list of [uttid, feat_rxfilename, label_rxfilename].
Note that label_rxfilename may be None.
Returns:
uttid_list: a list of utterance id
feat: a 3-D float tensor of shape [batch_size, seq_len, feat_dim]
feat_len_list: number of frames of each utterance before padding
label_list: a list of labels of each utterance; It may be None.
label_len_list: label length of each utterance; It is None if label_list is None.
'''
uttid_list = [] # utterance id of each utterance
feat_len_list = [] # number of frames of each utterance
label_list = [] # label of each utterance
label_len_list = [] # label length of each utterance

feat_list = []

for b in batch:
uttid, feat_rxfilename, label_rxfilename = b

uttid_list.append(uttid)

feat = kaldi.read_mat(feat_rxfilename).numpy()
feat = torch.from_numpy(feat).float()
feat_list.append(feat)

feat_len_list.append(feat.size(0))

if label_rxfilename:
label = kaldi.read_vec_int(label_rxfilename)
label_list.append(label)
label_len_list.append(len(label))

feat = pad_sequence(feat_list, batch_first=True)

if not label_list:
label_list = None
label_len_list = None

return uttid_list, feat, feat_len_list, label_list, label_len_list


def _test_dataset():
feats_scp = 'data/train_sp/feats.scp'
labels_scp = 'data/train_sp/labels.scp'

dataset = CtcDataset(feats_scp=feats_scp, labels_scp=labels_scp)

print(dataset)


def _test_dataloader():
feats_scp = 'data/test/feats.scp'
labels_scp = 'data/test/labels.scp'

dataset = CtcDataset(feats_scp=feats_scp, labels_scp=labels_scp)

dataloader = DataLoader(dataset,
batch_size=2,
num_workers=10,
shuffle=True,
collate_fn=CtcDatasetCollateFunc())
i = 0
for batch in dataloader:
uttid_list, feat, feat_len_list, label_list, label_len_list = batch
print(uttid_list, feat.shape, feat_len_list, label_len_list)
i += 1
if i > 10:
break


if __name__ == '__main__':
# _test_dataset()
_test_dataloader()
145 changes: 145 additions & 0 deletions egs/aishell/s10b/ctc/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
#!/usr/bin/env python3

# Copyright 2019-2020 Mobvoi AI Lab, Beijing, China (author: Fangjun Kuang)
# Apache 2.0

import logging

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence


def get_ctc_model(input_dim,
output_dim,
num_layers=4,
hidden_dim=512,
proj_dim=256):
model = CtcModel(input_dim=input_dim,
output_dim=output_dim,
num_layers=num_layers,
hidden_dim=hidden_dim,
proj_dim=proj_dim)

return model


class CtcModel(nn.Module):

def __init__(self, input_dim, output_dim, num_layers, hidden_dim, proj_dim):
'''
Args:
input_dim: input dimension of the network
output_dim: output dimension of the network
num_layers: number of LSTM layers of the network
hidden_dim: the dimension of the hidden state of LSTM layers
proj_dim: dimension of the affine layer after every LSTM layer
'''
super().__init__()

lstm_layer_list = []
proj_layer_list = []

# batchnorm requires input of shape [N, C, L] == [batch_size, dim, seq_len]
self.input_batch_norm = nn.BatchNorm1d(num_features=input_dim,
affine=False)

for i in range(num_layers):
if i == 0:
lstm_input_dim = input_dim
else:
lstm_input_dim = proj_dim

lstm_layer = nn.LSTM(input_size=lstm_input_dim,
hidden_size=hidden_dim,
num_layers=1,
batch_first=True)

proj_layer = nn.Linear(in_features=hidden_dim,
out_features=proj_dim)

lstm_layer_list.append(lstm_layer)
proj_layer_list.append(proj_layer)

self.lstm_layer_list = nn.ModuleList(lstm_layer_list)
self.proj_layer_list = nn.ModuleList(proj_layer_list)

self.num_layers = num_layers

self.prefinal_affine = nn.Linear(in_features=proj_dim,
out_features=output_dim)

def forward(self, feat, feat_len_list):
'''
Args:
feat: a 3-D tensor of shape [batch_size, seq_len, feat_dim]
feat_len_list: feat length of each utterance before padding
Returns:
a 3-D tensor of shape [batch_size, seq_len, output_dim]
representing log prob, i.e., the output of log_softmax.
'''
x = feat

# at his point, x is of shape [batch_size, seq_len, feat_dim]
x = x.permute(0, 2, 1)

# at his point, x is of shape [batch_size, feat_dim, seq_len] == [N, C, L]
x = self.input_batch_norm(x)

x = x.permute(0, 2, 1)

# at his point, x is of shape [batch_size, seq_len, feat_dim] == [N, L, C]

for i in range(self.num_layers):
x = pack_padded_sequence(input=x,
lengths=feat_len_list,
batch_first=True,
enforce_sorted=False)

# TODO(fangjun): save intermediate LSTM state to support streaming inference
x, _ = self.lstm_layer_list[i](x)

x, _ = pad_packed_sequence(x, batch_first=True)

x = self.proj_layer_list[i](x)

x = torch.tanh(x)

x = self.prefinal_affine(x)

x = F.log_softmax(x, dim=-1)

return x


def _test_ctc_model():
input_dim = 5
output_dim = 20
model = CtcModel(input_dim=input_dim,
output_dim=output_dim,
num_layers=2,
hidden_dim=3,
proj_dim=4)

feat1 = torch.randn((6, input_dim))
feat2 = torch.randn((8, input_dim))

from torch.nn.utils.rnn import pad_sequence
feat = pad_sequence([feat1, feat2], batch_first=True)
assert feat.shape == torch.Size([2, 8, input_dim])

feat_len_list = [6, 8]
x = model(feat, feat_len_list)

assert x.shape == torch.Size([2, 8, output_dim])


if __name__ == '__main__':
_test_ctc_model()
Loading

0 comments on commit ffa861c

Please sign in to comment.