Skip to content

Commit

Permalink
updating the HCLG boosting code, debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
KarelVesely84 committed Jun 18, 2021
1 parent f399514 commit f184046
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 81 deletions.
4 changes: 2 additions & 2 deletions egs/wsj/s5/steps/nnet3/decode_compose.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ min_active=200
ivector_scale=1.0
lattice_beam=8.0 # Beam we use in lattice generation.
iter=final
num_threads=1 # if >1, will use gmm-latgen-faster-parallel
use_gpu=false # If true, will use a GPU, with nnet3-latgen-faster-batch.
#num_threads=1 # if >1, will use gmm-latgen-faster-parallel
#use_gpu=false # If true, will use a GPU, with nnet3-latgen-faster-batch.
# In that case it is recommended to set num-threads to a large
# number, e.g. 20 if you have that many free CPU slots on a GPU
# node, and to use a small number of jobs.
Expand Down
172 changes: 93 additions & 79 deletions src/nnet3bin/nnet3-latgen-faster-compose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "base/timer.h"

#include <fst/compose.h>
#include <fst/rmepsilon.h>
#include <memory>


Expand Down Expand Up @@ -154,106 +155,119 @@ int main(int argc, char *argv[]) {

RandomAccessTableReader<fst::VectorFstHolder> boosting_fst_reader(boosting_fst_rspecifier);

// HCLG FST is just one FST, not a table of FSTs.
auto hclg_fst = std::unique_ptr<VectorFst<StdArc>>(fst::ReadFstKaldi(hclg_fst_rxfilename));
// 'hclg_fst' is a single FST.
VectorFst<StdArc> hclg_fst;
{
auto hclg_fst_tmp = std::unique_ptr<Fst<StdArc>>(fst::ReadFstKaldiGeneric(hclg_fst_rxfilename));
hclg_fst = VectorFst<StdArc>(*hclg_fst_tmp); // Fst -> VectorFst, as it has to be MutableFst...
// 'hclg_fst_tmp' is deleted by 'going out of scope' ...
}

// make sure hclg is sorted on olabel
if (hclg_fst->Properties(fst::kOLabelSorted, true) == 0) {
if (hclg_fst.Properties(fst::kOLabelSorted, true) == 0) {
fst::OLabelCompare<StdArc> olabel_comp;
fst::ArcSort(hclg_fst.get(), olabel_comp);
fst::ArcSort(&hclg_fst, olabel_comp);
}

timer.Reset();

{

for (; !feature_reader.Done(); feature_reader.Next()) {
std::string utt = feature_reader.Key();
const Matrix<BaseFloat> &features (feature_reader.Value());
if (features.NumRows() == 0) {
KALDI_WARN << "Zero-length utterance: " << utt;
//// MAIN LOOP ////
for (; !feature_reader.Done(); feature_reader.Next()) {
std::string utt = feature_reader.Key();
const Matrix<BaseFloat> &features (feature_reader.Value());
if (features.NumRows() == 0) {
KALDI_WARN << "Zero-length utterance: " << utt;
num_fail++;
continue;
}
const Matrix<BaseFloat> *online_ivectors = NULL;
const Vector<BaseFloat> *ivector = NULL;
if (!ivector_rspecifier.empty()) {
if (!ivector_reader.HasKey(utt)) {
KALDI_WARN << "No iVector available for utterance " << utt;
num_fail++;
continue;
} else {
ivector = &ivector_reader.Value(utt);
}
const Matrix<BaseFloat> *online_ivectors = NULL;
const Vector<BaseFloat> *ivector = NULL;
if (!ivector_rspecifier.empty()) {
if (!ivector_reader.HasKey(utt)) {
KALDI_WARN << "No iVector available for utterance " << utt;
num_fail++;
continue;
} else {
ivector = &ivector_reader.Value(utt);
}
}
if (!online_ivector_rspecifier.empty()) {
if (!online_ivector_reader.HasKey(utt)) {
KALDI_WARN << "No online iVector available for utterance " << utt;
num_fail++;
continue;
} else {
online_ivectors = &online_ivector_reader.Value(utt);
}
}

// get the boosting graph,
VectorFst<StdArc> boosting_fst;
if (!boosting_fst_reader.HasKey(utt)) {
KALDI_WARN << "No boosting fst for utterance " << utt;
}
if (!online_ivector_rspecifier.empty()) {
if (!online_ivector_reader.HasKey(utt)) {
KALDI_WARN << "No online iVector available for utterance " << utt;
num_fail++;
continue;
} else {
boosting_fst = boosting_fst_reader.Value(utt); // copy,
online_ivectors = &online_ivector_reader.Value(utt);
}
}

timer_compose.Reset();

// make sure boosting graph is sorted on ilabel,
if (boosting_fst.Properties(fst::kILabelSorted, true) == 0) {
fst::ILabelCompare<StdArc> ilabel_comp;
fst::ArcSort(&boosting_fst, ilabel_comp);
}
// get the boosting graph,
VectorFst<StdArc> boosting_fst;
if (!boosting_fst_reader.HasKey(utt)) {
KALDI_WARN << "No boosting fst for utterance " << utt;
num_fail++;
continue;
} else {
boosting_fst = boosting_fst_reader.Value(utt); // copy,
}

// TODO: should we call rmepsilon on boosting_fst ?
timer_compose.Reset();

// run composition (measure time),
VectorFst<StdArc> decode_fst;
fst::Compose(*hclg_fst, boosting_fst, &decode_fst);
// RmEpsilon saved 30% of composition runtime...
// - Note: we are loading 2-state graphs with eps back-link to the initial state.
if (boosting_fst.Properties(fst::kIEpsilons, true) != 0) {
fst::RmEpsilon(&boosting_fst);
}

// TODO: should we sort the 'decode_fst' by isymbols ?
// (we don't do it, as it would take time.
// not sure it decoding would be faster if
// decode_fst was sorted by isymbols)
// make sure boosting graph is sorted on ilabel,
if (boosting_fst.Properties(fst::kILabelSorted, true) == 0) {
fst::ILabelCompare<StdArc> ilabel_comp;
fst::ArcSort(&boosting_fst, ilabel_comp);
}

// Check that composed graph is non-empty,
if (decode_fst.Start() == fst::kNoStateId) {
KALDI_WARN << "Empty 'decode_fst' HCLG for utterance "
<< utt << " (bad boosting graph?)";
num_fail++;
continue;
}
// run composition,
VectorFst<StdArc> decode_fst;
fst::Compose(hclg_fst, boosting_fst, &decode_fst);

elapsed_compose += timer_compose.Elapsed();

DecodableAmNnetSimple nnet_decodable(
decodable_opts, trans_model, am_nnet,
features, ivector, online_ivectors,
online_ivector_period, &compiler);

LatticeFasterDecoder decoder(decode_fst, config);

double like;
if (DecodeUtteranceLatticeFaster(
decoder, nnet_decodable, trans_model, word_syms.get(), utt,
decodable_opts.acoustic_scale, determinize, allow_partial,
&alignment_writer, &words_writer, &compact_lattice_writer,
&lattice_writer,
&like)) {
tot_like += like;
frame_count += nnet_decodable.NumFramesReady();
num_success++;
} else num_fail++;
// check that composed graph is non-empty,
if (decode_fst.Start() == fst::kNoStateId) {
KALDI_WARN << "Empty 'decode_fst' HCLG for utterance "
<< utt << " (bad boosting graph?)";
num_fail++;
continue;
}

elapsed_compose += timer_compose.Elapsed();

DecodableAmNnetSimple nnet_decodable(
decodable_opts, trans_model, am_nnet,
features, ivector, online_ivectors,
online_ivector_period, &compiler);

// Note: decode_fst is VectorFst, not ConstFst.
//
// OpenFst docs say that more specific iterators
// are faster than generic iterators. And in HCLG
// is usually loaded for decoding as ConstFst.
//
// auto decode_fst_ = ConstFst<StdArc>(decode_fst);
//
// In this way, I tried to cast VectorFst to ConstFst,
// but this made the decoding 20% slower.
//
LatticeFasterDecoder decoder(decode_fst, config);

double like;
if (DecodeUtteranceLatticeFaster(
decoder, nnet_decodable, trans_model, word_syms.get(), utt,
decodable_opts.acoustic_scale, determinize, allow_partial,
&alignment_writer, &words_writer, &compact_lattice_writer,
&lattice_writer,
&like)) {
tot_like += like;
frame_count += nnet_decodable.NumFramesReady();
num_success++;
} else num_fail++;
}
}

Expand Down

0 comments on commit f184046

Please sign in to comment.