Commit aa73da51 authored by Dan Povey's avatar Dan Povey
Browse files

Merging changes from trunk; extending usage of flags to tied programs.

git-svn-id: https://svn.code.sf.net/p/kaldi/code/sandbox/discrim@467 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
parent 48c72043
...@@ -912,7 +912,7 @@ class LatticeFasterDecoder { ...@@ -912,7 +912,7 @@ class LatticeFasterDecoder {
e_tail = e->tail; e_tail = e->tail;
toks_.Delete(e); toks_.Delete(e);
} }
toks_.clear(); toks_.Clear();
} }
void ClearActiveTokens() { // a cleanup routine, at utt end/begin void ClearActiveTokens() { // a cleanup routine, at utt end/begin
......
...@@ -26,6 +26,7 @@ GmmFlagsType StringToGmmFlags(std::string str) { ...@@ -26,6 +26,7 @@ GmmFlagsType StringToGmmFlags(std::string str) {
case 'm': flags |= kGmmMeans; break; case 'm': flags |= kGmmMeans; break;
case 'v': flags |= kGmmVariances; break; case 'v': flags |= kGmmVariances; break;
case 'w': flags |= kGmmWeights; break; case 'w': flags |= kGmmWeights; break;
case 't': flags |= kGmmTransitions; break;
case 'a': flags |= kGmmAll; break; case 'a': flags |= kGmmAll; break;
default: KALDI_ERR << "Invalid element " << CharToString(*c) default: KALDI_ERR << "Invalid element " << CharToString(*c)
<< " of GmmFlagsType option string " << " of GmmFlagsType option string "
...@@ -53,6 +54,7 @@ SgmmUpdateFlagsType StringToSgmmUpdateFlags(std::string str) { ...@@ -53,6 +54,7 @@ SgmmUpdateFlagsType StringToSgmmUpdateFlags(std::string str) {
case 'S': flags |= kSgmmCovarianceMatrix; break; case 'S': flags |= kSgmmCovarianceMatrix; break;
case 'c': flags |= kSgmmSubstateWeights; break; case 'c': flags |= kSgmmSubstateWeights; break;
case 'N': flags |= kSgmmSpeakerProjections; break; case 'N': flags |= kSgmmSpeakerProjections; break;
case 't': flags |= kSgmmTransitions; break;
case 'a': flags |= kSgmmAll; break; case 'a': flags |= kSgmmAll; break;
default: KALDI_ERR << "Invalid element " << CharToString(*c) default: KALDI_ERR << "Invalid element " << CharToString(*c)
<< " of SgmmUpdateFlagsType option string " << " of SgmmUpdateFlagsType option string "
...@@ -63,6 +65,24 @@ SgmmUpdateFlagsType StringToSgmmUpdateFlags(std::string str) { ...@@ -63,6 +65,24 @@ SgmmUpdateFlagsType StringToSgmmUpdateFlags(std::string str) {
} }
SgmmUpdateFlagsType StringToSgmmWriteFlags(std::string str) {
SgmmWriteFlagsType flags = 0;
for (const char *c = str.c_str(); *c != '\0'; c++) {
switch (*c) {
case 'g': flags |= kSgmmGlobalParams; break;
case 's': flags |= kSgmmStateParams; break;
case 'n': flags |= kSgmmNormalizers; break;
case 'u': flags |= kSgmmBackgroundGmms; break;
case 'a': flags |= kSgmmAll; break;
default: KALDI_ERR << "Invalid element " << CharToString(*c)
<< " of SgmmWriteFlagsType option string "
<< str;
}
}
return flags;
}
} // End namespace kaldi } // End namespace kaldi
...@@ -22,10 +22,11 @@ ...@@ -22,10 +22,11 @@
namespace kaldi { namespace kaldi {
enum GmmUpdateFlags { enum GmmUpdateFlags {
kGmmMeans = 0x001, // m kGmmMeans = 0x001, // m
kGmmVariances = 0x002, // v kGmmVariances = 0x002, // v
kGmmWeights = 0x004, // w kGmmWeights = 0x004, // w
kGmmAll = 0x007 // a kGmmTransitions = 0x008, // t ... not really part of GMM.
kGmmAll = 0x00F // a
}; };
typedef uint16 GmmFlagsType; ///< Bitwise OR of the above flags. typedef uint16 GmmFlagsType; ///< Bitwise OR of the above flags.
/// Convert string which is some subset of "mSwa" to /// Convert string which is some subset of "mSwa" to
...@@ -43,7 +44,8 @@ enum SgmmUpdateFlags { /// The letters correspond to the variable names. ...@@ -43,7 +44,8 @@ enum SgmmUpdateFlags { /// The letters correspond to the variable names.
kSgmmCovarianceMatrix = 0x008, /// S kSgmmCovarianceMatrix = 0x008, /// S
kSgmmSubstateWeights = 0x010, /// c kSgmmSubstateWeights = 0x010, /// c
kSgmmSpeakerProjections = 0x020, /// N kSgmmSpeakerProjections = 0x020, /// N
kSgmmAll = 0x03F /// a (won't normally use this). kSgmmTransitions = 0x040, /// t .. not really part of SGMM.
kSgmmAll = 0x07F /// a (won't normally use this).
}; };
typedef uint16 SgmmUpdateFlagsType; ///< Bitwise OR of the above flags. typedef uint16 SgmmUpdateFlagsType; ///< Bitwise OR of the above flags.
...@@ -59,7 +61,7 @@ enum SgmmWriteFlags { ...@@ -59,7 +61,7 @@ enum SgmmWriteFlags {
typedef uint16 SgmmWriteFlagsType; ///< Bitwise OR of the above flags. typedef uint16 SgmmWriteFlagsType; ///< Bitwise OR of the above flags.
SgmmUpdateFlagsType StringToSgmmWriteFlags(std::string str); SgmmWriteFlagsType StringToSgmmWriteFlags(std::string str);
} // End namespace kaldi } // End namespace kaldi
......
...@@ -15,9 +15,9 @@ BINFILES = gmm-init-mono gmm-est gmm-acc-stats-ali gmm-align \ ...@@ -15,9 +15,9 @@ BINFILES = gmm-init-mono gmm-est gmm-acc-stats-ali gmm-align \
gmm-est-fmllr-gpost gmm-est-fmllr gmm-est-regtree-fmllr-ali \ gmm-est-fmllr-gpost gmm-est-fmllr gmm-est-regtree-fmllr-ali \
gmm-est-regtree-mllr gmm-decode-kaldi gmm-compute-likes \ gmm-est-regtree-mllr gmm-decode-kaldi gmm-compute-likes \
gmm-decode-faster-regtree-mllr gmm-et-apply-c gmm-latgen-simple \ gmm-decode-faster-regtree-mllr gmm-et-apply-c gmm-latgen-simple \
gmm-rescore-lattice gmm-decode-biglm-faster fmpe-gmm-model-diffs-est \ gmm-rescore-lattice gmm-decode-biglm-faster fmpe-gmm-model-diffs-est \
fmpe-gmm-acc-stats-gpost fmpe-gmm-sum-accs fmpe-init-gmms fmpe-gmm-est \ fmpe-gmm-acc-stats-gpost fmpe-gmm-sum-accs fmpe-init-gmms fmpe-gmm-est \
gmm-est-mmi gmm-est-mmi gmm-latgen-faster
OBJFILES = OBJFILES =
......
...@@ -39,6 +39,7 @@ int main(int argc, char *argv[]) { ...@@ -39,6 +39,7 @@ int main(int argc, char *argv[]) {
int32 mixdown = 0; int32 mixdown = 0;
BaseFloat perturb_factor = 0.01; BaseFloat perturb_factor = 0.01;
BaseFloat power = 0.2; BaseFloat power = 0.2;
std::string update_flags_str = "mvwt";
std::string occs_out_filename; std::string occs_out_filename;
...@@ -49,11 +50,13 @@ int main(int argc, char *argv[]) { ...@@ -49,11 +50,13 @@ int main(int argc, char *argv[]) {
po.Register("mix-down", &mixdown, "If nonzero, merge mixture components to this " po.Register("mix-down", &mixdown, "If nonzero, merge mixture components to this "
"target."); "target.");
po.Register("power", &power, "If mixing up, power to allocate Gaussians to" po.Register("power", &power, "If mixing up, power to allocate Gaussians to"
" states."); " states.");
po.Register("update-flags", &update_flags_str, "Which GMM parameters to "
"update: subset of mvwt.");
po.Register("perturb-factor", &perturb_factor, "While mixing up, perturb " po.Register("perturb-factor", &perturb_factor, "While mixing up, perturb "
"means by standard deviation times this factor."); "means by standard deviation times this factor.");
po.Register("write-occs", &occs_out_filename, "File to write state " po.Register("write-occs", &occs_out_filename, "File to write state "
"occupancies to."); "occupancies to.");
tcfg.Register(&po); tcfg.Register(&po);
gmm_opts.Register(&po); gmm_opts.Register(&po);
...@@ -64,6 +67,8 @@ int main(int argc, char *argv[]) { ...@@ -64,6 +67,8 @@ int main(int argc, char *argv[]) {
exit(1); exit(1);
} }
kaldi::GmmFlagsType update_flags =
StringToGmmFlags(update_flags_str);
std::string model_in_filename = po.GetArg(1), std::string model_in_filename = po.GetArg(1),
stats_filename = po.GetArg(2), stats_filename = po.GetArg(2),
...@@ -87,7 +92,7 @@ int main(int argc, char *argv[]) { ...@@ -87,7 +92,7 @@ int main(int argc, char *argv[]) {
gmm_accs.Read(is.Stream(), binary, true); // true == add; doesn't matter here. gmm_accs.Read(is.Stream(), binary, true); // true == add; doesn't matter here.
} }
{ // Update transition model. if (update_flags & kGmmTransitions) { // Update transition model.
BaseFloat objf_impr, count; BaseFloat objf_impr, count;
trans_model.Update(transition_accs, tcfg, &objf_impr, &count); trans_model.Update(transition_accs, tcfg, &objf_impr, &count);
KALDI_LOG << "Transition model update: average " << (objf_impr/count) KALDI_LOG << "Transition model update: average " << (objf_impr/count)
...@@ -97,7 +102,7 @@ int main(int argc, char *argv[]) { ...@@ -97,7 +102,7 @@ int main(int argc, char *argv[]) {
{ // Update GMMs. { // Update GMMs.
BaseFloat objf_impr, count; BaseFloat objf_impr, count;
MleAmDiagGmmUpdate(gmm_opts, gmm_accs, kGmmAll, &am_gmm, &objf_impr, &count); MleAmDiagGmmUpdate(gmm_opts, gmm_accs, update_flags, &am_gmm, &objf_impr, &count);
KALDI_LOG << "GMM update: average " << (objf_impr/count) KALDI_LOG << "GMM update: average " << (objf_impr/count)
<< " objective function improvement per frame over " << " objective function improvement per frame over "
<< (count) << " frames."; << (count) << " frames.";
......
// gmmbin/gmm-latgen-faster.cc
// Copyright 2009-2011 Microsoft Corporation
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "gmm/am-diag-gmm.h"
#include "tree/context-dep.h"
#include "hmm/transition-model.h"
#include "fstext/fstext-lib.h"
#include "decoder/lattice-faster-decoder.h"
#include "decoder/decodable-am-diag-gmm.h"
#include "util/timer.h"
namespace kaldi {
// Takes care of output. Returns total like.
double ProcessDecodedOutput(const LatticeFasterDecoder &decoder,
const fst::SymbolTable *word_syms,
std::string utt,
double acoustic_scale,
bool determinize,
Int32VectorWriter *alignment_writer,
Int32VectorWriter *words_writer,
CompactLatticeWriter *compact_lattice_writer,
LatticeWriter *lattice_writer) {
using fst::VectorFst;
double likelihood;
{ // First do some stuff with word-level traceback...
VectorFst<LatticeArc> decoded;
if (!decoder.GetBestPath(&decoded))
// Shouldn't really reach this point as already checked success.
KALDI_ERR << "Failed to get traceback for utterance " << utt;
std::vector<int32> alignment;
std::vector<int32> words;
LatticeWeight weight;
GetLinearSymbolSequence(decoded, &alignment, &words, &weight);
if (words_writer->IsOpen())
words_writer->Write(utt, words);
if (alignment_writer->IsOpen())
alignment_writer->Write(utt, alignment);
if (word_syms != NULL) {
std::cerr << utt << ' ';
for (size_t i = 0; i < words.size(); i++) {
std::string s = word_syms->Find(words[i]);
if (s == "")
KALDI_ERR << "Word-id " << words[i] <<" not in symbol table.";
std::cerr << s << ' ';
}
std::cerr << '\n';
}
likelihood = -(weight.Value1() + weight.Value2());
}
if (determinize) {
CompactLattice fst;
if (!decoder.GetLattice(&fst))
KALDI_ERR << "Unexpected problem getting lattice for utterance "
<< utt;
if (acoustic_scale != 0.0) // We'll write the lattice without acoustic scaling
fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &fst);
compact_lattice_writer->Write(utt, fst);
} else {
Lattice fst;
if (!decoder.GetRawLattice(&fst))
KALDI_ERR << "Unexpected problem getting lattice for utterance "
<< utt;
if (acoustic_scale != 0.0) // We'll write the lattice without acoustic scaling
fst::ScaleLattice(fst::AcousticLatticeScale(1.0 / acoustic_scale), &fst);
lattice_writer->Write(utt, fst);
}
return likelihood;
}
}
int main(int argc, char *argv[]) {
try {
using namespace kaldi;
typedef kaldi::int32 int32;
using fst::SymbolTable;
using fst::VectorFst;
using fst::StdArc;
const char *usage =
"Generate lattices using GMM-based model.\n"
"Usage: gmm-latgen-faster [options] model-in fst-in features-rspecifier"
" lattice-wspecifier [ words-wspecifier [alignments-wspecifier] ]\n";
ParseOptions po(usage);
Timer timer;
bool allow_partial = false;
BaseFloat acoustic_scale = 0.1;
LatticeFasterDecoderConfig config;
std::string word_syms_filename;
config.Register(&po);
po.Register("acoustic-scale", &acoustic_scale, "Scaling factor for acoustic likelihoods");
po.Register("word-symbol-table", &word_syms_filename, "Symbol table for words [for debug output]");
po.Register("allow-partial", &allow_partial, "If true, produce output even if end state was not reached.");
po.Read(argc, argv);
if (po.NumArgs() < 4 || po.NumArgs() > 6) {
po.PrintUsage();
exit(1);
}
std::string model_in_filename = po.GetArg(1),
fst_in_filename = po.GetArg(2),
feature_rspecifier = po.GetArg(3),
lattice_wspecifier = po.GetArg(4),
words_wspecifier = po.GetOptArg(5),
alignment_wspecifier = po.GetOptArg(6);
TransitionModel trans_model;
AmDiagGmm am_gmm;
{
bool binary;
Input is(model_in_filename, &binary);
trans_model.Read(is.Stream(), binary);
am_gmm.Read(is.Stream(), binary);
}
VectorFst<StdArc> *decode_fst = NULL;
{
std::ifstream is(fst_in_filename.c_str(), std::ifstream::binary);
if (!is.good()) KALDI_EXIT << "Could not open decoding-graph FST "
<< fst_in_filename;
decode_fst =
VectorFst<StdArc>::Read(is, fst::FstReadOptions(fst_in_filename));
if (decode_fst == NULL) // fst code will warn.
exit(1);
}
bool determinize = config.determinize_lattice;
CompactLatticeWriter compact_lattice_writer;
LatticeWriter lattice_writer;
if (! (determinize ? compact_lattice_writer.Open(lattice_wspecifier)
: lattice_writer.Open(lattice_wspecifier)))
KALDI_EXIT << "Could not open table for writing lattices: "
<< lattice_wspecifier;
Int32VectorWriter words_writer(words_wspecifier);
Int32VectorWriter alignment_writer(alignment_wspecifier);
fst::SymbolTable *word_syms = NULL;
if (word_syms_filename != "")
if (!(word_syms = fst::SymbolTable::ReadText(word_syms_filename)))
KALDI_EXIT << "Could not read symbol table from file "
<< word_syms_filename;
SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
BaseFloat tot_like = 0.0;
kaldi::int64 frame_count = 0;
int num_success = 0, num_fail = 0;
LatticeFasterDecoder decoder(*decode_fst, config);
for (; !feature_reader.Done(); feature_reader.Next()) {
std::string utt = feature_reader.Key();
Matrix<BaseFloat> features (feature_reader.Value());
feature_reader.FreeCurrent();
if (features.NumRows() == 0) {
KALDI_WARN << "Zero-length utterance: " << utt;
num_fail++;
continue;
}
DecodableAmDiagGmmScaled gmm_decodable(am_gmm, trans_model, features,
acoustic_scale);
if (!decoder.Decode(&gmm_decodable)) {
KALDI_WARN << "Failed to decode file " << utt;
num_fail++;
continue;
}
frame_count += features.NumRows();
double like;
if (!decoder.ReachedFinal()) {
if (allow_partial) {
KALDI_WARN << "Outputting partial output for utterance " << utt
<< " since no final-state reached\n";
} else {
KALDI_WARN << "Not producing output for utterance " << utt
<< " since no final-state reached and "
<< "--allow-partial=false.\n";
num_fail++;
continue;
}
}
like = ProcessDecodedOutput(decoder, word_syms, utt, acoustic_scale,
determinize, &alignment_writer, &words_writer,
&compact_lattice_writer, &lattice_writer);
tot_like += like;
KALDI_LOG << "Log-like per frame for utterance " << utt << " is "
<< (like / features.NumRows()) << " over "
<< features.NumRows() << " frames.";
num_success++;
}
double elapsed = timer.Elapsed();
KALDI_LOG << "Time taken "<< elapsed
<< "s: real-time factor assuming 100 frames/sec is "
<< (elapsed*100.0/frame_count);
KALDI_LOG << "Done " << num_success << " utterances, failed for "
<< num_fail;
KALDI_LOG << "Overall log-likelihood per frame is " << (tot_like/frame_count) << " over "
<< frame_count<<" frames.";
delete decode_fst;
if (word_syms) delete word_syms;
if (num_success != 0) return 0;
else return 1;
} catch(const std::exception& e) {
std::cerr << e.what();
return -1;
}
}
...@@ -35,7 +35,7 @@ int main(int argc, char *argv[]) { ...@@ -35,7 +35,7 @@ int main(int argc, char *argv[]) {
ParseOptions po(usage); ParseOptions po(usage);
bool binary = false; bool binary = false;
std::string gselect_rspecifier, spkvecs_rspecifier, utt2spk_rspecifier; std::string gselect_rspecifier, spkvecs_rspecifier, utt2spk_rspecifier;
std::string update_flags_str = "vMNwcS"; std::string update_flags_str = "vMNwcSt";
BaseFloat rand_prune = 1.0e-05; BaseFloat rand_prune = 1.0e-05;
kaldi::SgmmGselectConfig sgmm_opts; kaldi::SgmmGselectConfig sgmm_opts;
po.Register("binary", &binary, "Write output in binary mode"); po.Register("binary", &binary, "Write output in binary mode");
......
...@@ -37,7 +37,7 @@ int main(int argc, char *argv[]) { ...@@ -37,7 +37,7 @@ int main(int argc, char *argv[]) {
ParseOptions po(usage); ParseOptions po(usage);
bool binary = false; bool binary = false;
std::string spkvecs_rspecifier, utt2spk_rspecifier; std::string spkvecs_rspecifier, utt2spk_rspecifier;
std::string update_flags_str = "vMNwcS"; std::string update_flags_str = "vMNwcSt";
BaseFloat rand_prune = 1.0e-05; BaseFloat rand_prune = 1.0e-05;
po.Register("binary", &binary, "Write output in binary mode"); po.Register("binary", &binary, "Write output in binary mode");
......
...@@ -38,7 +38,7 @@ int main(int argc, char *argv[]) { ...@@ -38,7 +38,7 @@ int main(int argc, char *argv[]) {
ParseOptions po(usage); ParseOptions po(usage);
bool binary = false; bool binary = false;
std::string gselect_rspecifier, spkvecs_rspecifier, utt2spk_rspecifier; std::string gselect_rspecifier, spkvecs_rspecifier, utt2spk_rspecifier;
std::string update_flags_str = "vMNwcS"; std::string update_flags_str = "vMNwcSt";
BaseFloat rand_prune = 1.0e-05; BaseFloat rand_prune = 1.0e-05;
SgmmGselectConfig sgmm_opts; SgmmGselectConfig sgmm_opts;
po.Register("binary", &binary, "Write output in binary mode"); po.Register("binary", &binary, "Write output in binary mode");
......
...@@ -34,7 +34,8 @@ int main(int argc, char *argv[]) { ...@@ -34,7 +34,8 @@ int main(int argc, char *argv[]) {
"Usage: sgmm-est [options] <model-in> <stats-in> <model-out>\n"; "Usage: sgmm-est [options] <model-in> <stats-in> <model-out>\n";
bool binary_write = false; bool binary_write = false;
std::string update_flags_str = "vMNwcS"; std::string update_flags_str = "vMNwcSt";
std::string write_flags_str = "gsnu";
kaldi::TransitionUpdateConfig tcfg; kaldi::TransitionUpdateConfig tcfg;
kaldi::MleAmSgmmOptions sgmm_opts; kaldi::MleAmSgmmOptions sgmm_opts;
int32 split_substates = 0; int32 split_substates = 0;
...@@ -65,7 +66,9 @@ int main(int argc, char *argv[]) { ...@@ -65,7 +66,9 @@ int main(int argc, char *argv[]) {
po.Register("write-occs", &occs_out_filename, "File to write state " po.Register("write-occs", &occs_out_filename, "File to write state "
"occupancies to."); "occupancies to.");
po.Register("update-flags", &update_flags_str, "Which SGMM parameters to " po.Register("update-flags", &update_flags_str, "Which SGMM parameters to "
"update: subset of vMNwcS."); "update: subset of vMNwcSt.");
po.Register("write-flags", &write_flags_str, "Which SGMM parameters to "
"write: subset of gsnu");
tcfg.Register(&po); tcfg.Register(&po);
sgmm_opts.Register(&po); sgmm_opts.Register(&po);
...@@ -78,8 +81,11 @@ int main(int argc, char *argv[]) { ...@@ -78,8 +81,11 @@ int main(int argc, char *argv[]) {
stats_filename = po.GetArg(2), stats_filename = po.GetArg(2),
model_out_filename = po.GetArg(3); model_out_filename = po.GetArg(3);
kaldi::SgmmUpdateFlagsType acc_flags = StringToSgmmUpdateFlags(update_flags_str); kaldi::SgmmUpdateFlagsType update_flags =
StringToSgmmUpdateFlags(update_flags_str);
kaldi::SgmmWriteFlagsType write_flags =
StringToSgmmWriteFlags(write_flags_str);
AmSgmm am_sgmm; AmSgmm am_sgmm;
TransitionModel trans_model; TransitionModel trans_model;
{ {
...@@ -98,7 +104,7 @@ int main(int argc, char *argv[]) { ...@@ -98,7 +104,7 @@ int main(int argc, char *argv[]) {
sgmm_accs.Read(is.Stream(), binary, true); // true == add; doesn't matter here. sgmm_accs.Read(is.Stream(), binary, true); // true == add; doesn't matter here.
} }
{ // Update transition model. if (update_flags & kSgmmTransitions) { // Update transition model.
BaseFloat objf_impr, count; BaseFloat objf_impr, count;
trans_model.Update(transition_accs, tcfg, &objf_impr, &count); trans_model.Update(transition_accs, tcfg, &objf_impr, &count);
KALDI_LOG << "Transition model update: average " << (objf_impr/count) KALDI_LOG << "Transition model update: average " << (objf_impr/count)
...@@ -110,7 +116,7 @@ int main(int argc, char *argv[]) { ...@@ -110,7 +116,7 @@ int main(int argc, char *argv[]) {
{ // Update SGMM. { // Update SGMM.
kaldi::MleAmSgmmUpdater sgmm_updater(sgmm_opts); kaldi::MleAmSgmmUpdater sgmm_updater(sgmm_opts);
sgmm_updater.Update(sgmm_accs, &am_sgmm, acc_flags); sgmm_updater.Update(sgmm_accs, &am_sgmm, update_flags);
} }
if (split_substates != 0 || !occs_out_filename.empty()) { // get state occs if (split_substates != 0 || !occs_out_filename.empty()) { // get state occs
...@@ -147,7 +153,7 @@ int main(int argc, char *argv[]) { ...@@ -147,7 +153,7 @@ int main(int argc, char *argv[]) {
{ {
Output os(model_out_filename, binary_write); Output os(model_out_filename, binary_write);
trans_model.Write(os.Stream(), binary_write); trans_model.Write(os.Stream(), binary_write);
am_sgmm.Write(os.Stream(), binary_write, kSgmmWriteAll); am_sgmm.Write(os.Stream(), binary_write, write_flags);
} }
......
...@@ -35,6 +35,7 @@ int main(int argc, char *argv[]) { ...@@ -35,6 +35,7 @@ int main(int argc, char *argv[]) {
"e.g.: tied-diag-gmm-est 1.mdl 1.acc 2.mdl\n"; "e.g.: tied-diag-gmm-est 1.mdl 1.acc 2.mdl\n";
bool binary_write = false; bool binary_write = false;
std::string update_flags_str = "mvwt";
TransitionUpdateConfig tcfg; TransitionUpdateConfig tcfg;
std::string occs_out_filename; std::string occs_out_filename;
...@@ -42,6 +43,8 @@ int main(int argc, char *argv[]) { ...@@ -42,6 +43,8 @@ int main(int argc, char *argv[]) {
po.Register("binary", &binary_write, "Write output in binary mode"); po.Register("binary", &binary_write, "Write output in binary mode");
po.Register("write-occs", &occs_out_filename, "File to write state " po.Register("write-occs", &occs_out_filename, "File to write state "
"occupancies to."); "occupancies to.");
po.Register("update-flags", &update_flags_str, "Which GMM parameters to "
"update: subset of mvwt.");
tcfg.Register(&po); tcfg.Register(&po);
gmm_opts.Register(&po); gmm_opts.Register(&po);
tied_opts.Register(&po);