Commit 0c7a40d6 authored by Dan Povey's avatar Dan Povey
Browse files

Added sgmm-est-fmllr; various minor fixes.

git-svn-id: https://svn.code.sf.net/p/kaldi/code/trunk@457 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
parent 6fca69cd
......@@ -8,6 +8,7 @@
minor:
remove KALDI_EXIT and modify KALDI_ERR macro to behave like it.
Make ApplySoftMax() function more efficient (does two exp()'s per element)
possibly change the way lattice-compose works to remove properties hack.
gmm-latgen-faster [Not started]. Need to write internal code for this (Mirko, but possibly wait a bit till I finalize lattice-simple-decoder.). As gmm-latgen-simple but not using std::unordered_map for current + previous frame's tokens-- would use the HashList stuff.
......
......@@ -134,15 +134,24 @@ inline vector<vector<double> > AcousticLatticeScale(double acwt) {
return ans;
}
inline vector<vector<double> > GraphLatticeScale(double acwt) {
inline vector<vector<double> > GraphLatticeScale(double lmwt) {
vector<vector<double> > ans(2);
ans[0].resize(2, 0.0);
ans[1].resize(2, 0.0);
ans[0][0] = acwt;
ans[0][0] = lmwt;
ans[1][1] = 1.0;
return ans;
}
inline vector<vector<double> > LatticeScale(double lmwt, double acwt) {
vector<vector<double> > ans(2);
ans[0].resize(2, 0.0);
ans[1].resize(2, 0.0);
ans[0][0] = lmwt;
ans[1][1] = acwt;
return ans;
}
/** Scales the pairs of weights in LatticeWeight or CompactLatticeWeight by
viewing the pair (a, b) as a 2-vector and pre-multiplying by the 2x2 matrix
......
......@@ -37,16 +37,17 @@ void LatticeAcousticRescore(const AmDiagGmm& am,
KALDI_ERR << "Input lattice must be topologically sorted.";
KALDI_ASSERT(!state_times.empty());
int32 max_time = *std::max_element(state_times.begin(), state_times.end());
KALDI_ASSERT(max_time > 0);
std::vector<std::vector<int32> > time_to_state(max_time+1);
std::vector<std::vector<int32> > time_to_state(data.NumRows());
for (size_t i = 0; i < state_times.size(); i++) {
KALDI_ASSERT(state_times[i] >= 0);
time_to_state[state_times[i]].push_back(i);
if (state_times[i] < data.NumRows()) // end state may be past this..
time_to_state[state_times[i]].push_back(i);
else
KALDI_ASSERT(state_times[i] == data.NumRows()
&& "There appears to be lattice/feature mismatch.");
}
for (int32 t = 0; t <= max_time; t++) {
for (int32 t = 0; t < data.NumRows(); t++) {
unordered_map<int32, BaseFloat> pdf_id_to_like;
for (size_t i = 0; i < time_to_state[t].size(); i++) {
int32 state = time_to_state[t][i];
......@@ -84,9 +85,9 @@ int main(int argc, char *argv[]) {
const char *usage =
"Replace the acoustic scores on a lattice using a new model.\n"
"Usage: gmm-resocre-lattice [options] <model-in> <lattice-rspecifier> "
"Usage: gmm-rescore-lattice [options] <model-in> <lattice-rspecifier> "
"<feature-rspecifier> <lattice-wspecifier>\n"
" e.g.: gmm-resocre-lattice 1.mdl ark:1.lats scp:trn.scp ark:2.lats\n";
" e.g.: gmm-rescore-lattice 1.mdl ark:1.lats scp:trn.scp ark:2.lats\n";
kaldi::BaseFloat old_acoustic_scale = 0.0;
kaldi::ParseOptions po(usage);
......@@ -135,12 +136,8 @@ int main(int argc, char *argv[]) {
kaldi::uint64 props = lat.Properties(fst::kFstProperties, false);
if (!(props & fst::kTopSorted)) {
KALDI_WARN << "Utterance " << key << ": Supplied lattice not "
<< "topologically sorted. Sorting it.";
if (fst::TopSort(&lat) == false)
KALDI_ERR << "Cycles detected in lattice.";
} else {
KALDI_LOG << "Already topologically sorted.";
}
vector<int32> state_times;
......
......@@ -34,10 +34,12 @@ int main(int argc, char *argv[]) {
"Usage: lattice-to-post [options] lats-rspecifier posts-wspecifier\n"
" e.g.: lattice-to-post --acoustic-scale=0.1 ark:1.lats ark:1.post\n";
kaldi::BaseFloat acoustic_scale = 1.0;
kaldi::BaseFloat acoustic_scale = 1.0, lm_scale = 1.0;
kaldi::ParseOptions po(usage);
po.Register("acoustic-scale", &acoustic_scale,
"Scaling factor for acoustic likelihoods");
po.Register("lm-scale", &lm_scale,
"Scaling factor for \"graph costs\" (including LM costs)");
po.Read(argc, argv);
if (po.NumArgs() != 2) {
......@@ -64,9 +66,9 @@ int main(int argc, char *argv[]) {
std::string key = lattice_reader.Key();
kaldi::Lattice lat = lattice_reader.Value();
lattice_reader.FreeCurrent();
if (acoustic_scale != 1.0)
fst::ScaleLattice(fst::AcousticLatticeScale(acoustic_scale), &lat);
if (acoustic_scale != 1.0 || lm_scale != 1.0)
fst::ScaleLattice(fst::LatticeScale(lm_scale, acoustic_scale), &lat);
kaldi::uint64 props = lat.Properties(fst::kFstProperties, false);
if (!(props & fst::kTopSorted)) {
if (fst::TopSort(&lat) == false)
......
......@@ -5,7 +5,7 @@ include ../kaldi.mk
BINFILES = init-ubm sgmm-align sgmm-align-compiled sgmm-acc-stats-ali \
sgmm-sum-accs sgmm-est sgmm-decode-faster sgmm-init sgmm-gselect \
sgmm-acc-stats sgmm-est-spkvecs sgmm-post-to-gpost \
sgmm-est-fmllr sgmm-acc-stats sgmm-est-spkvecs sgmm-post-to-gpost \
sgmm-acc-stats-gpost sgmm-est-spkvecs-gpost sgmm-comp-prexform \
sgmm-est-fmllr-gpost sgmm-acc-fmllrbasis-ali sgmm-est-fmllrbasis \
sgmm-calc-distances sgmm-normalize sgmm-latgen-simple \
......
......@@ -114,9 +114,7 @@ int main(int argc, char *argv[]) {
SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
RandomAccessInt32VectorReader transcript_reader(transcript_rspecifier);
RandomAccessInt32VectorVectorReader gselect_reader;
if (!gselect_rspecifier.empty() && !gselect_reader.Open(gselect_rspecifier))
KALDI_ERR << "Unable to open stream for gaussian-selection indices";
RandomAccessInt32VectorVectorReader gselect_reader(gselect_rspecifier);
RandomAccessBaseFloatVectorReader spkvecs_reader(spkvecs_rspecifier);
RandomAccessTokenReader utt2spk_reader(utt2spk_rspecifier);
......
......@@ -23,7 +23,7 @@ using std::vector;
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "sgmm/am-sgmm.h"
# include "sgmm/fmllr-sgmm.h"
#include "sgmm/fmllr-sgmm.h"
#include "hmm/transition-model.h"
namespace kaldi {
......
// sgmmbin/sgmm-est-fmllr.cc
// Copyright 2009-2011 Saarland University; Microsoft Corporation
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include <string>
using std::string;
#include <vector>
using std::vector;
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "sgmm/am-sgmm.h"
#include "sgmm/fmllr-sgmm.h"
#include "hmm/transition-model.h"
namespace kaldi {
void AccumulateForUtterance(const Matrix<BaseFloat> &feats,
const Matrix<BaseFloat> &transformed_feats, // if already fMLLR
const std::vector<std::vector<int32> > &gselect,
const SgmmGselectConfig &sgmm_config,
const Posterior &post,
const TransitionModel &trans_model,
const AmSgmm &am_sgmm,
const SgmmPerSpkDerivedVars &spk_vars,
BaseFloat logdet,
FmllrSgmmAccs *spk_stats) {
kaldi::SgmmPerFrameDerivedVars per_frame_vars;
for (size_t t = 0; t < post.size(); t++) {
std::vector<int32> this_gselect;
if (!gselect.empty()) {
KALDI_ASSERT(t < gselect.size());
this_gselect = gselect[t];
} else {
am_sgmm.GaussianSelection(sgmm_config, feats.Row(t), &this_gselect);
}
// per-frame vars only used for computing posteriors... use the
// transformed feats for this, if available.
am_sgmm.ComputePerFrameVars(transformed_feats.Row(t), this_gselect, spk_vars,
0.0 /*fMLLR logdet*/, &per_frame_vars);
for (size_t j = 0; j < post[t].size(); j++) {
int32 pdf_id = trans_model.TransitionIdToPdf(post[t][j].first);
Matrix<BaseFloat> posteriors;
am_sgmm.ComponentPosteriors(per_frame_vars, pdf_id,
&posteriors);
posteriors.Scale(post[t][j].second);
spk_stats->AccumulateFromPosteriors(am_sgmm, spk_vars, feats.Row(t),
this_gselect,
posteriors, pdf_id);
}
}
}
} // end namespace kaldi
int main(int argc, char *argv[]) {
try {
typedef kaldi::int32 int32;
using namespace kaldi;
const char *usage =
"Estimate FMLLR transform for SGMMs, either per utterance or for the "
"supplied set of speakers (with spk2utt option).\n"
"Reads state-level posteriors. Writes to a table of matrices.\n"
"Usage: sgmm-est-fmllr [options] <model-in> <feature-rspecifier> "
"<post-rspecifier> <mats-wspecifier>\n";
ParseOptions po(usage);
string spk2utt_rspecifier, spkvecs_rspecifier, fmllr_rspecifier,
gselect_rspecifier;
BaseFloat min_count = 100;
SgmmFmllrConfig fmllr_opts;
SgmmGselectConfig sgmm_opts;
po.Register("spk2utt", &spk2utt_rspecifier,
"File to read speaker to utterance-list map from.");
po.Register("spkvec-min-count", &min_count,
"Minimum count needed to estimate speaker vectors");
po.Register("spk-vecs", &spkvecs_rspecifier,
"Speaker vectors to use during aligment (rspecifier)");
po.Register("input-fmllr", &fmllr_rspecifier,
"Initial FMLLR transform per speaker (rspecifier)");
po.Register("gselect", &gselect_rspecifier,
"Precomputed Gaussian indices (rspecifier)");
fmllr_opts.Register(&po);
sgmm_opts.Register(&po);
po.Read(argc, argv);
if (po.NumArgs() != 4) {
po.PrintUsage();
exit(1);
}
string model_rxfilename = po.GetArg(1),
feature_rspecifier = po.GetArg(2),
post_rspecifier = po.GetArg(3),
fmllr_wspecifier = po.GetArg(4);
TransitionModel trans_model;
AmSgmm am_sgmm;
SgmmFmllrGlobalParams fmllr_globals;
{
bool binary;
Input ki(model_rxfilename, &binary);
trans_model.Read(ki.Stream(), binary);
am_sgmm.Read(ki.Stream(), binary);
fmllr_globals.Read(ki.Stream(), binary);
}
RandomAccessPosteriorReader post_reader(post_rspecifier);
RandomAccessBaseFloatVectorReader spkvecs_reader(spkvecs_rspecifier);
RandomAccessInt32VectorVectorReader gselect_reader(gselect_rspecifier);
RandomAccessBaseFloatMatrixReader fmllr_reader(fmllr_rspecifier);
BaseFloatMatrixWriter fmllr_writer(fmllr_wspecifier);
int32 dim = am_sgmm.FeatureDim();
FmllrSgmmAccs spk_stats;
spk_stats.Init(dim, am_sgmm.NumGauss());
Matrix<BaseFloat> fmllr_xform(dim, dim + 1);
BaseFloat logdet = 0.0;
double tot_impr = 0.0, tot_t = 0.0;
int32 num_done = 0, num_no_post = 0, num_other_error = 0;
std::vector<std::vector<int32> > empty_gselect;
if (!spk2utt_rspecifier.empty()) { // per-speaker adaptation
SequentialTokenVectorReader spk2utt_reader(spk2utt_rspecifier);
RandomAccessBaseFloatMatrixReader feature_reader(feature_rspecifier);
for (; !spk2utt_reader.Done(); spk2utt_reader.Next()) {
spk_stats.SetZero();
string spk = spk2utt_reader.Key();
const vector<string> &uttlist = spk2utt_reader.Value();
SgmmPerSpkDerivedVars spk_vars;
if (spkvecs_reader.IsOpen()) {
if (spkvecs_reader.HasKey(spk)) {
spk_vars.v_s = spkvecs_reader.Value(spk);
am_sgmm.ComputePerSpkDerivedVars(&spk_vars);
} else {
KALDI_WARN << "Cannot find speaker vector for " << spk;
}
} // else spk_vars is "empty"
if (fmllr_reader.IsOpen()) {
if (fmllr_reader.HasKey(spk)) {
fmllr_xform.CopyFromMat(fmllr_reader.Value(spk));
logdet = fmllr_xform.Range(0, dim, 0, dim).LogDet();
} else {
KALDI_WARN << "Cannot find FMLLR transform for " << spk;
fmllr_xform.SetUnit();
logdet = 0.0;
}
} else {
fmllr_xform.SetUnit();
logdet = 0.0;
}
for (size_t i = 0; i < uttlist.size(); i++) {
std::string utt = uttlist[i];
if (!feature_reader.HasKey(utt)) {
KALDI_WARN << "Did not find features for utterance " << utt;
continue;
}
if (!post_reader.HasKey(utt)) {
KALDI_WARN << "Did not find posteriors for utterance " << utt;
num_no_post++;
continue;
}
const Matrix<BaseFloat> &feats = feature_reader.Value(utt);
const Posterior &post = post_reader.Value(utt);
if (static_cast<int32>(post.size()) != feats.NumRows()) {
KALDI_WARN << "posterior vector has wrong size " << (post.size())
<< " vs. " << (feats.NumRows());
num_other_error++;
continue;
}
bool have_gselect = !gselect_rspecifier.empty()
&& gselect_reader.HasKey(utt)
&& gselect_reader.Value(utt).size() == feats.NumRows();
if (!gselect_rspecifier.empty() && !have_gselect)
KALDI_WARN << "No Gaussian-selection info available for utterance "
<< utt << " (or wrong size)";
const std::vector<std::vector<int32> > *gselect =
(have_gselect ? &gselect_reader.Value(utt) : &empty_gselect);
Matrix<BaseFloat> transformed_feats(feats);
for (int32 r = 0; r < transformed_feats.NumRows(); r++)
ApplyAffineTransform(fmllr_xform, &transformed_feats.Row(r));
AccumulateForUtterance(feats, transformed_feats, *gselect, sgmm_opts,
post, trans_model, am_sgmm, spk_vars,
logdet, &spk_stats);
num_done++;
} // end looping over all utterances of the current speaker
BaseFloat impr, spk_frame_count;
// Compute the FMLLR transform and write it out.
spk_stats.Update(am_sgmm, fmllr_globals, fmllr_opts, &fmllr_xform,
&spk_frame_count, &impr);
fmllr_writer.Write(spk, fmllr_xform);
tot_impr += impr;
tot_t += spk_frame_count;
} // end looping over speakers
} else { // per-utterance adaptation
SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
for (; !feature_reader.Done(); feature_reader.Next()) {
string utt = feature_reader.Key();
if (!post_reader.HasKey(utt)) {
KALDI_WARN << "Did not find posts for utterance "
<< utt;
num_no_post++;
continue;
}
const Matrix<BaseFloat> &feats = feature_reader.Value();
SgmmPerSpkDerivedVars spk_vars;
if (spkvecs_reader.IsOpen()) {
if (spkvecs_reader.HasKey(utt)) {
spk_vars.v_s = spkvecs_reader.Value(utt);
am_sgmm.ComputePerSpkDerivedVars(&spk_vars);
} else {
KALDI_WARN << "Cannot find speaker vector for " << utt;
}
} // else spk_vars is "empty"
if (fmllr_reader.IsOpen()) {
if (fmllr_reader.HasKey(utt)) {
fmllr_xform.CopyFromMat(fmllr_reader.Value(utt));
logdet = fmllr_xform.Range(0, dim, 0, dim).LogDet();
} else {
KALDI_WARN << "Cannot find FMLLR transform for " << utt;
fmllr_xform.SetUnit();
logdet = 0.0;
}
} else {
fmllr_xform.SetUnit();
logdet = 0.0;
}
const Posterior &post = post_reader.Value(utt);
if (static_cast<int32>(post.size()) != feats.NumRows()) {
KALDI_WARN << "post has wrong size " << (post.size())
<< " vs. " << (feats.NumRows());
num_other_error++;
continue;
}
spk_stats.SetZero();
Matrix<BaseFloat> transformed_feats(feats);
for (int32 r = 0; r < transformed_feats.NumRows(); r++)
ApplyAffineTransform(fmllr_xform, &transformed_feats.Row(r));
bool have_gselect = !gselect_rspecifier.empty()
&& gselect_reader.HasKey(utt)
&& gselect_reader.Value(utt).size() == feats.NumRows();
if (!gselect_rspecifier.empty() && !have_gselect)
KALDI_WARN << "No Gaussian-selection info available for utterance "
<< utt << " (or wrong size)";
const std::vector<std::vector<int32> > *gselect =
(have_gselect ? &gselect_reader.Value(utt) : &empty_gselect);
AccumulateForUtterance(feats, transformed_feats, *gselect, sgmm_opts,
post, trans_model, am_sgmm, spk_vars,
logdet, &spk_stats);
num_done++;
BaseFloat impr, spk_frame_count;
// Compute the FMLLR transform and write it out.
spk_stats.Update(am_sgmm, fmllr_globals, fmllr_opts, &fmllr_xform,
&spk_frame_count, &impr);
fmllr_writer.Write(utt, fmllr_xform);
tot_impr += impr;
tot_t += spk_frame_count;
}
}
KALDI_LOG << "Done " << num_done << " files, " << num_no_post
<< " with no posts, " << num_other_error << " with other errors.";
KALDI_LOG << "Num frames " << tot_t << ", auxf impr per frame is "
<< (tot_impr / tot_t);
return 0;
} catch(const std::exception& e) {
std::cerr << e.what();
return -1;
}
}
......@@ -164,13 +164,9 @@ int main(int argc, char *argv[]) {
kaldi::uint64 props = lat.Properties(fst::kFstProperties, false);
if (!(props & fst::kTopSorted)) {
KALDI_WARN << "Utterance " << key << ": Supplied lattice not "
<< "topologically sorted. Sorting it.";
if (fst::TopSort(&lat) == false)
KALDI_ERR << "Cycles detected in lattice.";
} else {
KALDI_LOG << "Already topologically sorted.";
}
}
vector<int32> state_times;
int32 max_time = kaldi::LatticeStateTimes(lat, &state_times);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment