Commit 45ece9fb authored by Dan Povey's avatar Dan Povey
Browse files

Name changes in SGMM code; add sgmm-rescore-lattice; make gmm-rescore-lattice more efficient.

git-svn-id: https://svn.code.sf.net/p/kaldi/code/trunk@451 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
parent 1ea50549
......@@ -33,7 +33,7 @@ int main(int argc, char *argv[]) {
"Usage: weight-silence-post [options] <silence-weight> <silence-phones> "
"<model> <posteriors-rspecifier> <posteriors-wspecifier>\n"
"e.g.:\n"
" weight-silence-post 0.0 1:2:3 1.mdl ark:1.ali ark:1.post\n";
" weight-silence-post 0.0 1:2:3 1.mdl ark:1.post ark:nosil.post\n";
ParseOptions po(usage);
......
......@@ -22,12 +22,12 @@ using std::vector;
namespace kaldi {
BaseFloat DecodableAmSgmm::LogLikelihoodZeroBased(int32 frame, int32 state) {
BaseFloat DecodableAmSgmm::LogLikelihoodZeroBased(int32 frame, int32 pdf_id) {
KALDI_ASSERT(frame >= 0 && frame < NumFrames());
KALDI_ASSERT(state >= 0 && state < NumIndices());
KALDI_ASSERT(pdf_id >= 0 && pdf_id < NumIndices());
if (log_like_cache_[state].hit_time == frame) {
return log_like_cache_[state].log_like; // return cached value, if found
if (log_like_cache_[pdf_id].hit_time == frame) {
return log_like_cache_[pdf_id].log_like; // return cached value, if found
}
const VectorBase<BaseFloat> &data = feature_matrix_.Row(frame);
......@@ -49,18 +49,18 @@ BaseFloat DecodableAmSgmm::LogLikelihoodZeroBased(int32 frame, int32 state) {
previous_frame_ = frame;
}
BaseFloat loglike = acoustic_model_.LogLikelihood(per_frame_vars_, state,
BaseFloat loglike = acoustic_model_.LogLikelihood(per_frame_vars_, pdf_id,
log_prune_);
if (KALDI_ISNAN(loglike) || KALDI_ISINF(loglike))
KALDI_ERR << "Invalid answer (overflow or invalid variances/features?)";
log_like_cache_[state].log_like = loglike;
log_like_cache_[state].hit_time = frame;
log_like_cache_[pdf_id].log_like = loglike;
log_like_cache_[pdf_id].hit_time = frame;
return loglike;
}
void DecodableAmSgmm::ResetLogLikeCache() {
if (log_like_cache_.size() != acoustic_model_.NumStates()) {
log_like_cache_.resize(acoustic_model_.NumStates());
if (log_like_cache_.size() != acoustic_model_.NumPdfs()) {
log_like_cache_.resize(acoustic_model_.NumPdfs());
}
vector<LikelihoodCacheRecord>::iterator it = log_like_cache_.begin(),
end = log_like_cache_.end();
......
......@@ -60,7 +60,7 @@ class DecodableAmSgmm : public DecodableInterface {
void ResetLogLikeCache();
protected:
virtual BaseFloat LogLikelihoodZeroBased(int32 frame, int32 state_index);
virtual BaseFloat LogLikelihoodZeroBased(int32 frame, int32 pdf_id);
const AmSgmm &acoustic_model_;
const SgmmGselectConfig &sgmm_config_;
......@@ -137,7 +137,7 @@ class DecodableAmSgmmFmllr : public DecodableAmSgmm {
}
protected:
virtual BaseFloat LogLikelihoodZeroBased(int32 frame, int32 state_index);
virtual BaseFloat LogLikelihoodZeroBased(int32 frame, int32 pdf_id);
private:
Matrix<BaseFloat> fmllr_mat_;
......
......@@ -18,6 +18,7 @@
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "util/stl-utils.h"
#include "gmm/am-diag-gmm.h"
#include "hmm/transition-model.h"
#include "fstext/fstext-lib.h"
......@@ -35,22 +36,39 @@ void LatticeAcousticRescore(const AmDiagGmm& am,
if (!(props & fst::kTopSorted))
KALDI_ERR << "Input lattice must be topologically sorted.";
int32 num_states = lat->NumStates();
for (int32 state = 0; state < num_states; ++state) {
int32 cur_time = state_times[state];
for (fst::MutableArcIterator<Lattice> aiter(lat, state); !aiter.Done();
aiter.Next()) {
LatticeArc arc = aiter.Value();
int32 trans_id = arc.ilabel;
if (trans_id != 0) { // Non-epsilon input label on arc
int32 pdf_id = trans_model.TransitionIdToPdf(trans_id);
BaseFloat ll = am.LogLikelihood(pdf_id, data.Row(cur_time));
arc.weight.SetValue2(-ll + arc.weight.Value2());
// TODO(arnab): This can be made more efficient by caching likelihoods
KALDI_ASSERT(!state_times.empty());
int32 max_time = *std::max_element(state_times.begin(), state_times.end());
KALDI_ASSERT(max_time > 0);
std::vector<std::vector<int32> > time_to_state(max_time+1);
for (size_t i = 0; i < state_times.size(); i++) {
KALDI_ASSERT(state_times[i] >= 0);
time_to_state[state_times[i]].push_back(i);
}
for (int32 t = 0; t <= max_time; t++) {
unordered_map<int32, BaseFloat> pdf_id_to_like;
for (size_t i = 0; i < time_to_state[t].size(); i++) {
int32 state = time_to_state[t][i];
for (fst::MutableArcIterator<Lattice> aiter(lat, state); !aiter.Done();
aiter.Next()) {
LatticeArc arc = aiter.Value();
int32 trans_id = arc.ilabel;
if (trans_id != 0) { // Non-epsilon input label on arc
int32 pdf_id = trans_model.TransitionIdToPdf(trans_id);
BaseFloat ll;
if (pdf_id_to_like.count(pdf_id) == 0) {
ll = am.LogLikelihood(pdf_id, data.Row(t));
pdf_id_to_like[pdf_id] = ll;
} else {
ll = pdf_id_to_like[pdf_id];
}
arc.weight.SetValue2(-ll + arc.weight.Value2());
aiter.SetValue(arc);
}
}
aiter.SetValue(arc);
} // end iterating over arcs for a state
} // end iterating over states
}
}
}
} // namespace kaldi
......
......@@ -92,8 +92,8 @@ void TestSgmmIO(const AmSgmm &sgmm) {
void TestSgmmSubstates(const AmSgmm &sgmm) {
using namespace kaldi;
int32 target_substates = 2 * sgmm.NumStates();
kaldi::Vector<BaseFloat> occs(sgmm.NumStates());
int32 target_substates = 2 * sgmm.NumPdfs();
kaldi::Vector<BaseFloat> occs(sgmm.NumPdfs());
for (int32 i = 0; i < occs.Dim(); ++i)
occs(i) = std::fabs(kaldi::RandGauss()) * (kaldi::RandUniform()+1);
AmSgmm *sgmm1 = new AmSgmm();
......@@ -173,7 +173,7 @@ void TestSgmmIncreaseDim(const AmSgmm &sgmm) {
void TestSgmmPreXform(const AmSgmm &sgmm) {
kaldi::Matrix<BaseFloat> xform, inv_xform;
kaldi::Vector<BaseFloat> diag_scatter;
kaldi::Vector<BaseFloat> occs(sgmm.NumStates());
kaldi::Vector<BaseFloat> occs(sgmm.NumPdfs());
occs.Set(100);
sgmm.ComputeFmllrPreXform(occs, &xform, &inv_xform, &diag_scatter);
int32 dim = xform.NumRows();
......
......@@ -95,7 +95,7 @@ void AmSgmm::Read(std::istream &in_stream, bool binary) {
void AmSgmm::Write(std::ostream &out_stream, bool binary,
SgmmWriteFlagsType write_params) const {
int32 num_states = NumStates(),
int32 num_states = NumPdfs(),
feat_dim = FeatureDim(),
num_gauss = NumGauss();
......@@ -166,7 +166,7 @@ void AmSgmm::Write(std::ostream &out_stream, bool binary,
}
void AmSgmm::Check(bool show_properties) {
int32 num_states = NumStates(),
int32 num_states = NumPdfs(),
num_gauss = NumGauss(),
feat_dim = FeatureDim(),
phn_dim = PhoneSpaceDim(),
......@@ -318,7 +318,7 @@ void AmSgmm::ComputePerFrameVars(const VectorBase<BaseFloat>& data,
BaseFloat AmSgmm::LogLikelihood(const SgmmPerFrameDerivedVars &per_frame_vars,
int32 j, BaseFloat log_prune) const {
KALDI_ASSERT(j < NumStates());
KALDI_ASSERT(j < NumPdfs());
const vector<int32> &gselect = per_frame_vars.gselect;
......@@ -344,7 +344,7 @@ BaseFloat
AmSgmm::ComponentPosteriors(const SgmmPerFrameDerivedVars &per_frame_vars,
int32 j,
Matrix<BaseFloat> *post) const {
KALDI_ASSERT(j < NumStates());
KALDI_ASSERT(j < NumPdfs());
if (post == NULL) KALDI_ERR << "NULL pointer passed as return argument.";
const vector<int32> &gselect = per_frame_vars.gselect;
post->Resize(gselect.size(), NumSubstates(j));
......@@ -379,7 +379,7 @@ void AmSgmm::SplitSubstates(const Vector<BaseFloat> &state_occupancies,
int32 target_nsubstates, BaseFloat perturb,
BaseFloat power, BaseFloat max_cond) {
// power == p in document. target_nsubstates == T in document.
KALDI_ASSERT(state_occupancies.Dim() == NumStates());
KALDI_ASSERT(state_occupancies.Dim() == NumPdfs());
int32 tot_n_substates_old = 0;
int32 phn_dim = PhoneSpaceDim();
std::priority_queue<SubstateCounter> substate_counts;
......@@ -387,7 +387,7 @@ void AmSgmm::SplitSubstates(const Vector<BaseFloat> &state_occupancies,
SpMatrix<BaseFloat> sqrt_H_sm;
Vector<BaseFloat> rand_vec(phn_dim), v_shift(phn_dim);
for (int32 j = 0; j < NumStates(); j++) {
for (int32 j = 0; j < NumPdfs(); j++) {
// work out the sub-model's prob from the sum of 'c'.;
BaseFloat gamma_p = pow(state_occupancies(j) * c_[j].Sum(), power);
substate_counts.push(SubstateCounter(j, NumSubstates(j), gamma_p));
......@@ -482,7 +482,7 @@ void AmSgmm::IncreasePhoneSpaceDim(int32 target_dim,
w_.Resize(tmp_w.NumRows(), target_dim);
w_.Range(0, tmp_w.NumRows(), 0, tmp_w.NumCols()).CopyFromMat(tmp_w);
for (int32 j = 0; j < NumStates(); ++j) {
for (int32 j = 0; j < NumPdfs(); ++j) {
// Resize v[j]
Matrix<BaseFloat> tmp_v_j = v_[j];
v_[j].Resize(tmp_v_j.NumRows(), target_dim);
......@@ -563,8 +563,8 @@ void AmSgmm::ComputeNormalizers() {
double entropy_count = 0, entropy_sum = 0;
n_.resize(NumStates());
for (int32 j = 0; j < NumStates(); ++j) {
n_.resize(NumPdfs());
for (int32 j = 0; j < NumPdfs(); ++j) {
Vector<BaseFloat> log_w_jm(NumGauss());
n_[j].Resize(NumGauss(), NumSubstates(j));
......@@ -650,8 +650,8 @@ void AmSgmm::ComputeNormalizersNormalized(const std::vector<std::vector<int32> >
// double entropy_count = 0, entropy_sum = 0;
n_.resize(NumStates());
for (int32 j = 0; j < NumStates(); ++j) {
n_.resize(NumPdfs());
for (int32 j = 0; j < NumPdfs(); ++j) {
Vector<BaseFloat> log_w_jm(NumGauss());
n_[j].Resize(NumGauss(), NumSubstates(j));
......@@ -709,7 +709,7 @@ void AmSgmm::ComputeNormalizersNormalized(const std::vector<std::vector<int32> >
void AmSgmm::ComputeFmllrPreXform(const Vector<BaseFloat> &state_occs,
Matrix<BaseFloat> *xform, Matrix<BaseFloat> *inv_xform,
Vector<BaseFloat> *diag_mean_scatter) const {
int32 num_states = NumStates(),
int32 num_states = NumPdfs(),
num_gauss = NumGauss(),
dim = FeatureDim();
KALDI_ASSERT(state_occs.Dim() == num_states);
......@@ -927,13 +927,13 @@ void AmSgmm::ComputeSmoothingTermsFromModel(
BaseFloat max_cond) const {
int32 num_gauss = NumGauss();
BaseFloat tot_sum = 0.0;
KALDI_ASSERT(state_occupancies.Dim() == NumStates());
KALDI_ASSERT(state_occupancies.Dim() == NumPdfs());
Vector<BaseFloat> w_jm(num_gauss);
H_sm->Resize(PhoneSpaceDim());
H_sm->SetZero();
Vector<BaseFloat> gamma_i(num_gauss);
gamma_i.SetZero();
for (int32 j = 0; j < NumStates(); j++) {
for (int32 j = 0; j < NumPdfs(); j++) {
int32 M_j = NumSubstates(j);
KALDI_ASSERT(M_j > 0);
for (int32 m = 0; m < M_j; ++m) {
......@@ -1196,7 +1196,7 @@ void SgmmGauPost::Read(std::istream &is, bool binary) {
void AmSgmmFunctions::ComputeDistances(const AmSgmm& model,
const Vector<BaseFloat> &state_occs,
MatrixBase<BaseFloat> *dists) {
int32 num_states = model.NumStates(),
int32 num_states = model.NumPdfs(),
phn_space_dim = model.PhoneSpaceDim(),
num_gauss = model.NumGauss();
KALDI_ASSERT(dists != NULL && dists->NumRows() == num_states
......
......@@ -227,7 +227,7 @@ class AmSgmm {
Vector<BaseFloat> *diag_mean_scatter) const;
/// Various model dimensions.
int32 NumStates() const { return c_.size(); }
int32 NumPdfs() const { return c_.size(); }
int32 NumSubstates(int32 j) const { return c_[j].Dim(); }
int32 NumGauss() const { return M_.size(); }
int32 PhoneSpaceDim() const { return w_.NumCols(); }
......@@ -332,7 +332,7 @@ template<typename Real>
inline void AmSgmm::GetSubstateMean(int32 j, int32 m, int32 i,
VectorBase<Real> *mean_out) const {
KALDI_ASSERT(mean_out != NULL);
KALDI_ASSERT(j < NumStates() && m < NumSubstates(j) && i < NumGauss());
KALDI_ASSERT(j < NumPdfs() && m < NumSubstates(j) && i < NumGauss());
KALDI_ASSERT(mean_out->Dim() == FeatureDim());
Vector<BaseFloat> mean_tmp(FeatureDim());
mean_tmp.AddMatVec(1.0, M_[i], kNoTrans, v_[j].Row(m), 0.0);
......
......@@ -184,7 +184,7 @@ void MleAmSgmmAccs::Check(const AmSgmm &model,
feature_dim_ << ", S = " << phn_space_dim_ << ", T = " <<
spk_space_dim_ << ", I = " << num_gaussians_;
}
KALDI_ASSERT(num_states_ == model.NumStates() && num_states_ > 0);
KALDI_ASSERT(num_states_ == model.NumPdfs() && num_states_ > 0);
KALDI_ASSERT(num_gaussians_ == model.NumGauss() && num_gaussians_ > 0);
KALDI_ASSERT(feature_dim_ == model.FeatureDim() && feature_dim_ > 0);
KALDI_ASSERT(phn_space_dim_ == model.PhoneSpaceDim() && phn_space_dim_ > 0);
......@@ -275,7 +275,7 @@ void MleAmSgmmAccs::Check(const AmSgmm &model,
void MleAmSgmmAccs::ResizeAccumulators(const AmSgmm &model,
SgmmUpdateFlagsType flags) {
num_states_ = model.NumStates();
num_states_ = model.NumPdfs();
num_gaussians_ = model.NumGauss();
feature_dim_ = model.FeatureDim();
phn_space_dim_ = model.PhoneSpaceDim();
......
......@@ -57,7 +57,7 @@ void TestSgmmFmllrAccsIO(const AmSgmm &sgmm,
frame_vars.Resize(sgmm.NumGauss(), dim, sgmm.PhoneSpaceDim());
sgmm_config.full_gmm_nbest = std::min(sgmm_config.full_gmm_nbest,
sgmm.NumGauss());
kaldi::Vector<BaseFloat> occs(sgmm.NumStates());
kaldi::Vector<BaseFloat> occs(sgmm.NumPdfs());
occs.Set(feats.NumRows());
sgmm.ComputeFmllrPreXform(occs, &fmllr_globals.pre_xform_,
&fmllr_globals.inv_xform_,
......@@ -145,7 +145,7 @@ void TestSgmmFmllrSubspace(const AmSgmm &sgmm,
frame_vars.Resize(sgmm.NumGauss(), dim, sgmm.PhoneSpaceDim());
sgmm_config.full_gmm_nbest = std::min(sgmm_config.full_gmm_nbest,
sgmm.NumGauss());
kaldi::Vector<BaseFloat> occs(sgmm.NumStates());
kaldi::Vector<BaseFloat> occs(sgmm.NumPdfs());
occs.Set(feats.NumRows());
sgmm.ComputeFmllrPreXform(occs, &fmllr_globals.pre_xform_,
&fmllr_globals.inv_xform_,
......
......@@ -3,12 +3,13 @@ all:
EXTRA_CXXFLAGS = -Wno-sign-compare
include ../kaldi.mk
BINFILES = init-ubm sgmm-align sgmm-align-compiled sgmm-acc-stats-ali sgmm-sum-accs \
sgmm-est sgmm-decode-faster sgmm-init sgmm-gselect sgmm-acc-stats \
sgmm-est-spkvecs sgmm-post-to-gpost sgmm-acc-stats-gpost sgmm-est-spkvecs-gpost \
sgmm-comp-prexform sgmm-est-fmllr-gpost sgmm-acc-fmllrbasis-ali sgmm-est-fmllrbasis \
sgmm-calc-distances sgmm-normalize sgmm-latgen-simple
BINFILES = init-ubm sgmm-align sgmm-align-compiled sgmm-acc-stats-ali \
sgmm-sum-accs sgmm-est sgmm-decode-faster sgmm-init sgmm-gselect \
sgmm-acc-stats sgmm-est-spkvecs sgmm-post-to-gpost \
sgmm-acc-stats-gpost sgmm-est-spkvecs-gpost sgmm-comp-prexform \
sgmm-est-fmllr-gpost sgmm-acc-fmllrbasis-ali sgmm-est-fmllrbasis \
sgmm-calc-distances sgmm-normalize sgmm-latgen-simple \
sgmm-rescore-lattice
OBJFILES =
......
......@@ -91,9 +91,7 @@ int main(int argc, char *argv[]) {
SequentialTableReader<fst::VectorFstHolder> fst_reader(fst_rspecifier);
RandomAccessBaseFloatMatrixReader feature_reader(feature_rspecifier);
RandomAccessInt32VectorVectorReader gselect_reader;
if (!gselect_rspecifier.empty() && !gselect_reader.Open(gselect_rspecifier))
KALDI_ERR << "Unable to open stream for gaussian-selection indices";
RandomAccessInt32VectorVectorReader gselect_reader(gselect_rspecifier);
RandomAccessBaseFloatVectorReader spkvecs_reader(spkvecs_rspecifier);
......
......@@ -60,7 +60,7 @@ int main(int argc, char *argv[]) {
occs.Read(is.Stream(), binary);
}
Matrix<BaseFloat> dists(am_sgmm.NumStates(), am_sgmm.NumStates());
Matrix<BaseFloat> dists(am_sgmm.NumPdfs(), am_sgmm.NumPdfs());
AmSgmmFunctions::ComputeDistances(am_sgmm, occs, &dists);
Output os(distances_out_filename, binary);
......
// sgmmbin/sgmm-rescore-lattice.cc
// Copyright 2009-2011 Saarland University
// Author: Arnab Ghoshal
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "util/stl-utils.h"
#include "sgmm/am-sgmm.h"
#include "hmm/transition-model.h"
#include "fstext/fstext-lib.h"
#include "lat/kaldi-lattice.h"
#include "lat/lattice-utils.h"
namespace kaldi {
void LatticeAcousticRescore(const AmSgmm& am,
const TransitionModel& trans_model,
const MatrixBase<BaseFloat>& data,
const SgmmPerSpkDerivedVars &spk_vars,
const std::vector<std::vector<int32> > &gselect,
const SgmmGselectConfig &sgmm_config,
double log_prune,
const std::vector<int32> state_times,
Lattice *lat) {
kaldi::uint64 props = lat->Properties(fst::kFstProperties, false);
if (!(props & fst::kTopSorted))
KALDI_ERR << "Input lattice must be topologically sorted.";
KALDI_ASSERT(!state_times.empty());
int32 max_time = *std::max_element(state_times.begin(), state_times.end());
KALDI_ASSERT(max_time > 0);
std::vector<std::vector<int32> > time_to_state(max_time+1);
for (size_t i = 0; i < state_times.size(); i++) {
KALDI_ASSERT(state_times[i] >= 0);
time_to_state[state_times[i]].push_back(i);
}
for (int32 t = 0; t <= max_time; t++) {
SgmmPerFrameDerivedVars per_frame_vars;
std::vector<int32> this_gselect;
if (!gselect.empty()) {
KALDI_ASSERT(t < gselect.size());
this_gselect = gselect[t];
} else {
am.GaussianSelection(sgmm_config, data.Row(t), &this_gselect);
}
am.ComputePerFrameVars(data.Row(t), this_gselect, spk_vars,
0.0 /*fMLLR logdet*/, &per_frame_vars);
unordered_map<int32, BaseFloat> pdf_id_to_like;
for (size_t i = 0; i < time_to_state[t].size(); i++) {
int32 state = time_to_state[t][i];
for (fst::MutableArcIterator<Lattice> aiter(lat, state); !aiter.Done();
aiter.Next()) {
LatticeArc arc = aiter.Value();
int32 trans_id = arc.ilabel;
if (trans_id != 0) { // Non-epsilon input label on arc
int32 pdf_id = trans_model.TransitionIdToPdf(trans_id);
BaseFloat ll;
if (pdf_id_to_like.count(pdf_id) == 0) {
ll = am.LogLikelihood(per_frame_vars, pdf_id, log_prune);
pdf_id_to_like[pdf_id] = ll;
} else {
ll = pdf_id_to_like[pdf_id];
}
arc.weight.SetValue2(-ll + arc.weight.Value2());
aiter.SetValue(arc);
}
}
}
}
}
} // namespace kaldi
int main(int argc, char *argv[]) {
try {
using namespace kaldi;
typedef kaldi::int32 int32;
typedef kaldi::int64 int64;
using fst::SymbolTable;
using fst::VectorFst;
using fst::StdArc;
const char *usage =
"Replace the acoustic scores on a lattice using a new model.\n"
"Usage: gmm-resocre-lattice [options] <model-in> <lattice-rspecifier> "
"<feature-rspecifier> <lattice-wspecifier>\n"
" e.g.: gmm-resocre-lattice 1.mdl ark:1.lats scp:trn.scp ark:2.lats\n";
kaldi::BaseFloat old_acoustic_scale = 0.0;
BaseFloat log_prune = 5.0;
std::string gselect_rspecifier, spkvecs_rspecifier, utt2spk_rspecifier;
SgmmGselectConfig sgmm_opts;
kaldi::ParseOptions po(usage);
po.Register("old-acoustic-scale", &old_acoustic_scale,
"Add the current acoustic scores with some scale.");
po.Register("log-prune", &log_prune, "Pruning beam used to reduce number of exp() evaluations.");
po.Register("spk-vecs", &spkvecs_rspecifier, "Speaker vectors (rspecifier)");
po.Register("utt2spk", &utt2spk_rspecifier,
"rspecifier for utterance to speaker map");
po.Register("gselect", &gselect_rspecifier, "Precomputed Gaussian indices (rspecifier)");
sgmm_opts.Register(&po);
po.Read(argc, argv);
if (po.NumArgs() != 4) {
po.PrintUsage();
exit(1);
}
std::string model_filename = po.GetArg(1),
lats_rspecifier = po.GetArg(2),
feature_rspecifier = po.GetArg(3),
lats_wspecifier = po.GetArg(4);
AmSgmm am_sgmm;
TransitionModel trans_model;
{
bool binary;
Input is(model_filename, &binary);
trans_model.Read(is.Stream(), binary);
am_sgmm.Read(is.Stream(), binary);
}
RandomAccessTokenReader utt2spk_reader(utt2spk_rspecifier);
RandomAccessInt32VectorVectorReader gselect_reader(gselect_rspecifier);
RandomAccessBaseFloatVectorReader spkvecs_reader(spkvecs_rspecifier);
RandomAccessBaseFloatMatrixReader feature_reader(feature_rspecifier);
// Read as regular lattice
SequentialLatticeReader lattice_reader(lats_rspecifier);
// Write as compact lattice.
CompactLatticeWriter compact_lattice_writer(lats_wspecifier);
int32 n_done = 0, num_no_feats = 0, num_other_error = 0;
for (; !lattice_reader.Done(); lattice_reader.Next()) {
std::string key = lattice_reader.Key();
if (!feature_reader.HasKey(key)) {
KALDI_WARN << "No feature found for utterance " << key << ". Skipping";
num_no_feats++;
continue;
}
Lattice lat = lattice_reader.Value();
lattice_reader.FreeCurrent();
if (old_acoustic_scale != 1.0)
fst::ScaleLattice(fst::AcousticLatticeScale(old_acoustic_scale), &lat);
kaldi::uint64 props = lat.Properties(fst::kFstProperties, false);
if (!(props & fst::kTopSorted)) {
KALDI_WARN << "Utterance " << key << ": Supplied lattice not "
<< "topologically sorted. Sorting it.";
if (fst::TopSort(&lat) == false)
KALDI_ERR << "Cycles detected in lattice.";
} else {
KALDI_LOG << "Already topologically sorted.";
}
vector<int32> state_times;
int32 max_time = kaldi::LatticeStateTimes(lat, &state_times);
const Matrix<BaseFloat> &feats = feature_reader.Value(key);
if (feats.NumRows() != max_time) {
KALDI_WARN << "Skipping utterance " << key << " since number of time "
<< "frames in lattice ("<< max_time << ") differ from "
<< "number of feature frames (" << feats.NumRows() << ").";
num_other_error++;
continue;
}
std::string utt_or_spk; // used to work out speaker vector.
if (utt2spk_rspecifier.empty()) utt_or_spk = key;
else {
if (!utt2spk_reader.HasKey(key)) {
KALDI_WARN << "Utterance " << key << " not present in utt2spk map; "
<< "skipping this utterance.";
num_other_error++;
continue;
} else {
utt_or_spk = utt2spk_reader.Value(key);
}
}
// Get speaker vectors
SgmmPerSpkDerivedVars spk_vars;
if (spkvecs_reader.IsOpen()) {
if (spkvecs_reader.HasKey(utt_or_spk)) {
spk_vars.v_s = spkvecs_reader.Value(utt_or_spk);
am_sgmm.ComputePerSpkDerivedVars(&spk_vars);
} else {
KALDI_WARN << "Cannot find speaker vector for " << utt_or_spk;
}
} // else spk_vars is "empty"
bool have_gselect = !gselect_rspecifier.empty()
&& gselect_reader.HasKey(key)
&& gselect_reader.Value(key).size() == feats.NumRows();
if (!gselect_rspecifier.empty() && !have_gselect)
KALDI_WARN << "No Gaussian-selection info available for utterance "
<< key << " (or wrong size)";
std::vector<std::vector<int32> > empty_gselect;
const std::vector<std::vector<int32> > *gselect =
(have_gselect ? &gselect_reader.Value(key) : &empty_gselect);
kaldi::LatticeAcousticRescore(am_sgmm, trans_model, feats,
spk_vars, *gselect, sgmm_opts,
log_prune, state_times, &lat);
CompactLattice clat_out;
ConvertLattice(lat, &clat_out);
compact_lattice_writer.Write(key, clat_out);
n_done++;
}
KALDI_LOG << "Done " << n_done << " lattices.";
return (n_done != 0 ? 0 : 1);
} catch(const std::exception& e) {