Commit 019fa536 authored by Dan Povey's avatar Dan Povey
Browse files

Added binary for lattice generation (faster); extend and make consistent the usage of update flags.

git-svn-id: https://svn.code.sf.net/p/kaldi/code/trunk@465 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
parent fd45504b
......@@ -912,7 +912,7 @@ class LatticeFasterDecoder {
e_tail = e->tail;
toks_.Delete(e);
}
toks_.clear();
toks_.Clear();
}
void ClearActiveTokens() { // a cleanup routine, at utt end/begin
......
......@@ -26,6 +26,7 @@ GmmFlagsType StringToGmmFlags(std::string str) {
case 'm': flags |= kGmmMeans; break;
case 'v': flags |= kGmmVariances; break;
case 'w': flags |= kGmmWeights; break;
case 't': flags |= kGmmTransitions; break;
case 'a': flags |= kGmmAll; break;
default: KALDI_ERR << "Invalid element " << CharToString(*c)
<< " of GmmFlagsType option string "
......@@ -45,6 +46,7 @@ SgmmUpdateFlagsType StringToSgmmUpdateFlags(std::string str) {
case 'S': flags |= kSgmmCovarianceMatrix; break;
case 'c': flags |= kSgmmSubstateWeights; break;
case 'N': flags |= kSgmmSpeakerProjections; break;
case 't': flags |= kSgmmTransitions; break;
case 'a': flags |= kSgmmAll; break;
default: KALDI_ERR << "Invalid element " << CharToString(*c)
<< " of SgmmUpdateFlagsType option string "
......@@ -55,6 +57,24 @@ SgmmUpdateFlagsType StringToSgmmUpdateFlags(std::string str) {
}
SgmmUpdateFlagsType StringToSgmmWriteFlags(std::string str) {
SgmmWriteFlagsType flags = 0;
for (const char *c = str.c_str(); *c != '\0'; c++) {
switch (*c) {
case 'g': flags |= kSgmmGlobalParams; break;
case 's': flags |= kSgmmStateParams; break;
case 'n': flags |= kSgmmNormalizers; break;
case 'u': flags |= kSgmmBackgroundGmms; break;
case 'a': flags |= kSgmmAll; break;
default: KALDI_ERR << "Invalid element " << CharToString(*c)
<< " of SgmmWriteFlagsType option string "
<< str;
}
}
return flags;
}
} // End namespace kaldi
......@@ -22,10 +22,11 @@
namespace kaldi {
enum GmmUpdateFlags {
kGmmMeans = 0x001, // m
kGmmVariances = 0x002, // v
kGmmWeights = 0x004, // w
kGmmAll = 0x007 // a
kGmmMeans = 0x001, // m
kGmmVariances = 0x002, // v
kGmmWeights = 0x004, // w
kGmmTransitions = 0x008, // t ... not really part of GMM.
kGmmAll = 0x00F // a
};
typedef uint16 GmmFlagsType; ///< Bitwise OR of the above flags.
/// Convert string which is some subset of "mSwa" to
......@@ -39,7 +40,8 @@ enum SgmmUpdateFlags { /// The letters correspond to the variable names.
kSgmmCovarianceMatrix = 0x008, /// S
kSgmmSubstateWeights = 0x010, /// c
kSgmmSpeakerProjections = 0x020, /// N
kSgmmAll = 0x03F /// a (won't normally use this).
kSgmmTransitions = 0x040, /// t .. not really part of SGMM.
kSgmmAll = 0x07F /// a (won't normally use this).
};
typedef uint16 SgmmUpdateFlagsType; ///< Bitwise OR of the above flags.
......@@ -55,7 +57,7 @@ enum SgmmWriteFlags {
typedef uint16 SgmmWriteFlagsType; ///< Bitwise OR of the above flags.
SgmmUpdateFlagsType StringToSgmmWriteFlags(std::string str);
SgmmWriteFlagsType StringToSgmmWriteFlags(std::string str);
} // End namespace kaldi
......
......@@ -15,7 +15,7 @@ BINFILES = gmm-init-mono gmm-est gmm-acc-stats-ali gmm-align \
gmm-est-fmllr-gpost gmm-est-fmllr gmm-est-regtree-fmllr-ali \
gmm-est-regtree-mllr gmm-decode-kaldi gmm-compute-likes \
gmm-decode-faster-regtree-mllr gmm-et-apply-c gmm-latgen-simple \
gmm-rescore-lattice gmm-decode-biglm-faster
gmm-rescore-lattice gmm-decode-biglm-faster gmm-latgen-faster
OBJFILES =
......
// gmmbin/gmm-latgen-faster.cc
// Copyright 2009-2011 Microsoft Corporation
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "gmm/am-diag-gmm.h"
#include "tree/context-dep.h"
#include "hmm/transition-model.h"
#include "fstext/fstext-lib.h"
#include "decoder/lattice-faster-decoder.h"
#include "decoder/decodable-am-diag-gmm.h"
#include "util/timer.h"
namespace kaldi {
// Takes care of output. Returns total like.
double ProcessDecodedOutput(const LatticeFasterDecoder &decoder,
const fst::SymbolTable *word_syms,
std::string utt,
double acoustic_scale,
bool determinize,
Int32VectorWriter *alignment_writer,
Int32VectorWriter *words_writer,
CompactLatticeWriter *compact_lattice_writer,
LatticeWriter *lattice_writer) {
using fst::VectorFst;
double likelihood;
{ // First do some stuff with word-level traceback...
VectorFst<LatticeArc> decoded;
if (!decoder.GetBestPath(&decoded))
// Shouldn't really reach this point as already checked success.
KALDI_ERR << "Failed to get traceback for utterance " << utt;
std::vector<int32> alignment;
std::vector<int32> words;
LatticeWeight weight;
GetLinearSymbolSequence(decoded, &alignment, &words, &weight);
if (words_writer->IsOpen())
words_writer->Write(utt, words);
if (alignment_writer->IsOpen())
alignment_writer->Write(utt, alignment);
if (word_syms != NULL) {
std::cerr << utt << ' ';
for (size_t i = 0; i < words.size(); i++) {
std::string s = word_syms->Find(words[i]);
if (s == "")
KALDI_ERR << "Word-id " << words[i] <<" not in symbol table.";
std::cerr << s << ' ';
}
std::cerr << '\n';
}
likelihood = -(weight.Value1() + weight.Value2());
}
if (determinize) {
CompactLattice fst;
if (!decoder.GetLattice(&fst))
KALDI_ERR << "Unexpected problem getting lattice for utterance "
<< utt;
if (acoustic_scale != 0.0) // We'll write the lattice without acoustic scaling
fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &fst);
compact_lattice_writer->Write(utt, fst);
} else {
Lattice fst;
if (!decoder.GetRawLattice(&fst))
KALDI_ERR << "Unexpected problem getting lattice for utterance "
<< utt;
if (acoustic_scale != 0.0) // We'll write the lattice without acoustic scaling
fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &fst);
lattice_writer->Write(utt, fst);
}
return likelihood;
}
}
int main(int argc, char *argv[]) {
try {
using namespace kaldi;
typedef kaldi::int32 int32;
using fst::SymbolTable;
using fst::VectorFst;
using fst::StdArc;
const char *usage =
"Generate lattices using GMM-based model.\n"
"Usage: gmm-latgen-faster [options] model-in fst-in features-rspecifier"
" lattice-wspecifier [ words-wspecifier [alignments-wspecifier] ]\n";
ParseOptions po(usage);
Timer timer;
bool allow_partial = false;
BaseFloat acoustic_scale = 0.1;
LatticeFasterDecoderConfig config;
std::string word_syms_filename;
config.Register(&po);
po.Register("acoustic-scale", &acoustic_scale, "Scaling factor for acoustic likelihoods");
po.Register("word-symbol-table", &word_syms_filename, "Symbol table for words [for debug output]");
po.Register("allow-partial", &allow_partial, "If true, produce output even if end state was not reached.");
po.Read(argc, argv);
if (po.NumArgs() < 4 || po.NumArgs() > 6) {
po.PrintUsage();
exit(1);
}
std::string model_in_filename = po.GetArg(1),
fst_in_filename = po.GetArg(2),
feature_rspecifier = po.GetArg(3),
lattice_wspecifier = po.GetArg(4),
words_wspecifier = po.GetOptArg(5),
alignment_wspecifier = po.GetOptArg(6);
TransitionModel trans_model;
AmDiagGmm am_gmm;
{
bool binary;
Input is(model_in_filename, &binary);
trans_model.Read(is.Stream(), binary);
am_gmm.Read(is.Stream(), binary);
}
VectorFst<StdArc> *decode_fst = NULL;
{
std::ifstream is(fst_in_filename.c_str(), std::ifstream::binary);
if (!is.good()) KALDI_EXIT << "Could not open decoding-graph FST "
<< fst_in_filename;
decode_fst =
VectorFst<StdArc>::Read(is, fst::FstReadOptions(fst_in_filename));
if (decode_fst == NULL) // fst code will warn.
exit(1);
}
bool determinize = config.determinize_lattice;
CompactLatticeWriter compact_lattice_writer;
LatticeWriter lattice_writer;
if (! (determinize ? compact_lattice_writer.Open(lattice_wspecifier)
: lattice_writer.Open(lattice_wspecifier)))
KALDI_EXIT << "Could not open table for writing lattices: "
<< lattice_wspecifier;
Int32VectorWriter words_writer(words_wspecifier);
Int32VectorWriter alignment_writer(alignment_wspecifier);
fst::SymbolTable *word_syms = NULL;
if (word_syms_filename != "")
if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename)))
KALDI_EXIT << "Could not read symbol table from file "
<< word_syms_filename;
SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
BaseFloat tot_like = 0.0;
kaldi::int64 frame_count = 0;
int num_success = 0, num_fail = 0;
LatticeFasterDecoder decoder(*decode_fst, config);
for (; !feature_reader.Done(); feature_reader.Next()) {
std::string utt = feature_reader.Key();
Matrix<BaseFloat> features (feature_reader.Value());
feature_reader.FreeCurrent();
if (features.NumRows() == 0) {
KALDI_WARN << "Zero-length utterance: " << utt;
num_fail++;
continue;
}
DecodableAmDiagGmmScaled gmm_decodable(am_gmm, trans_model, features,
acoustic_scale);
if (!decoder.Decode(&gmm_decodable)) {
KALDI_WARN << "Failed to decode file " << utt;
num_fail++;
continue;
}
frame_count += features.NumRows();
double like;
if (!decoder.ReachedFinal()) {
if (allow_partial) {
KALDI_WARN << "Outputting partial output for utterance " << utt
<< " since no final-state reached\n";
} else {
KALDI_WARN << "Not producing output for utterance " << utt
<< " since no final-state reached and "
<< "--allow-partial=false.\n";
num_fail++;
continue;
}
}
like = ProcessDecodedOutput(decoder, word_syms, utt, acoustic_scale,
determinize, &alignment_writer, &words_writer,
&compact_lattice_writer, &lattice_writer);
tot_like += like;
KALDI_LOG << "Log-like per frame for utterance " << utt << " is "
<< (like / features.NumRows()) << " over "
<< features.NumRows() << " frames.";
num_success++;
}
double elapsed = timer.Elapsed();
KALDI_LOG << "Time taken "<< elapsed
<< "s: real-time factor assuming 100 frames/sec is "
<< (elapsed*100.0/frame_count);
KALDI_LOG << "Done " << num_success << " utterances, failed for "
<< num_fail;
KALDI_LOG << "Overall log-likelihood per frame is " << (tot_like/frame_count) << " over "
<< frame_count<<" frames.";
delete decode_fst;
if (word_syms) delete word_syms;
if (num_success != 0) return 0;
else return 1;
} catch(const std::exception& e) {
std::cerr << e.what();
return -1;
}
}
......@@ -27,7 +27,13 @@ $(BINFILES): ../lat/kaldi-lat.a ../matrix/kaldi-matrix.a ../util/kaldi-util.a ..
$(MAKE) -C ${@D} ${@F}
clean:
rm *.o *.a $(BINFILES)
rm *.o *.a $(TESTFILES) $(BINFILES)
test: $(TESTFILES)
for x in $(TESTFILES); do ./$$x >&/dev/null || (echo "***test $$x failed***"; exit 1); done
echo Tests succeeded
.valgrind: $(TESTFILES)
depend:
-$(CXX) -M $(CXXFLAGS) *.cc > .depend.mk
......
......@@ -35,7 +35,7 @@ int main(int argc, char *argv[]) {
ParseOptions po(usage);
bool binary = false;
std::string gselect_rspecifier, spkvecs_rspecifier, utt2spk_rspecifier;
std::string update_flags_str = "vMNwcS";
std::string update_flags_str = "vMNwcSt";
BaseFloat rand_prune = 1.0e-05;
kaldi::SgmmGselectConfig sgmm_opts;
po.Register("binary", &binary, "Write output in binary mode");
......
......@@ -37,7 +37,7 @@ int main(int argc, char *argv[]) {
ParseOptions po(usage);
bool binary = false;
std::string spkvecs_rspecifier, utt2spk_rspecifier;
std::string update_flags_str = "vMNwcS";
std::string update_flags_str = "vMNwcSt";
BaseFloat rand_prune = 1.0e-05;
po.Register("binary", &binary, "Write output in binary mode");
......
......@@ -38,7 +38,7 @@ int main(int argc, char *argv[]) {
ParseOptions po(usage);
bool binary = false;
std::string gselect_rspecifier, spkvecs_rspecifier, utt2spk_rspecifier;
std::string update_flags_str = "vMNwcS";
std::string update_flags_str = "vMNwcSt";
BaseFloat rand_prune = 1.0e-05;
SgmmGselectConfig sgmm_opts;
po.Register("binary", &binary, "Write output in binary mode");
......
......@@ -34,7 +34,8 @@ int main(int argc, char *argv[]) {
"Usage: sgmm-est [options] <model-in> <stats-in> <model-out>\n";
bool binary_write = false;
std::string update_flags_str = "vMNwcS";
std::string update_flags_str = "vMNwcSt";
std::string write_flags_str = "gsnu";
kaldi::TransitionUpdateConfig tcfg;
kaldi::MleAmSgmmOptions sgmm_opts;
int32 split_substates = 0;
......@@ -65,7 +66,9 @@ int main(int argc, char *argv[]) {
po.Register("write-occs", &occs_out_filename, "File to write state "
"occupancies to.");
po.Register("update-flags", &update_flags_str, "Which SGMM parameters to "
"update: subset of vMNwcS.");
"update: subset of vMNwcSt.");
po.Register("write-flags", &write_flags_str, "Which SGMM parameters to "
"write: subset of gsnu");
tcfg.Register(&po);
sgmm_opts.Register(&po);
......@@ -78,8 +81,11 @@ int main(int argc, char *argv[]) {
stats_filename = po.GetArg(2),
model_out_filename = po.GetArg(3);
kaldi::SgmmUpdateFlagsType acc_flags = StringToSgmmUpdateFlags(update_flags_str);
kaldi::SgmmUpdateFlagsType update_flags =
StringToSgmmUpdateFlags(update_flags_str);
kaldi::SgmmWriteFlagsType write_flags =
StringToSgmmWriteFlags(write_flags_str);
AmSgmm am_sgmm;
TransitionModel trans_model;
{
......@@ -98,7 +104,7 @@ int main(int argc, char *argv[]) {
sgmm_accs.Read(is.Stream(), binary, true); // true == add; doesn't matter here.
}
{ // Update transition model.
if (update_flags & kSgmmTransitions) { // Update transition model.
BaseFloat objf_impr, count;
trans_model.Update(transition_accs, tcfg, &objf_impr, &count);
KALDI_LOG << "Transition model update: average " << (objf_impr/count)
......@@ -110,7 +116,7 @@ int main(int argc, char *argv[]) {
{ // Update SGMM.
kaldi::MleAmSgmmUpdater sgmm_updater(sgmm_opts);
sgmm_updater.Update(sgmm_accs, &am_sgmm, acc_flags);
sgmm_updater.Update(sgmm_accs, &am_sgmm, update_flags);
}
if (split_substates != 0 || !occs_out_filename.empty()) { // get state occs
......@@ -147,7 +153,7 @@ int main(int argc, char *argv[]) {
{
Output os(model_out_filename, binary_write);
trans_model.Write(os.Stream(), binary_write);
am_sgmm.Write(os.Stream(), binary_write, kSgmmWriteAll);
am_sgmm.Write(os.Stream(), binary_write, write_flags);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment