Commit ded57d4f authored by Dan Povey's avatar Dan Povey
Browse files

Fixes and additions to lattice code.

git-svn-id: https://svn.code.sf.net/p/kaldi/code/sandbox/discrim@479 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
parent 415dd2d4
......@@ -58,9 +58,9 @@ int main(int argc, char *argv[]) {
<< silence_weight_str << '"';
std::vector<int32> silence_phones;
if (!SplitStringToIntegers(silence_phones_str, ":", false, &silence_phones))
KALDI_EXIT << "weight-silence-posteriors: Invalid silence-phones string " << silence_phones_str;
KALDI_EXIT << "Invalid silence-phones string " << silence_phones_str;
if (silence_phones.empty())
KALDI_WARN <<"weight-silence-posteriors: no silence phones, this will have no effect";
KALDI_WARN <<"No silence phones, this will have no effect";
ConstIntegerSet<int32> silence_set(silence_phones); // faster lookup.
TransitionModel trans_model;
......@@ -97,7 +97,7 @@ int main(int argc, char *argv[]) {
}
posterior_writer.Write(posterior_reader.Key(), new_post);
}
KALDI_LOG << "weight-silence-posteriors: processed " << num_posteriors << " posteriors.";
KALDI_LOG << "Done " << num_posteriors << " posteriors.";
} catch(const std::exception& e) {
std::cerr << e.what();
return -1;
......
......@@ -21,7 +21,7 @@ using std::pair;
#include <map>
using std::map;
#include <vector>
using std::vector;
#include "lat/lattice-utils.h"
#include "hmm/transition-model.h"
......@@ -85,7 +85,7 @@ BaseFloat LatticeForwardBackward(const Lattice &lat, Posterior *arc_post) {
vector<int32> state_times;
int32 max_time = LatticeStateTimes(lat, &state_times);
vector< vector<int32> > active_states(max_time + 1);
// the +1 is needed since time is indexed from 0
// the +1 is needed since time is indexed from 0.
vector<double> state_alphas(num_states, kLogZeroDouble),
state_betas(num_states, kLogZeroDouble);
......@@ -98,8 +98,9 @@ BaseFloat LatticeForwardBackward(const Lattice &lat, Posterior *arc_post) {
active_states[cur_time].push_back(state);
if (lat.Final(state) != LatticeWeight::Zero()) { // Check if final state.
state_betas[state] = 0.0;
tot_forward_prob = LogAdd(tot_forward_prob, state_alphas[state]);
BaseFloat final_loglike = -(lat.Final(state).Value1() + lat.Final(state).Value2());
state_betas[state] = final_loglike;
tot_forward_prob = LogAdd(tot_forward_prob, state_alphas[state] + final_loglike);
} else {
ForwardNode(lat, state, &state_alphas);
}
......@@ -134,12 +135,13 @@ BaseFloat LatticeForwardBackward(const Lattice &lat, Posterior *arc_post) {
void LatticeActivePhones(const Lattice &lat, const TransitionModel &trans,
const vector<int32> &sil_phones,
vector< map<int32, int32> > *active_phones) {
KALDI_ASSERT(IsSortedAndUniq(sil_phones));
const vector<int32> &silence_phones,
vector< std::set<int32> > *active_phones) {
KALDI_ASSERT(IsSortedAndUniq(silence_phones));
vector<int32> state_times;
int32 num_states = lat.NumStates();
int32 max_time = LatticeStateTimes(lat, &state_times);
active_phones->clear();
active_phones->resize(max_time);
for (int32 state = 0; state < num_states; ++state) {
int32 cur_time = state_times[state];
......@@ -148,13 +150,70 @@ void LatticeActivePhones(const Lattice &lat, const TransitionModel &trans,
const LatticeArc& arc = aiter.Value();
if (arc.ilabel != 0) { // Non-epsilon arc
int32 phone = trans.TransitionIdToPhone(arc.ilabel);
if (!std::binary_search(sil_phones.begin(), sil_phones.end(), phone))
(*active_phones)[cur_time][phone] = state;
if (!std::binary_search(silence_phones.begin(),
silence_phones.end(), phone))
(*active_phones)[cur_time].insert(phone);
}
} // end looping over arcs
} // end looping over states
}
bool LatticeBoost(const TransitionModel &trans,
const std::vector<std::set<int32> > &active_phones,
const std::vector<int32> &silence_phones,
BaseFloat b,
BaseFloat max_silence_error,
Lattice *lat) {
kaldi::uint64 props = lat->Properties(fst::kFstProperties, false);
if (!(props & fst::kTopSorted)) {
if (fst::TopSort(lat) == false) {
KALDI_WARN << "Cycles detected in lattice";
return false;
}
}
KALDI_ASSERT(IsSortedAndUniq(silence_phones));
KALDI_ASSERT(max_silence_error >= 0.0 && max_silence_error <= 1.0);
vector<int32> state_times;
int32 num_states = lat->NumStates();
LatticeStateTimes(*lat, &state_times);
for (int32 state = 0; state < num_states; ++state) {
int32 cur_time = state_times[state];
if (cur_time < 0 || cur_time > active_phones.size()) {
KALDI_WARN << "Lattice is too long for active_phones: mismatched den and num lattices/alignments?";
return false;
}
for (fst::MutableArcIterator<Lattice> aiter(lat, state); !aiter.Done();
aiter.Next()) {
LatticeArc arc = aiter.Value();
if (arc.ilabel != 0) { // Non-epsilon arc
if (arc.ilabel < 0 || arc.ilabel > trans.NumTransitionIds()) {
KALDI_WARN << "Lattice has out-of-range transition-ids: lattice/model mismatch?";
return false;
}
int32 phone = trans.TransitionIdToPhone(arc.ilabel);
BaseFloat frame_error;
if (active_phones[cur_time].count(phone) == 1) {
frame_error = 0.0;
} else { // an error...
if (std::binary_search(silence_phones.begin(), silence_phones.end(), phone))
frame_error = max_silence_error;
else
frame_error = 1.0;
}
BaseFloat delta_cost = -b * frame_error; // negative cost if
// frame is wrong, to boost likelihood of arcs with errors on them.
// Add this cost to the graph part.
arc.weight.SetValue1(arc.weight.Value1() + delta_cost);
aiter.SetValue(arc);
}
}
}
return true;
}
int32 LatticePhoneFrameAccuracy(const Lattice &hyp, const TransitionModel &trans,
const vector< map<int32, int32> > &ref_phones,
......@@ -277,9 +336,9 @@ void ForwardNode(const Lattice &lat, int32 state,
const LatticeArc& arc = aiter.Value();
double graph_score = arc.weight.Value1(),
am_score = arc.weight.Value2(),
arc_score = (*state_alphas)[state] - am_score - graph_score;
arc_loglike = (*state_alphas)[state] - am_score - graph_score;
(*state_alphas)[arc.nextstate] = LogAdd((*state_alphas)[arc.nextstate],
arc_score);
arc_loglike);
}
}
......@@ -299,10 +358,10 @@ void BackwardNode(const Lattice &lat, int32 state, int32 cur_time,
const LatticeArc& arc = aiter.Value();
if (arc.nextstate == state) {
KALDI_ASSERT(arc.ilabel == 0);
double arc_score = (*state_betas)[state] - arc.weight.Value1()
double arc_loglike = (*state_betas)[state] - arc.weight.Value1()
- arc.weight.Value2();
(*state_betas)[(*st_it)] = LogAdd((*state_betas)[(*st_it)],
arc_score);
arc_loglike);
}
}
}
......@@ -322,9 +381,9 @@ void BackwardNode(const Lattice &lat, int32 state, int32 cur_time,
KALDI_ASSERT(key != 0);
double graph_score = arc.weight.Value1(),
am_score = arc.weight.Value2(),
arc_score = (*state_betas)[state] - graph_score - am_score;
arc_loglike = (*state_betas)[state] - graph_score - am_score;
(*state_betas)[(*st_it)] = LogAdd((*state_betas)[(*st_it)],
arc_score);
arc_loglike);
double gamma = std::exp(state_alphas[(*st_it)] - graph_score - am_score
+ (*state_betas)[state] - tot_forward_prob);
if (post->find(key) == post->end()) // New label found at prev_time
......@@ -347,9 +406,9 @@ void ForwardNodeMpe(const Lattice &lat, const TransitionModel &tr,
const LatticeArc& arc = aiter.Value();
double graph_score = arc.weight.Value1(),
am_score = arc.weight.Value2(),
arc_score = (*state_alphas)[state].first - am_score - graph_score;
arc_loglike = (*state_alphas)[state].first - am_score - graph_score;
(*state_alphas)[arc.nextstate].first =
LogAdd((*state_alphas)[arc.nextstate].first, arc_score);
LogAdd((*state_alphas)[arc.nextstate].first, arc_loglike);
double frame_acc = 0.0;
if (arc.ilabel != 0) {
int32 phone = tr.TransitionIdToPhone(arc.ilabel);
......
......@@ -43,10 +43,32 @@ int32 LatticeStateTimes(const Lattice &lat, std::vector<int32> *times);
/// of the lattice.
BaseFloat LatticeForwardBackward(const Lattice &lat, Posterior *arc_post);
/// Given a lattice, and a transition model to map pdf-ids to phones,
/// outputs for each frame the set of phones active on that frame. If
/// sil_phones (which must be sorted and uniq) is nonempty, it excludes
/// phones in this list.
void LatticeActivePhones(const Lattice &lat, const TransitionModel &trans,
const std::vector<int32> &sil_phones,
std::vector< std::map<int32, int32> > *active_phones);
std::vector<std::set<int32> > *active_phones);
/// Boosts LM probabilities by b * [#frame errors]; equivalently, adds
/// -b*[#frame errors] to the graph-component of the cost of each arc/path.
/// There is a frame error if a particular transition-id on a particular frame
/// corresponds to a phone not appearining in active_phones for that frame.
/// This is used in "margin-inspired" discriminative training, esp. Boosted MMI.
/// The TransitionModel is used to map transition-ids in the lattice
/// input-side to phones; the phones appearing in
/// "silence_phones" are treated specially in that we replace the frame error f
/// (either zero or 1) for a frame, with the minimum of f or max_silence_error.
/// For the normal recipe, max_silence_error would be zero.
/// Returns true on success, false if there was some kind of mismatch.
/// At input, silence_phones must be sorted and unique.
bool LatticeBoost(const TransitionModel &trans,
const std::vector<std::set<int32> > &active_phones,
const std::vector<int32> &silence_phones,
BaseFloat b,
BaseFloat max_silence_error,
Lattice *lat);
/// This function takes a reference lattice
int32 LatticePhoneFrameAccuracy(const Lattice &hyp, const TransitionModel &trans,
......
......@@ -7,7 +7,7 @@ include ../kaldi.mk
BINFILES = lattice-best-path lattice-prune lattice-equivalent lattice-nbest \
lattice-lmrescore lattice-scale lattice-union lattice-to-post \
lattice-determinize lattice-oracle string-to-lattice lattice-rmali \
lattice-compose
lattice-compose lattice-boost-ali lattice-copy
OBJFILES =
......
// latbin/lattice-boost-ali.cc
// Copyright 2009-2011 Microsoft Corporation
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "fstext/fstext-lib.h"
#include "lat/kaldi-lattice.h"
#include "lat/lattice-utils.h"
int main(int argc, char *argv[]) {
try {
typedef kaldi::int32 int32;
using fst::SymbolTable;
using fst::VectorFst;
using fst::StdArc;
const char *usage =
"Boost graph likelihoods (decrease graph costs) by b * #frame-phone-errors\n"
"on each arc in the lattice. Useful for discriminative training, e.g.\n"
"boosted MMI. Modifies input lattices. This version takes the reference\n"
"in the form of alignments. Needs the model (just the transitions) to\n"
"transform pdf-ids to phones. Takes the --silence-phones option and these\n"
"phones appearing in the lattice are always assigned zero error, or with the\n"
"--max-silence-error option, at most this error-count per frame\n"
"(--max-silence-error=1 is equivalent to not specifying --silence-phones).\n"
"\n"
"Usage: lattice-boost-ali [options] model lats-rspecifier ali-rspecifier lats-wspecifier\n"
" e.g.: lattice-boost-ali --silence-phones=1:2:3 --b=0.05 1.mdl ark:1.lats ark:1.ali ark:boosted.lats\n";
kaldi::BaseFloat b = 0.05;
kaldi::BaseFloat max_silence_error = 0.0;
std::string silence_phones_str;
kaldi::ParseOptions po(usage);
po.Register("b", &b,
"Boosting factor (more -> more boosting of errors / larger margin)");
po.Register("max-silence", &max_silence_error,
"Maximum error assigned to silence phones [c.f. --silence-phones option]."
"0.0 -> original BMMI paper, 1.0 -> no special silence treatment.");
po.Register("silence-phones", &silence_phones_str,
"Colon-separated list of integer id's of silence phones, e.g. 46:47");
po.Read(argc, argv);
if (po.NumArgs() != 4) {
po.PrintUsage();
exit(1);
}
std::vector<int32> silence_phones;
if (!kaldi::SplitStringToIntegers(silence_phones_str, ":", false, &silence_phones))
KALDI_EXIT << "Invalid silence-phones string " << silence_phones_str;
kaldi::SortAndUniq(&silence_phones);
if (silence_phones.empty())
KALDI_WARN <<"No silence phones specified, make sure this is what you intended.";
std::string model_rxfilename = po.GetArg(1),
lats_rspecifier = po.GetArg(2),
ali_rspecifier = po.GetArg(3),
lats_wspecifier = po.GetArg(4);
// Read as regular lattice and write as compact.
kaldi::SequentialLatticeReader lattice_reader(lats_rspecifier);
kaldi::RandomAccessInt32VectorReader alignment_reader(ali_rspecifier);
kaldi::CompactLatticeWriter compact_lattice_writer(lats_wspecifier);
kaldi::TransitionModel trans;
{
bool binary_in;
kaldi::Input ki(model_rxfilename, &binary_in);
trans.Read(ki.Stream(), binary_in);
}
int32 n_done = 0, n_err = 0, n_no_ali = 0;
for (; !lattice_reader.Done(); lattice_reader.Next()) {
std::string key = lattice_reader.Key();
kaldi::Lattice lat = lattice_reader.Value();
lattice_reader.FreeCurrent();
if (b != 0.0) {
if (!alignment_reader.HasKey(key)) {
KALDI_WARN << "No alignment for utterance " << key;
n_no_ali++;
continue;
}
const std::vector<int32> &alignment = alignment_reader.Value(key);
std::vector<std::set<int32> > phones_seen(alignment.size());
for (size_t i = 0; i < alignment.size(); i++) {
// next line may crash if alignments mismatched w/ model:
phones_seen[i].insert(trans.TransitionIdToPhone(alignment[i]));
}
if (!LatticeBoost(trans, phones_seen, silence_phones, b,
max_silence_error, &lat)) {
n_err++; // will already have printed warning.
continue;
}
}
kaldi::CompactLattice clat;
ConvertLattice(lat, &clat);
compact_lattice_writer.Write(key, clat);
n_done++;
}
KALDI_LOG << "Done " << n_done << " lattices, missing alignments for "
<< n_no_ali << ", other errors on " << n_err;
return (n_done != 0 ? 0 : 1);
} catch(const std::exception& e) {
std::cerr << e.what();
return -1;
}
}
// latbin/lattice-copy.cc
// Copyright 2009-2011 Microsoft Corporation
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "fstext/fstext-lib.h"
#include "lat/kaldi-lattice.h"
int main(int argc, char *argv[]) {
try {
using namespace kaldi;
typedef kaldi::int32 int32;
typedef kaldi::int64 int64;
using fst::SymbolTable;
using fst::VectorFst;
using fst::StdArc;
const char *usage =
"Copy lattices (e.g. useful for changing to text mode or changing\n"
"format to standard from compact lattice.)\n"
"Usage: lattice-copy [options] lattice-rspecifier lattice-wspecifier\n"
" e.g.: lattice-copy --write-compact=false ark:1.lats ark,t:text.lats\n";
ParseOptions po(usage);
bool write_compact = true;
po.Register("write-compact", &write_compact, "If true, write in normal (compact) form.");
po.Read(argc, argv);
if (po.NumArgs() != 2) {
po.PrintUsage();
exit(1);
}
std::string lats_rspecifier = po.GetArg(1),
lats_wspecifier = po.GetArg(2);
int32 n_done = 0;
if (write_compact) {
SequentialCompactLatticeReader lattice_reader(lats_rspecifier);
CompactLatticeWriter lattice_writer(lats_wspecifier);
for (; !lattice_reader.Done(); lattice_reader.Next(), n_done++)
lattice_writer.Write(lattice_reader.Key(), lattice_reader.Value());
} else {
SequentialLatticeReader lattice_reader(lats_rspecifier);
LatticeWriter lattice_writer(lats_wspecifier);
for (; !lattice_reader.Done(); lattice_reader.Next(), n_done++)
lattice_writer.Write(lattice_reader.Key(), lattice_reader.Value());
}
KALDI_LOG << "Done copying " << n_done << " lattices.";
return (n_done != 0 ? 0 : 1);
} catch(const std::exception& e) {
std::cerr << e.what();
return -1;
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment