Commit 78a8fda0 authored by Jan Trmal's avatar Jan Trmal
Browse files

(trunk) Create a separate kws library. If you are using lat/kaldi-kws.h,...

(trunk) Create a separate kws library. If you are using lat/kaldi-kws.h, change it to kws/kaldi-kws.h.

git-svn-id: https://svn.code.sf.net/p/kaldi/code/trunk@5176 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
parent 8bd0ffec
......@@ -6,7 +6,7 @@ SHELL := /bin/bash
SUBDIRS = base matrix util feat tree thread gmm transform sgmm \
fstext hmm lm decoder lat cudamatrix nnet \
fstext hmm lm decoder lat kws cudamatrix nnet \
bin fstbin gmmbin fgmmbin sgmmbin featbin \
nnetbin latbin sgmm2 sgmm2bin nnet2 nnet2bin kwsbin \
ivector ivectorbin online2 online2bin lmbin
......@@ -176,4 +176,5 @@ online2bin: base matrix util feat tree optimization gmm transform sgmm sgmm2 fst
# python-kaldi-decoding: base matrix util feat tree optimization thread gmm transform sgmm sgmm2 fstext hmm decoder lat online
online: decoder gmm transform feat matrix util base lat hmm thread tree
online2: decoder gmm transform feat matrix util base lat hmm thread ivector cudamatrix nnet2
kwsbin: fstext lat base util hmm tree matrix
kws: base util hmm tree matrix lat
kwsbin: fstext kws lat base util hmm tree matrix
all:
include ../kaldi.mk
EXTRA_CXXFLAGS += -Wno-sign-compare
OBJFILES = kws-functions.o
LIBNAME = kaldi-kws
ADDLIBS = ../hmm/kaldi-hmm.a ../lat/kaldi-lat.a ../tree/kaldi-tree.a \
../matrix/kaldi-matrix.a ../util/kaldi-util.a ../base/kaldi-base.a
include ../makefiles/default_rules.mk
......@@ -18,10 +18,10 @@
// limitations under the License.
#include "lat/kws-functions.h"
#include "lat/lattice-functions.h"
#include "kws/kws-functions.h"
#include "fstext/determinize-star.h"
#include "fstext/epsilon-property.h"
namespace kaldi {
bool CompareInterval(const Interval &i1,
......@@ -106,83 +106,6 @@ bool ClusterLattice(CompactLattice *clat,
return true;
}
bool ComputeCompactLatticeAlphas(const CompactLattice &clat,
vector<double> *alpha) {
using namespace fst;
// typedef the arc, weight types
typedef CompactLattice::Arc Arc;
typedef Arc::Weight Weight;
typedef Arc::StateId StateId;
//Make sure the lattice is topologically sorted.
if (clat.Properties(fst::kTopSorted, true) == 0) {
KALDI_WARN << "Input lattice must be topologically sorted.";
return false;
}
if (clat.Start() != 0) {
KALDI_WARN << "Input lattice must start from state 0.";
return false;
}
int32 num_states = clat.NumStates();
(*alpha).resize(0);
(*alpha).resize(num_states, kLogZeroDouble);
// Now propagate alphas forward. Note that we don't acount the weight of the
// final state to alpha[final_state] -- we acount it to beta[final_state];
(*alpha)[0] = 0.0;
for (StateId s = 0; s < num_states; s++) {
double this_alpha = (*alpha)[s];
for (ArcIterator<CompactLattice> aiter(clat, s); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
double arc_like = -(arc.weight.Weight().Value1() + arc.weight.Weight().Value2());
(*alpha)[arc.nextstate] = LogAdd((*alpha)[arc.nextstate], this_alpha + arc_like);
}
}
return true;
}
bool ComputeCompactLatticeBetas(const CompactLattice &clat,
vector<double> *beta) {
using namespace fst;
// typedef the arc, weight types
typedef CompactLattice::Arc Arc;
typedef Arc::Weight Weight;
typedef Arc::StateId StateId;
// Make sure the lattice is topologically sorted.
if (clat.Properties(fst::kTopSorted, true) == 0) {
KALDI_WARN << "Input lattice must be topologically sorted.";
return false;
}
if (clat.Start() != 0) {
KALDI_WARN << "Input lattice must start from state 0.";
return false;
}
int32 num_states = clat.NumStates();
(*beta).resize(0);
(*beta).resize(num_states, kLogZeroDouble);
// Now propagate betas backward. Note that beta[final_state] contains the
// weight of the final state in the lattice -- compare that with alpha.
for (StateId s = num_states-1; s >= 0; s--) {
Weight f = clat.Final(s);
double this_beta = -(f.Weight().Value1()+f.Weight().Value2());
for (ArcIterator<CompactLattice> aiter(clat, s); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
double arc_like = -(arc.weight.Weight().Value1()+arc.weight.Weight().Value2());
double arc_beta = (*beta)[arc.nextstate] + arc_like;
this_beta = LogAdd(this_beta, arc_beta);
}
(*beta)[s] = this_beta;
}
return true;
}
class CompactLatticeToKwsProductFstMapper {
public:
......
......@@ -22,7 +22,7 @@
#define KALDI_LAT_KWS_FUNCTIONS_H_
#include "lat/kaldi-lattice.h"
#include "lat/kaldi-kws.h"
#include "kws/kaldi-kws.h"
namespace kaldi {
......@@ -64,17 +64,6 @@ bool CompareInterval(const Interval &i1,
bool ClusterLattice(CompactLattice *clat,
const vector<int32> &state_times);
// This function is something similar to LatticeForwardBackward(), but it is on
// the CompactLattice lattice format. Also we only need the alpha in the forward
// path, not the posteriors.
bool ComputeCompactLatticeAlphas(const CompactLattice &lat,
vector<double> *alpha);
// A sibling of the function CompactLatticeAlphas()... We compute the beta from
// the backward path here.
bool ComputeCompactLatticeBetas(const CompactLattice &lat,
vector<double> *beta);
// This function contains two steps: weight pushing and factor generation. The
// original ShortestDistance() is not very efficient, so we do the weight
// pushing and shortest path manually by computing the alphas and betas. The
......
......@@ -14,7 +14,7 @@ OBJFILES =
TESTFILES =
ADDLIBS = ../lat/kaldi-lat.a ../fstext/kaldi-fstext.a \
ADDLIBS = ../kws/kaldi-kws.a ../lat/kaldi-lat.a ../fstext/kaldi-fstext.a \
../hmm/kaldi-hmm.a ../tree/kaldi-tree.a ../matrix/kaldi-matrix.a \
../util/kaldi-util.a ../base/kaldi-base.a
......
......@@ -21,8 +21,8 @@
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "fstext/fstext-utils.h"
#include "lat/kaldi-kws.h"
#include "lat/kws-functions.h"
#include "kws/kaldi-kws.h"
#include "kws/kws-functions.h"
int main(int argc, char *argv[]) {
try {
......
......@@ -21,7 +21,7 @@
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "fstext/fstext-utils.h"
#include "lat/kaldi-kws.h"
#include "kws/kaldi-kws.h"
namespace kaldi {
......
......@@ -24,8 +24,8 @@
#include "fstext/fstext-utils.h"
#include "lat/kaldi-lattice.h"
#include "lat/lattice-functions.h"
#include "lat/kaldi-kws.h"
#include "lat/kws-functions.h"
#include "kws/kaldi-kws.h"
#include "kws/kws-functions.h"
#include "fstext/epsilon-property.h"
int main(int argc, char *argv[]) {
......
......@@ -10,8 +10,8 @@ TESTFILES = kaldi-lattice-test push-lattice-test minimize-lattice-test \
OBJFILES = kaldi-lattice.o lattice-functions.o word-align-lattice.o \
phone-align-lattice.o word-align-lattice-lexicon.o sausages.o \
kws-functions.o push-lattice.o minimize-lattice.o \
determinize-lattice-pruned.o confidence.o
push-lattice.o minimize-lattice.o determinize-lattice-pruned.o \
confidence.o
LIBNAME = kaldi-lat
......@@ -20,3 +20,28 @@ ADDLIBS = ../hmm/kaldi-hmm.a ../tree/kaldi-tree.a ../matrix/kaldi-matrix.a \
include ../makefiles/default_rules.mk
# Overriding the default library rule
# Added 2015-06-22 in connection with creating a standalone kws lib
# It's purpose is to make the transition more seamless for users
# Will be removed in a half a year or so.
$(LIBFILE): $(OBJFILES)
$(AR) -d $(LIBNAME).a kws-functions.o
$(AR) -cru $(LIBNAME).a $(OBJFILES)
$(RANLIB) $(LIBNAME).a
ifeq ($(KALDI_FLAVOR), dynamic)
ifeq ($(shell uname), Darwin)
$(CXX) -dynamiclib -o $@ -install_name @rpath/$@ -framework Accelerate $(LDFLAGS) $(XLDLIBS) $(OBJFILES) $(LDLIBS)
rm -f $(KALDILIBDIR)/$@; ln -s $(shell pwd)/$@ $(KALDILIBDIR)/$@
else
ifeq ($(shell uname), Linux)
# Building shared library from static (static was compiled with -fPIC)
$(CXX) -shared -o $@ -Wl,--no-undefined -Wl,--as-needed -Wl,-soname=$@,--whole-archive $(LIBNAME).a -Wl,--no-whole-archive $(LDFLAGS) $(XDEPENDS) $(LDLIBS)
rm -f $(KALDILIBDIR)/$@; ln -s $(shell pwd)/$@ $(KALDILIBDIR)/$@
#cp $@ $(KALDILIBDIR)
else # Platform not supported
$(error Dynamic libraries not supported on this platform. Run configure with --static flag. )
endif
endif
endif
......@@ -41,7 +41,7 @@ int32 LatticeStateTimes(const Lattice &lat, vector<int32> *times) {
times->clear();
times->resize(num_states, -1);
(*times)[0] = 0;
for (int32 state = 0; state < num_states; state++) {
for (int32 state = 0; state < num_states; state++) {
int32 cur_time = (*times)[state];
for (fst::ArcIterator<Lattice> aiter(lat, state); !aiter.Done();
aiter.Next()) {
......@@ -96,7 +96,7 @@ int32 CompactLatticeStateTimes(const CompactLattice &lat, vector<int32> *times)
utt_len = std::max(utt_len, this_utt_len);
}
}
}
}
}
if (utt_len == -1) {
KALDI_WARN << "Utterance does not have a final-state.";
......@@ -105,12 +105,90 @@ int32 CompactLatticeStateTimes(const CompactLattice &lat, vector<int32> *times)
return utt_len;
}
template<class LatType> // could be Lattice or CompactLattice
bool ComputeCompactLatticeAlphas(const CompactLattice &clat,
vector<double> *alpha) {
using namespace fst;
// typedef the arc, weight types
typedef CompactLattice::Arc Arc;
typedef Arc::Weight Weight;
typedef Arc::StateId StateId;
//Make sure the lattice is topologically sorted.
if (clat.Properties(fst::kTopSorted, true) == 0) {
KALDI_WARN << "Input lattice must be topologically sorted.";
return false;
}
if (clat.Start() != 0) {
KALDI_WARN << "Input lattice must start from state 0.";
return false;
}
int32 num_states = clat.NumStates();
(*alpha).resize(0);
(*alpha).resize(num_states, kLogZeroDouble);
// Now propagate alphas forward. Note that we don't acount the weight of the
// final state to alpha[final_state] -- we acount it to beta[final_state];
(*alpha)[0] = 0.0;
for (StateId s = 0; s < num_states; s++) {
double this_alpha = (*alpha)[s];
for (ArcIterator<CompactLattice> aiter(clat, s); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
double arc_like = -(arc.weight.Weight().Value1() + arc.weight.Weight().Value2());
(*alpha)[arc.nextstate] = LogAdd((*alpha)[arc.nextstate], this_alpha + arc_like);
}
}
return true;
}
bool ComputeCompactLatticeBetas(const CompactLattice &clat,
vector<double> *beta) {
using namespace fst;
// typedef the arc, weight types
typedef CompactLattice::Arc Arc;
typedef Arc::Weight Weight;
typedef Arc::StateId StateId;
// Make sure the lattice is topologically sorted.
if (clat.Properties(fst::kTopSorted, true) == 0) {
KALDI_WARN << "Input lattice must be topologically sorted.";
return false;
}
if (clat.Start() != 0) {
KALDI_WARN << "Input lattice must start from state 0.";
return false;
}
int32 num_states = clat.NumStates();
(*beta).resize(0);
(*beta).resize(num_states, kLogZeroDouble);
// Now propagate betas backward. Note that beta[final_state] contains the
// weight of the final state in the lattice -- compare that with alpha.
for (StateId s = num_states-1; s >= 0; s--) {
Weight f = clat.Final(s);
double this_beta = -(f.Weight().Value1()+f.Weight().Value2());
for (ArcIterator<CompactLattice> aiter(clat, s); !aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
double arc_like = -(arc.weight.Weight().Value1()+arc.weight.Weight().Value2());
double arc_beta = (*beta)[arc.nextstate] + arc_like;
this_beta = LogAdd(this_beta, arc_beta);
}
(*beta)[s] = this_beta;
}
return true;
}
template<class LatType> // could be Lattice or CompactLattice
bool PruneLattice(BaseFloat beam, LatType *lat) {
typedef typename LatType::Arc Arc;
typedef typename Arc::Weight Weight;
typedef typename Arc::StateId StateId;
KALDI_ASSERT(beam > 0.0);
if (!lat->Properties(fst::kTopSorted, true)) {
if (fst::TopSort(lat) == false) {
......@@ -124,7 +202,7 @@ bool PruneLattice(BaseFloat beam, LatType *lat) {
int32 num_states = lat->NumStates();
if (num_states == 0) return false;
std::vector<double> forward_cost(num_states,
std::numeric_limits<double>::infinity()); // viterbi forward.
std::numeric_limits<double>::infinity()); // viterbi forward.
forward_cost[start] = 0.0; // lattice can't have cycles so couldn't be
// less than this.
double best_final_cost = std::numeric_limits<double>::infinity();
......@@ -151,7 +229,7 @@ bool PruneLattice(BaseFloat beam, LatType *lat) {
}
int32 bad_state = lat->AddState(); // this state is not final.
double cutoff = best_final_cost + beam;
// Go backwards updating the backward probs (which share memory with the
// forward probs), and pruning arcs and deleting final-probs. We prune arcs
// by making them point to the non-final state "bad_state". We'll then use
......@@ -200,7 +278,7 @@ BaseFloat LatticeForwardBackward(const Lattice &lat, Posterior *post,
typedef Lattice::Arc Arc;
typedef Arc::Weight Weight;
typedef Arc::StateId StateId;
if (acoustic_like_sum) *acoustic_like_sum = 0.0;
// Make sure the lattice is topologically sorted.
......@@ -218,7 +296,7 @@ BaseFloat LatticeForwardBackward(const Lattice &lat, Posterior *post,
post->clear();
post->resize(max_time);
alpha[0] = 0.0;
// Propagate alphas forward.
for (StateId s = 0; s < num_states; s++) {
......@@ -274,7 +352,7 @@ BaseFloat LatticeForwardBackward(const Lattice &lat, Posterior *post,
MergePairVectorSumming(&((*post)[t]));
return tot_backward_prob;
}
void LatticeActivePhones(const Lattice &lat, const TransitionModel &trans,
const vector<int32> &silence_phones,
......@@ -411,7 +489,7 @@ void CompactLatticeLimitDepth(int32 max_depth_per_frame,
if (!TopSort(clat))
KALDI_ERR << "Topological sorting of lattice failed.";
}
vector<int32> state_times;
int32 T = CompactLatticeStateTimes(*clat, &state_times);
......@@ -611,7 +689,7 @@ bool LatticeBoost(const TransitionModel &trans,
// get all stored properties (test==false means don't test if not known).
uint64 props = lat->Properties(fst::kFstProperties,
false);
KALDI_ASSERT(IsSortedAndUniq(silence_phones));
KALDI_ASSERT(max_silence_error >= 0.0 && max_silence_error <= 1.0);
vector<int32> state_times;
......@@ -642,7 +720,7 @@ bool LatticeBoost(const TransitionModel &trans,
}
BaseFloat delta_cost = -b * frame_error; // negative cost if
// frame is wrong, to boost likelihood of arcs with errors on them.
// Add this cost to the graph part.
// Add this cost to the graph part.
arc.weight.SetValue1(arc.weight.Value1() + delta_cost);
aiter.SetValue(arc);
}
......@@ -653,7 +731,7 @@ bool LatticeBoost(const TransitionModel &trans,
// lattice was weighted.
lat->SetProperties(props,
~(fst::kWeighted|fst::kUnweighted));
return true;
}
......@@ -674,11 +752,11 @@ BaseFloat LatticeForwardBackwardMpeVariants(
KALDI_ASSERT(criterion == "mpfe" || criterion == "smbr");
bool is_mpfe = (criterion == "mpfe");
if (lat.Properties(fst::kTopSorted, true) == 0)
KALDI_ERR << "Input lattice must be topologically sorted.";
KALDI_ASSERT(lat.Start() == 0);
int32 num_states = lat.NumStates();
vector<int32> state_times;
int32 max_time = LatticeStateTimes(lat, &state_times);
......@@ -953,7 +1031,7 @@ bool CompactLatticeToWordProns(
}
prons->push_back(phones);
phone_lengths->push_back(plengths);
cur_time += length;
state = arc.nextstate;
}
......@@ -1047,7 +1125,7 @@ void CompactLatticeShortestPath(const CompactLattice &clat,
}
}
void AddWordInsPenToCompactLattice(BaseFloat word_ins_penalty,
void AddWordInsPenToCompactLattice(BaseFloat word_ins_penalty,
CompactLattice *clat) {
typedef CompactLatticeArc Arc;
int32 num_states = clat->NumStates();
......@@ -1056,19 +1134,19 @@ void AddWordInsPenToCompactLattice(BaseFloat word_ins_penalty,
for (int32 state = 0; state < num_states; state++) {
for (fst::MutableArcIterator<CompactLattice> aiter(clat, state);
!aiter.Done(); aiter.Next()) {
Arc arc(aiter.Value());
if (arc.ilabel != 0) { // if there is a word on this arc
LatticeWeight weight = arc.weight.Weight();
// add word insertion penalty to lattice
weight.SetValue1( weight.Value1() + word_ins_penalty);
weight.SetValue1( weight.Value1() + word_ins_penalty);
arc.weight.SetWeight(weight);
aiter.SetValue(arc);
}
}
} // end looping over arcs
} // end looping over states
}
} // end looping over states
}
struct ClatRescoreTuple {
ClatRescoreTuple(int32 state, int32 arc, int32 tid):
......@@ -1100,7 +1178,7 @@ bool RescoreCompactLatticeInternal(
}
std::vector<int32> state_times;
int32 utt_len = kaldi::CompactLatticeStateTimes(*clat, &state_times);
std::vector<std::vector<ClatRescoreTuple> > time_to_state(utt_len);
int32 num_states = clat->NumStates();
......@@ -1113,7 +1191,7 @@ bool RescoreCompactLatticeInternal(
!aiter.Done(); aiter.Next(), arc_id++) {
CompactLatticeArc arc = aiter.Value();
std::vector<int32> arc_string = arc.weight.String();
for (size_t offset = 0; offset < arc_string.size(); offset++) {
if (t < utt_len) { // end state may be past this..
int32 tid = arc_string[offset];
......@@ -1151,7 +1229,7 @@ bool RescoreCompactLatticeInternal(
// For frames with only one pdf-id, it will equal speedup_factor (>=1.0)
// with probability 1.0 / speedup_factor, and zero otherwise. If it is zero,
// we can avoid computing the probabilities.
BaseFloat frame_scale = 1.0;
BaseFloat frame_scale = 1.0;
KALDI_ASSERT(!time_to_state[t].empty());
if (tmodel != NULL) {
int32 pdf_id = tmodel->TransitionIdToPdf(time_to_state[t][0].tid);
......@@ -1164,7 +1242,7 @@ bool RescoreCompactLatticeInternal(
}
if (frame_has_multiple_pdfs) {
frame_scale = 1.0;
} else {
} else {
if (WithProb(1.0 / speedup_factor)) {
frame_scale = speedup_factor;
} else {
......@@ -1174,16 +1252,16 @@ bool RescoreCompactLatticeInternal(
if (frame_scale == 0.0)
continue; // the code below would be pointless.
}
for (size_t i = 0; i < time_to_state[t].size(); i++) {
int32 state = time_to_state[t][i].state_id;
int32 arc_id = time_to_state[t][i].arc_id;
int32 tid = time_to_state[t][i].tid;
if (arc_id == -1) { // Final state
// Access the trans_id
CompactLatticeWeight curr_clat_weight = clat->Final(state);
// Calculate likelihood
BaseFloat log_like = decodable->LogLikelihood(t, tid) * frame_scale;
// update weight
......@@ -1232,7 +1310,7 @@ bool RescoreLattice(DecodableInterface *decodable,
KALDI_WARN << "Rescoring empty lattice";
return false;
}
if (!lat->Properties(fst::kTopSorted, true)) {
if (!lat->Properties(fst::kTopSorted, true)) {
if (fst::TopSort(lat) == false) {
KALDI_WARN << "Cycles detected in lattice.";
return false;
......@@ -1240,15 +1318,15 @@ bool RescoreLattice(DecodableInterface *decodable,
}
std::vector<int32> state_times;
int32 utt_len = kaldi::LatticeStateTimes(*lat, &state_times);
std::vector<std::vector<int32> > time_to_state(utt_len );
int32 num_states = lat->NumStates();
KALDI_ASSERT(num_states == state_times.size());
for (size_t state = 0; state < num_states; state++) {
int32 t = state_times[state];
// Don't check t >= 0 because non-accessible states could have t = -1.
KALDI_ASSERT(t <= utt_len);
KALDI_ASSERT(t <= utt_len);
if (t >= 0 && t < utt_len)
time_to_state[t].push_back(state);
}
......@@ -1294,14 +1372,14 @@ BaseFloat LatticeForwardBackwardMmi(
BaseFloat ans = LatticeForwardBackward(lat,
&den_post,
NULL);
Posterior num_post;
AlignmentToPosterior(num_ali, &num_post);
// Now negate the MMI posteriors and add the numerator
// posteriors.
ScalePosterior(-1.0, &den_post);
if (convert_to_pdf_ids) {
Posterior num_tmp;
ConvertPosteriorToPdfs(tmodel, num_post, &num_tmp);
......@@ -1310,10 +1388,10 @@ BaseFloat LatticeForwardBackwardMmi(
ConvertPosteriorToPdfs(tmodel, den_post, &den_tmp);
den_tmp.swap(den_post);
}
MergePosteriors(num_post, den_post,
cancel, drop_frames, post);
return ans;
}
......@@ -1322,7 +1400,7 @@ int32 LongestSentenceLength(const Lattice &lat) {
typedef Lattice::Arc Arc;
typedef Arc::Label Label;
typedef Arc::StateId StateId;
if (lat.Properties(fst::kTopSorted, true) == 0) {
Lattice lat_copy(lat);
if (!TopSort(&lat_copy))
......@@ -1359,7 +1437,7 @@ int32 LongestSentenceLength(const CompactLattice &clat) {
typedef CompactLattice::Arc Arc;
typedef Arc::Label Label;
typedef Arc::StateId StateId;
if (clat.Properties(fst::kTopSorted, true) == 0) {
CompactLattice clat_copy(clat);
if (!TopSort(&clat_copy))
......
......@@ -63,6 +63,17 @@ BaseFloat LatticeForwardBackward(const Lattice &lat,
Posterior *arc_post,
double *acoustic_like_sum = NULL);