Commit 638fd707 authored by Guoguo Chen's avatar Guoguo Chen
Browse files

trunk: adding optional 3rd argument to lattice-to-ctm-conf so that it can also...

trunk: adding optional 3rd argument to lattice-to-ctm-conf so that it can also reads the word sequence from input

git-svn-id: https://svn.code.sf.net/p/kaldi/code/trunk@5215 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
parent e406aeec
// lat/sausages.cc // lat/sausages.cc
// Copyright 2012 Johns Hopkins University (Author: Daniel Povey) // Copyright 2012 Johns Hopkins University (Author: Daniel Povey)
// 2015 Guoguo Chen
// See ../../COPYING for clarification regarding multiple authors // See ../../COPYING for clarification regarding multiple authors
// //
...@@ -264,21 +265,20 @@ void MinimumBayesRisk::AccStats() { ...@@ -264,21 +265,20 @@ void MinimumBayesRisk::AccStats() {
} }
} }
MinimumBayesRisk::MinimumBayesRisk(const CompactLattice &clat_in, bool do_mbr): void MinimumBayesRisk::PrepareLatticeAndInitStats(CompactLattice *clat) {
do_mbr_(do_mbr) { KALDI_ASSERT(clat != NULL);
CompactLattice clat(clat_in); // copy.
CreateSuperFinal(&clat); // Add super-final state to clat... this is CreateSuperFinal(clat); // Add super-final state to clat... this is
// one of the requirements of the MBR algorithm, as mentioned in the // one of the requirements of the MBR algorithm, as mentioned in the
// paper (i.e. just one final state). // paper (i.e. just one final state).
// Topologically sort the lattice, if not already sorted. // Topologically sort the lattice, if not already sorted.
kaldi::uint64 props = clat.Properties(fst::kFstProperties, false); kaldi::uint64 props = clat->Properties(fst::kFstProperties, false);
if (!(props & fst::kTopSorted)) { if (!(props & fst::kTopSorted)) {
if (fst::TopSort(&clat) == false) if (fst::TopSort(clat) == false)
KALDI_ERR << "Cycles detected in lattice."; KALDI_ERR << "Cycles detected in lattice.";
} }
CompactLatticeStateTimes(clat, &state_times_); // work out times of CompactLatticeStateTimes(*clat, &state_times_); // work out times of
// the states in clat // the states in clat
state_times_.push_back(0); // we'll convert to 1-based numbering. state_times_.push_back(0); // we'll convert to 1-based numbering.
for (size_t i = state_times_.size()-1; i > 0; i--) for (size_t i = state_times_.size()-1; i > 0; i--)
...@@ -289,13 +289,13 @@ MinimumBayesRisk::MinimumBayesRisk(const CompactLattice &clat_in, bool do_mbr): ...@@ -289,13 +289,13 @@ MinimumBayesRisk::MinimumBayesRisk(const CompactLattice &clat_in, bool do_mbr):
// arcs preceding any given state. // arcs preceding any given state.
// Note: in our internal format the states will be numbered from 1, // Note: in our internal format the states will be numbered from 1,
// which involves adding 1 to the OpenFst states. // which involves adding 1 to the OpenFst states.
int32 N = clat.NumStates(); int32 N = clat->NumStates();
pre_.resize(N+1); pre_.resize(N+1);
// Careful: "Arc" is a class-member struct, not an OpenFst type of arc as one // Careful: "Arc" is a class-member struct, not an OpenFst type of arc as one
// would normally assume. // would normally assume.
for (int32 n = 1; n <= N; n++) { for (int32 n = 1; n <= N; n++) {
for (fst::ArcIterator<CompactLattice> aiter(clat, n-1); for (fst::ArcIterator<CompactLattice> aiter(*clat, n-1);
!aiter.Done(); !aiter.Done();
aiter.Next()) { aiter.Next()) {
const CompactLatticeArc &carc = aiter.Value(); const CompactLatticeArc &carc = aiter.Value();
...@@ -312,6 +312,13 @@ MinimumBayesRisk::MinimumBayesRisk(const CompactLattice &clat_in, bool do_mbr): ...@@ -312,6 +312,13 @@ MinimumBayesRisk::MinimumBayesRisk(const CompactLattice &clat_in, bool do_mbr):
arcs_.push_back(arc); arcs_.push_back(arc);
} }
} }
}
MinimumBayesRisk::MinimumBayesRisk(const CompactLattice &clat_in, bool do_mbr):
do_mbr_(do_mbr) {
CompactLattice clat(clat_in); // copy.
PrepareLatticeAndInitStats(&clat);
// We don't need to look at clat.Start() or clat.Final(state): // We don't need to look at clat.Start() or clat.Final(state):
// we know clat.Start() == 0 since it's topologically sorted, // we know clat.Start() == 0 since it's topologically sorted,
...@@ -341,5 +348,17 @@ MinimumBayesRisk::MinimumBayesRisk(const CompactLattice &clat_in, bool do_mbr): ...@@ -341,5 +348,17 @@ MinimumBayesRisk::MinimumBayesRisk(const CompactLattice &clat_in, bool do_mbr):
} }
MinimumBayesRisk::MinimumBayesRisk(const CompactLattice &clat_in,
const std::vector<int32> &words,
bool do_mbr): do_mbr_(do_mbr) {
CompactLattice clat(clat_in); // copy.
PrepareLatticeAndInitStats(&clat);
R_ = words;
L_ = 0.0;
MbrDecode();
}
} // namespace kaldi } // namespace kaldi
// lat/sausages.h // lat/sausages.h
// Copyright 2012 Johns Hopkins University (Author: Daniel Povey) // Copyright 2012 Johns Hopkins University (Author: Daniel Povey)
// 2015 Guoguo Chen
// See ../../COPYING for clarification regarding multiple authors // See ../../COPYING for clarification regarding multiple authors
// //
...@@ -64,7 +65,11 @@ class MinimumBayesRisk { ...@@ -64,7 +65,11 @@ class MinimumBayesRisk {
MinimumBayesRisk(const CompactLattice &clat, bool do_mbr = true); // if do_mbr == false, MinimumBayesRisk(const CompactLattice &clat, bool do_mbr = true); // if do_mbr == false,
// it will just use the MAP recognition output, but will get the MBR stats for things // it will just use the MAP recognition output, but will get the MBR stats for things
// like confidences. // like confidences.
// Uses the provided <words> as <R_> instead of using the lattice best path.
MinimumBayesRisk(const CompactLattice &clat,
const std::vector<int32> &words, bool do_mbr = false);
const std::vector<int32> &GetOneBest() const { // gets one-best (with no epsilons) const std::vector<int32> &GetOneBest() const { // gets one-best (with no epsilons)
return R_; return R_;
} }
...@@ -96,6 +101,8 @@ class MinimumBayesRisk { ...@@ -96,6 +101,8 @@ class MinimumBayesRisk {
} }
private: private:
void PrepareLatticeAndInitStats(CompactLattice *clat);
/// Minimum-Bayes-Risk Decode. Top-level algorithm. Figure 6 of the paper. /// Minimum-Bayes-Risk Decode. Top-level algorithm. Figure 6 of the paper.
void MbrDecode(); void MbrDecode();
......
// latbin/lattice-to-ctm-conf.cc // latbin/lattice-to-ctm-conf.cc
// Copyright 2012-2014 Johns Hopkins University (Author: Daniel Povey) // Copyright 2012-2014 Johns Hopkins University (Author: Daniel Povey)
// 2015 Guoguo Chen
// See ../../COPYING for clarification regarding multiple authors // See ../../COPYING for clarification regarding multiple authors
// //
...@@ -28,20 +29,31 @@ int main(int argc, char *argv[]) { ...@@ -28,20 +29,31 @@ int main(int argc, char *argv[]) {
typedef kaldi::int32 int32; typedef kaldi::int32 int32;
const char *usage = const char *usage =
"Generate 1-best from lattices and convert into ctm with confidences.\n" "This tool turns a lattice into a ctm with confidences, based on the\n"
"If --decode-mbr=true, does Minimum Bayes Risk decoding, else normal\n" "posterior probabilities in the lattice. The word sequence in the\n"
"Maximum A Posteriori (but works out the confidences based on posteriors\n" "ctm is determined as follows. Firstly we determine the initial word\n"
"in the lattice, using the MBR code). Note: if you don't need confidences,\n" "sequence. In the 3-argument form, we read it from the\n"
"you can do lattice-1best and pipe to nbest-to-ctm. \n" "<1best-rspecifier> input; otherwise it is the 1-best of the lattice.\n"
"Note: the ctm this produces will be relative to the utterance-id.\n" "Then, if --decode-mbr=true, we iteratively refine the hypothesis\n"
"Note: the times will only be correct if you do lattice-align-words\n" "using Minimum Bayes Risk decoding. If you don't need confidences,\n"
"on the input\n" "you can do lattice-1best and pipe to nbest-to-ctm. The ctm this\n"
"program produces will be relative to the utterance-id; a standard\n"
"ctm relative to the filename can be obtained using\n"
"utils/convert_ctm.pl. The times produced by this program will only\n"
"be meaningful if you do lattice-align-words on the input. The\n"
"<1-best-rspecifier> could be the output of utils/int2sym.pl or\n"
"nbest-to-linear.\n"
"\n" "\n"
"Usage: lattice-to-ctm-conf [options] <lattice-rspecifier> <ctm-wxfilename>\n" "Usage: lattice-to-ctm-conf [options] <lattice-rspecifier> \\\n"
" <ctm-wxfilename>\n"
"Usage: lattice-to-ctm-conf [options] <lattice-rspecifier> \\\n"
" [<1best-rspecifier>] <ctm-wxfilename>\n"
" e.g.: lattice-to-ctm-conf --acoustic-scale=0.1 ark:1.lats 1.ctm\n" " e.g.: lattice-to-ctm-conf --acoustic-scale=0.1 ark:1.lats 1.ctm\n"
"See also: lattice-mbr-decode, and the scripts steps/get_ctm.sh and\n" " or: lattice-to-ctm-conf --acoustic-scale=0.1 --decode-mbr=false\\\n"
" steps/get_train_ctm.sh\n"; " ark:1.lats ark:1.1best 1.ctm\n"
"See also: lattice-mbr-decode, nbest-to-ctm, steps/get_ctm.sh,\n"
" steps/get_train_ctm.sh and utils/convert_ctm.sh.\n";
ParseOptions po(usage); ParseOptions po(usage);
BaseFloat acoustic_scale = 1.0, inv_acoustic_scale = 1.0, lm_scale = 1.0; BaseFloat acoustic_scale = 1.0, inv_acoustic_scale = 1.0, lm_scale = 1.0;
bool decode_mbr = true; bool decode_mbr = true;
...@@ -56,11 +68,11 @@ int main(int argc, char *argv[]) { ...@@ -56,11 +68,11 @@ int main(int argc, char *argv[]) {
"probabilities"); "probabilities");
po.Register("decode-mbr", &decode_mbr, "If true, do Minimum Bayes Risk " po.Register("decode-mbr", &decode_mbr, "If true, do Minimum Bayes Risk "
"decoding (else, Maximum a Posteriori)"); "decoding (else, Maximum a Posteriori)");
po.Register("frame-shift", &frame_shift, "Time in seconds between frames.\n"); po.Register("frame-shift", &frame_shift, "Time in seconds between frames.");
po.Read(argc, argv); po.Read(argc, argv);
if (po.NumArgs() != 2) { if (po.NumArgs() != 2 && po.NumArgs() != 3) {
po.PrintUsage(); po.PrintUsage();
exit(1); exit(1);
} }
...@@ -69,8 +81,17 @@ int main(int argc, char *argv[]) { ...@@ -69,8 +81,17 @@ int main(int argc, char *argv[]) {
if (inv_acoustic_scale != 1.0) if (inv_acoustic_scale != 1.0)
acoustic_scale = 1.0 / inv_acoustic_scale; acoustic_scale = 1.0 / inv_acoustic_scale;
std::string lats_rspecifier = po.GetArg(1), std::string lats_rspecifier, one_best_rspecifier, ctm_wxfilename;
ctm_wxfilename = po.GetArg(2);
if (po.NumArgs() == 2) {
lats_rspecifier = po.GetArg(1);
one_best_rspecifier = "";
ctm_wxfilename = po.GetArg(2);
} else if (po.NumArgs() == 3) {
lats_rspecifier = po.GetArg(1);
one_best_rspecifier = po.GetArg(2);
ctm_wxfilename = po.GetArg(3);
}
// Ensure the output ctm file is not a wspecifier // Ensure the output ctm file is not a wspecifier
WspecifierType ctm_wx_type; WspecifierType ctm_wx_type;
...@@ -83,6 +104,8 @@ int main(int argc, char *argv[]) { ...@@ -83,6 +104,8 @@ int main(int argc, char *argv[]) {
// Read as compact lattice. // Read as compact lattice.
SequentialCompactLatticeReader clat_reader(lats_rspecifier); SequentialCompactLatticeReader clat_reader(lats_rspecifier);
RandomAccessInt32VectorReader one_best_reader(one_best_rspecifier);
Output ko(ctm_wxfilename, false); // false == non-binary writing mode. Output ko(ctm_wxfilename, false); // false == non-binary writing mode.
ko.Stream() << std::fixed; // Set to "fixed" floating point model, where precision() specifies ko.Stream() << std::fixed; // Set to "fixed" floating point model, where precision() specifies
...@@ -98,12 +121,23 @@ int main(int argc, char *argv[]) { ...@@ -98,12 +121,23 @@ int main(int argc, char *argv[]) {
clat_reader.FreeCurrent(); clat_reader.FreeCurrent();
fst::ScaleLattice(fst::LatticeScale(lm_scale, acoustic_scale), &clat); fst::ScaleLattice(fst::LatticeScale(lm_scale, acoustic_scale), &clat);
MinimumBayesRisk mbr(clat, decode_mbr); MinimumBayesRisk *mbr = NULL;
if (one_best_rspecifier == "") {
mbr = new MinimumBayesRisk(clat, decode_mbr);
} else {
if (!one_best_reader.HasKey(key)) {
KALDI_WARN << "No 1-best present for utterance " << key;
continue;
}
const std::vector<int32> &one_best = one_best_reader.Value(key);
mbr = new MinimumBayesRisk(clat, one_best, decode_mbr);
}
const std::vector<BaseFloat> &conf = mbr.GetOneBestConfidences(); const std::vector<BaseFloat> &conf = mbr->GetOneBestConfidences();
const std::vector<int32> &words = mbr.GetOneBest(); const std::vector<int32> &words = mbr->GetOneBest();
const std::vector<std::pair<BaseFloat, BaseFloat> > &times = const std::vector<std::pair<BaseFloat, BaseFloat> > &times =
mbr.GetOneBestTimes(); mbr->GetOneBestTimes();
KALDI_ASSERT(conf.size() == words.size() && words.size() == times.size()); KALDI_ASSERT(conf.size() == words.size() && words.size() == times.size());
for (size_t i = 0; i < words.size(); i++) { for (size_t i = 0; i < words.size(); i++) {
KALDI_ASSERT(words[i] != 0); // Should not have epsilons. KALDI_ASSERT(words[i] != 0); // Should not have epsilons.
...@@ -111,12 +145,14 @@ int main(int argc, char *argv[]) { ...@@ -111,12 +145,14 @@ int main(int argc, char *argv[]) {
<< (frame_shift * (times[i].second-times[i].first)) << ' ' << (frame_shift * (times[i].second-times[i].first)) << ' '
<< words[i] << ' ' << conf[i] << '\n'; << words[i] << ' ' << conf[i] << '\n';
} }
KALDI_LOG << "For utterance " << key << ", Bayes Risk " << mbr.GetBayesRisk() KALDI_LOG << "For utterance " << key << ", Bayes Risk "
<< ", avg. confidence per-word " << mbr->GetBayesRisk() << ", avg. confidence per-word "
<< std::accumulate(conf.begin(),conf.end(),0.0) / words.size(); << std::accumulate(conf.begin(),conf.end(),0.0) / words.size();
n_done++; n_done++;
n_words += mbr.GetOneBest().size(); n_words += mbr->GetOneBest().size();
tot_bayes_risk += mbr.GetBayesRisk(); tot_bayes_risk += mbr->GetBayesRisk();
if (mbr != NULL)
delete mbr;
} }
KALDI_LOG << "Done " << n_done << " lattices."; KALDI_LOG << "Done " << n_done << " lattices.";
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment