Commit b47a690a authored by Arnab Ghoshal's avatar Arnab Ghoshal
Browse files

Merging trunk/src r435 changes to sandbox/discrim/src.

git-svn-id: https://svn.code.sf.net/p/kaldi/code/sandbox/discrim@436 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
parent 14af3a92
......@@ -7,7 +7,7 @@
SUBDIRS = base matrix util feat tree optimization gmm tied transform sgmm \
fstext hmm lm decoder lat \
bin fstbin gmmbin fgmmbin tiedbin sgmmbin featbin \
nnet nnetbin latbin
nnet nnetbin latbin rnn
all: $(SUBDIRS)
echo Done
......
......@@ -7,6 +7,7 @@
=====
minor:
remove KALDI_EXIT and modify KALDI_ERR macro to behave like it.
Make ApplySoftMax() function more efficient (does two exp()'s per element)
gmm-latgen-faster [Not started]. Need to write internal code for this (Mirko, but possibly wait a bit till I finalize lattice-simple-decoder.). As gmm-latgen-simple but not using std::unordered_map for current + previous frame's tokens-- would use the HashList stuff.
......
......@@ -100,6 +100,18 @@ void UnitTestRand() {
if (std::abs((float)sum) < 0.5*sqrt((double)j)*(maxint-minint)) break;
}
}
{ // test RandPrune in basic way.
KALDI_ASSERT(RandPrune(1.1, 1.0) == 1.1);
KALDI_ASSERT(RandPrune(0.0, 0.0) == 0.0);
KALDI_ASSERT(RandPrune(-1.1, 1.0) == -1.1);
KALDI_ASSERT(RandPrune(0.0, 1.0) == 0.0);
KALDI_ASSERT(RandPrune(0.5, 1.0) >= 0.0);
KALDI_ASSERT(RandPrune(-0.5, 1.0) <= 0.0);
BaseFloat f = RandPrune(-0.5, 1.0);
KALDI_ASSERT(f == 0.0 || f == -1.0);
f = RandPrune(0.5, 1.0);
KALDI_ASSERT(f == 0.0 || f == 1.0);
}
}
}
......
......@@ -33,20 +33,20 @@ int32 RoundUpToNearestPowerOfTwo(int32 n) {
return n+1;
}
int32 RandInt(int32 min, int32 max) { // This is not exact.
assert(max >= min);
if (max == min) return min;
int32 RandInt(int32 min_val, int32 max_val) { // This is not exact.
assert(max_val >= min_val);
if (max_val == min_val) return min_val;
#ifdef _MSC_VER
// RAND_MAX is quite small on Windows -> may need to handle larger numbers.
if (RAND_MAX > (max-min)*8) {
if (RAND_MAX > (max_val-min_val)*8) {
// *8 to avoid large inaccuracies in probability, from the modulus...
return min + ((unsigned int)rand() % (unsigned int)(max+1-min));
return min_val + ((unsigned int)rand() % (unsigned int)(max_val+1-min_val));
} else {
if ((unsigned int)(RAND_MAX*RAND_MAX) > (unsigned int)((max+1-min)*8)) {
if ((unsigned int)(RAND_MAX*RAND_MAX) > (unsigned int)((max_val+1-min_val)*8)) {
// *8 to avoid inaccuracies in probability, from the modulus...
return min + ( (unsigned int)( (rand()+RAND_MAX*rand()))
% (unsigned int)(max+1-min));
return min_val + ( (unsigned int)( (rand()+RAND_MAX*rand()))
% (unsigned int)(max_val+1-min_val));
} else {
throw std::runtime_error(std::string()
+"rand_int failed because we do not support "
......@@ -55,7 +55,8 @@ int32 RandInt(int32 min, int32 max) { // This is not exact.
}
}
#else
return min + ((unsigned int32)rand() % (unsigned int32)(max+1-min));
return min_val +
(static_cast<int32>(rand()) % (int32)(max_val+1-min_val));
#endif
}
......
......@@ -55,7 +55,13 @@
# define M_SQRT1_2 0.7071067811865475244008443621048490
#endif
#ifndef M_LOG_2PI
#define M_LOG_2PI 1.8378770664093454835606594728112
#endif
#ifndef M_LN2
#define M_LN2 0.693147180559945309417232121458
#endif
#ifdef _MSC_VER
# define KALDI_ISNAN _isnan
......@@ -98,6 +104,17 @@ inline float RandGauss() {
// to lambda. Faster algorithms exist but are more complex.
int32 RandPoisson(float lambda);
// This is a randomized pruning mechanism that preserves expectations,
// that we typically use to prune posteriors.
template<class Float>
inline Float RandPrune(Float post, BaseFloat prune_thresh) {
KALDI_ASSERT(prune_thresh >= 0.0);
if (post == 0.0 || std::abs(post) >= prune_thresh)
return post;
return (post >= 0 ? 1.0 : -1.0) *
(RandUniform() <= fabs(post)/prune_thresh ? prune_thresh : 0.0);
}
static const double kMinLogDiffDouble = std::log(DBL_EPSILON); // negative!
static const float kMinLogDiffFloat = std::log(FLT_EPSILON); // negative!
......
......@@ -11,7 +11,7 @@ BINFILES = align-equal align-equal-compiled acc-tree-stats \
ali-to-phones ali-to-post weight-silence-post acc-lda est-lda \
ali-to-pdf est-mllt build-tree build-tree-two-level decode-faster \
decode-faster-mapped scale-vecs copy-transition-model rand-prune-post
phones-to-prons prons-to-wordali
OBJFILES =
......
......@@ -36,8 +36,10 @@ int main(int argc, char *argv[]) {
" ali-to-post ark:1.ali ark:- | lda-acc 1.mdl \"ark:splice-feats scp:train.scp|\" ark:- ldaacc.1\n";
bool binary = true;
BaseFloat rand_prune = 0.0;
ParseOptions po(usage);
po.Register("binary", &binary, "Write accumulators in binary mode.");
po.Register("rand-prune", &rand_prune, "Randomized pruning threshold for posteriors");
po.Read(argc, argv);
if (po.NumArgs() != 4) {
......@@ -94,9 +96,11 @@ int main(int argc, char *argv[]) {
SubVector<BaseFloat> feat(feats, i);
for (size_t j = 0; j < post[i].size(); j++) {
int32 tid = post[i][j].first;
BaseFloat weight = post[i][j].second;
int32 pdf = trans_model.TransitionIdToPdf(tid);
lda.Accumulate(feat, pdf, weight);
BaseFloat weight = RandPrune(post[i][j].second, rand_prune);
if (weight != 0.0) {
int32 pdf = trans_model.TransitionIdToPdf(tid);
lda.Accumulate(feat, pdf, weight);
}
}
}
num_done++;
......
......@@ -37,7 +37,7 @@ int main(int argc, char *argv[]) {
"Add self-loops and transition probabilities to transducer, expanding to transition-ids\n"
"Usage: add-self-loops [options] transition-gmm/acoustic-model [fst-in] [fst-out]\n"
"e.g.: \n"
" add-self-loops --self-loop-scale = 0.1 1.mdl < HCLG_noloops.fst > HCLG_full.fst\n";
" add-self-loops --self-loop-scale=0.1 1.mdl < HCLG_noloops.fst > HCLG_full.fst\n";
BaseFloat self_loop_scale = 1.0;
bool reorder = true;
......
......@@ -33,9 +33,14 @@ int main(int argc, char *argv[]) {
" ali-to-phones 1.mdl ark:1.ali ark:phones.tra\n";
ParseOptions po(usage);
bool per_frame = false;
bool write_lengths = false;
po.Register("per-frame", &per_frame, "If true, write out the frame-level phone alignment (else phone sequence)");
po.Register("write-lengths", &write_lengths, "If true, write the #frames for each phone (different format)");
po.Read(argc, argv);
KALDI_ASSERT(!(per_frame && write_lengths) && "Incompatible options.");
if (po.NumArgs() != 3) {
po.PrintUsage();
exit(1);
......@@ -52,10 +57,13 @@ int main(int argc, char *argv[]) {
trans_model.Read(is.Stream(), binary);
}
SequentialInt32VectorReader reader(alignments_rspecifier);
Int32VectorWriter writer(phones_wspecifier);
std::string empty;
Int32VectorWriter phones_writer(write_lengths ? empty : phones_wspecifier);
Int32PairVectorWriter pair_writer(write_lengths ? phones_wspecifier : empty);
int32 n_done = 0;
for (; !reader.Done(); reader.Next()) {
std::string key = reader.Key();
const std::vector<int32> &alignment = reader.Value();
......@@ -63,19 +71,38 @@ int main(int argc, char *argv[]) {
std::vector<std::vector<int32> > split;
SplitToPhones(trans_model, alignment, &split);
std::vector<int32> phones;
for (size_t i = 0; i < split.size(); i++) {
KALDI_ASSERT(split[i].size() > 0);
int32 tid = split[i][0],
tstate = trans_model.TransitionIdToTransitionState(tid),
phone = trans_model.TransitionStateToPhone(tstate);
int32 num_repeats = (per_frame ?
static_cast<int32>(split[i].size()) : 1);
for(int32 j = 0; j < num_repeats; j++)
phones.push_back(phone);
if (!write_lengths) {
std::vector<int32> phones;
for (size_t i = 0; i < split.size(); i++) {
KALDI_ASSERT(split[i].size() > 0);
int32 tid = split[i][0],
tstate = trans_model.TransitionIdToTransitionState(tid),
phone = trans_model.TransitionStateToPhone(tstate);
int32 num_repeats = split[i].size();
KALDI_ASSERT(num_repeats!=0);
if (per_frame)
for(int32 j = 0; j < num_repeats; j++)
phones.push_back(phone);
else
phones.push_back(phone);
}
phones_writer.Write(key, phones);
} else {
std::vector<std::pair<int32,int32> > pairs;
for (size_t i = 0; i < split.size(); i++) {
KALDI_ASSERT(split[i].size() > 0);
int32 tid = split[i][0],
tstate = trans_model.TransitionIdToTransitionState(tid),
phone = trans_model.TransitionStateToPhone(tstate);
int32 num_repeats = split[i].size();
KALDI_ASSERT(num_repeats!=0);
pairs.push_back(std::make_pair(phone, num_repeats));
}
pair_writer.Write(key, pairs);
}
writer.Write(key, phones);
n_done++;
}
KALDI_LOG << "Done " << n_done << " utterances.";
} catch(const std::exception& e) {
std::cerr << e.what();
return -1;
......
......@@ -75,7 +75,7 @@ int main(int argc, char *argv[]) {
std::ifstream is(lex_in_filename.c_str());
if (!is.good()) KALDI_EXIT << "Could not open lexicon FST " << (std::string)lex_in_filename;
lex_fst =
VectorFst<StdArc>::Read(is, fst::FstReadOptions((std::string)lex_in_filename));
VectorFst<StdArc>::Read(is, fst::FstReadOptions(lex_in_filename));
if (lex_fst == NULL)
KALDI_EXIT << "Could not open lexicon FST "<<lex_in_filename;
}
......
......@@ -82,7 +82,7 @@ int main(int argc, char *argv[]) {
std::ifstream is(lex_in_filename.c_str());
if (!is.good()) KALDI_EXIT << "Could not open lexicon FST " << (std::string)lex_in_filename;
lex_fst =
VectorFst<StdArc>::Read(is, fst::FstReadOptions((std::string)lex_in_filename));
VectorFst<StdArc>::Read(is, fst::FstReadOptions(lex_in_filename));
if (lex_fst == NULL)
exit(1);
}
......
......@@ -91,11 +91,9 @@ int main(int argc, char *argv[]) {
// lot of virtual memory.
VectorFst<StdArc> *decode_fst = NULL;
{
std::ifstream is(fst_in_filename.c_str(), std::ifstream::binary);
if (!is.good()) KALDI_EXIT << "Could not open decoding-graph FST "
<< fst_in_filename;
Input ki(fst_in_filename.c_str());
decode_fst =
VectorFst<StdArc>::Read(is, fst::FstReadOptions((std::string)fst_in_filename));
VectorFst<StdArc>::Read(ki.Stream(), fst::FstReadOptions(fst_in_filename));
if (decode_fst == NULL) // fst code will warn.
exit(1);
}
......
......@@ -81,11 +81,9 @@ int main(int argc, char *argv[]) {
// lot of virtual memory.
VectorFst<StdArc> *decode_fst = NULL;
{
std::ifstream is(fst_in_filename.c_str(), std::ifstream::binary);
if (!is.good()) KALDI_EXIT << "Could not open decoding-graph FST "
<< fst_in_filename;
Input ki(fst_in_filename.c_str());
decode_fst =
VectorFst<StdArc>::Read(is, fst::FstReadOptions((std::string)fst_in_filename));
VectorFst<StdArc>::Read(ki.Stream(), fst::FstReadOptions(fst_in_filename));
if (decode_fst == NULL) // fst code will warn.
exit(1);
}
......
// bin/phones-to-prons.cc
// Copyright 2009-2011 Microsoft Corporation
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "base/kaldi-common.h"
#include "hmm/transition-model.h"
#include "hmm/hmm-utils.h"
#include "util/common-utils.h"
#include "fst/fstlib.h"
#include "fstext/fstext-lib.h"
// Create FST that accepts the phone sequence, with any number
// of word-start and word-end symbol in between each phone.
void CreatePhonesAltFst(const std::vector<int32> &phones,
int32 word_start_sym,
int32 word_end_sym,
fst::VectorFst<fst::StdArc> *ofst) {
using fst::StdArc;
typedef fst::StdArc::StateId StateId;
typedef fst::StdArc::Weight Weight;
ofst->SetStart(0);
StateId cur_s = ofst->AddState();
for (size_t i = 0; i < phones.size(); i++) {
StateId next_s = ofst->AddState();
// add arc to next state.
ofst->AddArc(cur_s, StdArc(phones[i], phones[i], Weight::One(),
next_s));
cur_s = next_s;
}
for (StateId s = 0; s <= cur_s; s++) {
ofst->AddArc(s, StdArc(word_end_sym, word_end_sym,
Weight::One(), s));
ofst->AddArc(s, StdArc(word_start_sym, word_start_sym,
Weight::One(), s));
}
ofst->SetFinal(cur_s, Weight::One());
{
fst::OLabelCompare<StdArc> olabel_comp;
ArcSort(ofst, olabel_comp);
}
}
int main(int argc, char *argv[]) {
using namespace kaldi;
using fst::VectorFst;
using fst::StdArc;
typedef kaldi::int32 int32;
try {
const char *usage =
"Convert pairs of (phone-level, word-level) transcriptions to\n"
"output that indicates the phones assigned to each word.\n"
"Format is standard format for archives of vector<vector<int32> >\n"
"i.e. :\n"
"utt-id 600 4 7 19 ; 512 4 18 ; 0 1\n"
"where 600, 512 and 0 are the word-ids (0 for non-word phones, e.g.\n"
"optional-silence introduced by the lexicon), and the phone-ids\n"
"follow the word-ids.\n"
"Note: L_align.fst must have word-start and word-end symbols in it\n"
"\n"
"Usage: phones-to-prons [options] <L_align.fst> <word-start-sym> "
"<word-end-sym> <phones-rspecifier> <words-rspecifier> <prons-wspecifier>\n"
"e.g.: \n"
" ali-to-phones 1.mdl ark:1.ali ark:- | \\\n"
" phones-to-prons L_align.fst 46 47 ark:- 1.tra ark:1.prons\n";
ParseOptions po(usage);
po.Read(argc, argv);
if (po.NumArgs() != 6) {
po.PrintUsage();
exit(1);
}
std::string lex_fst_filename = po.GetArg(1),
word_start_sym_str = po.GetArg(2),
word_end_sym_str = po.GetArg(3),
phones_rspecifier = po.GetArg(4),
words_rspecifier = po.GetArg(5),
prons_wspecifier = po.GetArg(6);
int32 word_start_sym, word_end_sym;
if (!ConvertStringToInteger(word_start_sym_str, &word_start_sym)
|| word_start_sym <= 0)
KALDI_EXIT << "Invalid word start symbol (expecting integer >= 0): "
<< word_start_sym_str;
if (!ConvertStringToInteger(word_end_sym_str, &word_end_sym)
|| word_end_sym <= 0 || word_end_sym == word_start_sym)
KALDI_EXIT << "Invalid word end symbol (expecting integer >= 0"
<< ", different from word start symbol): "
<< word_end_sym_str;
// L should be lexicon with word start and end symbols marked.
VectorFst<StdArc> *L = NULL;
{
Input ki(lex_fst_filename);
L = VectorFst<StdArc>::Read(ki.Stream(),
fst::FstReadOptions(lex_fst_filename));
if (L == NULL) // fst code will warn.
exit(1);
// Make sure that L is sorted on output symbol (words).
fst::OLabelCompare<StdArc> olabel_comp;
ArcSort(L, olabel_comp);
}
SequentialInt32VectorReader phones_reader(phones_rspecifier);
RandomAccessInt32VectorReader words_reader(words_rspecifier);
int32 n_done = 0, n_err = 0;
std::string empty;
Int32VectorVectorWriter prons_writer(prons_wspecifier);
for (; !phones_reader.Done(); phones_reader.Next()) {
std::string key = phones_reader.Key();
const std::vector<int32> &phones = phones_reader.Value();
if (!words_reader.HasKey(key)) {
KALDI_WARN << "Not processing utterance " << key << " because no word "
<< "transcription found.";
n_err++;
continue;
}
const std::vector<int32> &words = words_reader.Value(key);
// convert word alignment to acceptor and compose it with lexicon.
// phn2word will have phones (and word start/end symbols) on its
// input, and words on its output. It will enode the alternative
// pronunciations of this word-sequence, with word start and end
// symbols at the appropriate places.
VectorFst<StdArc> phn2word;
{
VectorFst<StdArc> words_acceptor;
MakeLinearAcceptor(words, &words_acceptor);
Compose(*L, words_acceptor, &phn2word);
}
if (phn2word.Start() == fst::kNoStateId) {
KALDI_WARN << "Phone to word FST is empty (possible mismatch in lexicon?)";
n_err++;
continue;
}
VectorFst<StdArc> phones_alt_fst;
CreatePhonesAltFst(phones, word_start_sym, word_end_sym, &phones_alt_fst);
// phnx2word will have phones and word-start and word-end symbols
// on the input side, and words on the output side.
VectorFst<StdArc> phnx2word;
Compose(phones_alt_fst, phn2word, &phnx2word);
if (phnx2word.Start() == fst::kNoStateId) {
KALDI_WARN << "phnx2word FST is empty (possible mismatch in lexicon?)";
n_err++;
continue;
}
// Now get the best path in phnx2word.
VectorFst<StdArc> phnx2word_best;
ShortestPath(phnx2word, &phnx2word_best);
// Now get seqs of phones and words.
std::vector<int32> phnx, words2;
StdArc::Weight garbage;
if (!fst::GetLinearSymbolSequence(phnx2word_best,
&phnx, &words2, &garbage))
KALDI_ERR << "phnx2word is not a linear transducer (code error?)";
if (words2 != words)
KALDI_ERR << "words have changed! (code error?)";
// Now, "phnx" should be the phone sequence with start and end
// symbols included. At this point we break it up into segments,
// and try to match it up with words.
std::vector<std::vector<int32> > prons;
if (!ConvertPhnxToProns(phnx, words,
word_start_sym, word_end_sym,
&prons)) {
KALDI_WARN << "Error converting phones and words to prons "
<< " (mismatched or non-marked lexicon or partial "
<< " alignment?)";
n_err++;
continue;
}
prons_writer.Write(key, prons);
n_done++;
}
KALDI_LOG << "Done " << n_done << " utterances; " << n_err << " had errors.";
} catch(const std::exception& e) {
std::cerr << e.what();
return -1;
}
}
// bin/prons-to-wordali.cc
// Copyright 2009-2011 Microsoft Corporation
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "base/kaldi-common.h"
#include "hmm/transition-model.h"
#include "hmm/hmm-utils.h"
#include "util/common-utils.h"
#include "fst/fstlib.h"
#include "fstext/fstext-utils.h"
int main(int argc, char *argv[]) {
using namespace kaldi;
using fst::VectorFst;
using fst::StdArc;
typedef kaldi::int32 int32;
try {
const char *usage =
"Given per-utterance pronunciation information as output by \n"
"words-to-prons, and per-utterance phone alignment information\n"
"as output by ali-to-phones --write-lengths, output word alignment\n"
"information that can be turned into the ctm format.\n"
"Outputs is pairs of (word, #frames), or if --per-frame is given,\n"
"just the word for each frame.\n"
"Note: zero word-id usually means optional silence.\n"
"Format is standard format for archives of vector<pair<int32, int32> >\n"
"i.e. :\n"
"utt-id 600 22 ; 1028 32 ; 0 41\n"
"where 600, 1028 and 0 are the word-ids, and 22, 32 and 41 are the\n"
"lengths.\n"
"\n"
"Usage: prons-to-wordali [options] <prons-rspecifier>"
" <phone-lengths-rspecifier> <wordali-wspecifier>\n"
"e.g.: \n"
" ali-to-phones 1.mdl ark:1.ali ark:- | \\\n"
" phones-to-prons L_align.fst 46 47 ark:- 1.tra ark:- | \\\n"
" prons-to-wordali ark:- \\\n"
" \"ark:ali-to-phones --write-lengths 1.mdl ark:1.ali ark:-|\" ark:1.wali\n";
ParseOptions po(usage);
bool per_frame = false;
po.Register("per-frame", &per_frame, "If true, write out the frame-level word alignment (else word sequence)");
po.Read(argc, argv);
if (po.NumArgs() != 3) {
po.PrintUsage();
exit(1);
}
std::string prons_rspecifier = po.GetArg(1),
phone_lengths_rspecifier = po.GetArg(2),
wordali_wspecifier = po.GetArg(3);
SequentialInt32VectorVectorReader prons_reader(prons_rspecifier);
RandomAccessInt32PairVectorReader phones_reader(phone_lengths_rspecifier);
std::string empty;
Int32PairVectorWriter pair_writer(per_frame ? empty : wordali_wspecifier);
Int32VectorWriter frame_writer(per_frame ? wordali_wspecifier : empty);
int32 n_done = 0, n_err = 0;
for (; !prons_reader.Done(); prons_reader.Next()) {
std::string key = prons_reader.Key();
const std::vector<std::vector<int32> > &prons = prons_reader.Value();
if (!phones_reader.HasKey(key)) {
KALDI_WARN << "Not processing utterance " << key << " because no phone "
<< "alignment found.";
n_err++;
continue;
}
// first member of each pair is phone; second is length in
// frames.
const std::vector<std::pair<int32, int32> > &phones =
phones_reader.Value(key);
std::vector<std::pair<int32, int32> > word_alignment;
size_t p = 0; // index into "phones".
for (size_t i = 0; i < prons.size(); i++) {
if (!(prons[i].size() >= 1)) {
KALDI_WARN << "Invalid, empty pronunciation.";
n_err++;
continue;
}
int32 word = prons[i][0], word_len = 0;
for (size_t j = 1; j < prons[i].size(); j++, p++) {
if (!(static_cast<size_t>(p) < phones.size() &&