Commit 6581d035 authored by Guoguo Chen's avatar Guoguo Chen
Browse files

trunk: fixes to lm reading code regarding <s> and </s>

git-svn-id: https://svn.code.sf.net/p/kaldi/code/trunk@5148 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
parent fc548ffa
#!/usr/bin/env perl
# Copyright 2010-2011 Microsoft Corporation
# 2015 Guoguo Chen
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -18,6 +19,11 @@
# acceptor.
while(<>){
s:^(\d+\s+\d+\s+)\<eps\>(\s+):$1#0$2:;
print;
if (/\s+#0\s+/) {
print STDERR "$0: ERROR: LM has word #0, " .
"which is reserved as disambiguation symbol\n";
exit 1;
}
s:^(\d+\s+\d+\s+)\<eps\>(\s+):$1#0$2:;
print;
}
......@@ -31,12 +31,12 @@ namespace kaldi {
// typedef fst::StdArc::StateId StateId;
// newlyAdded will be updated
// newly_added will be updated
LmFstConverter::StateId LmFstConverter::AddStateFromSymb(
const std::vector<string> &ngramString,
int kstart, int kend,
fst::StdVectorFst *pfst,
bool &newlyAdded) {
bool &newly_added) {
fst::StdArc::StateId sid;
std::string separator;
separator.resize(1);
......@@ -52,12 +52,12 @@ LmFstConverter::StateId LmFstConverter::AddStateFromSymb(
}
}
newlyAdded = false;
newly_added = false;
sid = FindState(hist);
if (sid < 0) {
sid = pfst->AddState();
histState_[hist] = sid;
newlyAdded = true;
hist_state_[hist] = sid;
newly_added = true;
//cerr << "Created state " << sid << " for " << hist << endl;
} else {
//cerr << "State symbol " << hist << " already exists" << endl;
......@@ -71,14 +71,16 @@ void LmFstConverter::ConnectUnusedStates(fst::StdVectorFst *pfst) {
// go through all states with a recorded backoff destination
// and find out any that has no output arcs and is not final
unsigned int connected = 0;
// cerr << "ConnectUnusedStates has recorded "<<bkState_.size()<<" states.\n";
// cerr << "ConnectUnusedStates has recorded "<<backoff_state_.size()<<" states.\n";
for (BkStateMap::iterator bkit = bkState_.begin(); bkit != bkState_.end(); ++bkit) {
for (BackoffStateMap::iterator bkit = backoff_state_.begin(); bkit != backoff_state_.end(); ++bkit) {
// add an output arc to its backoff destination recorded in backoff_
fst::StdArc::StateId src = bkit->first, dst = bkit->second;
if (pfst->NumArcs(src)==0 && !IsFinal(pfst, src)) {
// cerr << "ConnectUnusedStates: adding arc from "<<src<<" to "<<dst<<endl;
pfst->AddArc(src, fst::StdArc(0, 0, fst::StdArc::Weight::One(), dst)); // epsilon arc with no cost
// epsilon arc with no cost
pfst->AddArc(src,
fst::StdArc(0, 0, fst::StdArc::Weight::One(), dst));
connected++;
}
}
......@@ -86,53 +88,56 @@ void LmFstConverter::ConnectUnusedStates(fst::StdVectorFst *pfst) {
}
void LmFstConverter::AddArcsForNgramProb(
int ilev, int maxlev,
int ngram_order, int max_ngram_order,
float logProb,
float logBow,
std::vector<string> &ngs,
std::vector<string> &ngram,
fst::StdVectorFst *fst,
const string startSent,
const string endSent) {
fst::StdArc::StateId src, dst, dbo;
std::string curwrd = ngs[1];
std::string curwrd = ngram[1];
if (curwrd == "<eps>") {
KALDI_ERR << "The word <eps> is not allowed as a word in an ARPA LM.";
}
int64 ilab, olab;
LmWeight prob = ConvertArpaLogProbToWeight(logProb);
LmWeight bow = ConvertArpaLogProbToWeight(logBow);
bool newSrc, newDbo, newDst = false;
if (ilev >= 2) {
if (ngram_order >= 2) {
// General case works from N down to 2-grams
src = AddStateFromSymb(ngs, ilev, 2, fst, newSrc);
if (ilev != maxlev) {
src = AddStateFromSymb(ngram, ngram_order, 2, fst, newSrc);
if (ngram_order != max_ngram_order) {
// add all intermediate levels from 2 to current
// last ones will be current backoff source and destination
for (int iilev=2; iilev <= ilev; iilev++) {
dst = AddStateFromSymb(ngs, iilev, 1, fst, newDst);
dbo = AddStateFromSymb(ngs, iilev-1, 1, fst, newDbo);
bkState_[dst] = dbo;
for (int i = 2; i <= ngram_order; i++) {
dst = AddStateFromSymb(ngram, i, 1, fst, newDst);
dbo = AddStateFromSymb(ngram, i-1, 1, fst, newDbo);
backoff_state_[dst] = dbo;
}
} else {
// add all intermediate levels from 2 to current
// last ones will be current backoff source and destination
for (int iilev=2; iilev <= ilev; iilev++) {
dst = AddStateFromSymb(ngs, iilev-1, 1, fst, newDst);
dbo = AddStateFromSymb(ngs, iilev-2, 1, fst, newDbo);
bkState_[dst] = dbo;
for (int i = 2; i <= ngram_order; i++) {
dst = AddStateFromSymb(ngram, i-1, 1, fst, newDst);
dbo = AddStateFromSymb(ngram, i-2, 1, fst, newDbo);
backoff_state_[dst] = dbo;
}
}
} else {
// special case for 1-grams: start from 0-gram
if (curwrd.compare(startSent) != 0) {
src = AddStateFromSymb(ngs, 0, 1, fst, newSrc);
src = AddStateFromSymb(ngram, 0, 1, fst, newSrc);
} else {
// extra special case if in addition we are at beginning of sentence
// starts from initial state and has no cost
src = fst->Start();
prob = fst::StdArc::Weight::One();
}
dst = AddStateFromSymb(ngs, 1, 1, fst, newDst);
dbo = AddStateFromSymb(ngs, 0, 1, fst, newDbo);
bkState_[dst] = dbo;
dst = AddStateFromSymb(ngram, 1, 1, fst, newDst);
dbo = AddStateFromSymb(ngram, 0, 1, fst, newDbo);
backoff_state_[dst] = dbo;
}
// state is final if last word is end of sentence
......@@ -175,7 +180,7 @@ bool LmTable::ReadFstFromLmFile(std::istream &istrm,
// do not use state symbol table for word histories anymore
string inpline;
size_t pos1, pos2;
int ilev, maxlev = 0;
int ngram_order, max_ngram_order = 0;
// process \data\ section
......@@ -201,12 +206,12 @@ bool LmTable::ReadFstFromLmFile(std::istream &istrm,
continue; // not valid, continue looking
}
// found valid line
ilev = atoi(inpline.substr(pos1+5, pos2-(pos1+5)).c_str());
if (ilev > maxlev) {
maxlev = ilev;
ngram_order = atoi(inpline.substr(pos1+5, pos2-(pos1+5)).c_str());
if (ngram_order > max_ngram_order) {
max_ngram_order = ngram_order;
}
}
if (maxlev == 0) {
if (max_ngram_order == 0) {
// reached end of loop without having found any n-gram
KALDI_ERR << "No ngrams found in specified file";
}
......@@ -222,8 +227,8 @@ bool LmTable::ReadFstFromLmFile(std::istream &istrm,
continue; // not valid line, continue looking for one
}
// found, set current level
ilev = atoi(inpline.substr(pos1+1, pos2-(pos1+1)).c_str());
cerr << "Processing " << ilev <<"-grams" << endl;
ngram_order = atoi(inpline.substr(pos1+1, pos2-(pos1+1)).c_str());
cerr << "Processing " << ngram_order << "-grams" << endl;
// process individual n-grams
while (getline(istrm, inpline) && !istrm.eof()) {
......@@ -248,21 +253,26 @@ bool LmTable::ReadFstFromLmFile(std::istream &istrm,
// found, parse probability from first field
prob = KALDI_STRTOF(cur_cstr, &next_cstr);
if (prob != prob || prob - prob != 0) {
KALDI_ERR << "nan or inf detected in LM file [parsing " << (ilev)
KALDI_ERR << "nan or inf detected in LM file [parsing " << (ngram_order)
<< "-grams]: " << inpline;
}
if (next_cstr == cur_cstr)
KALDI_ERR << "Bad line in LM file [parsing "<<(ilev)<<"-grams]: "<<inpline;
KALDI_ERR << "Bad line in LM file [parsing "<<(ngram_order)<<"-grams]: "<<inpline;
cur_cstr = next_cstr;
while (*cur_cstr && isspace(*cur_cstr))
cur_cstr++;
for (int i = 0; i < ilev; i++) {
// element 0 will be empty, element 1 will be the current word,
// element 2 will be the immediately preceding word, and so on.
// Apparently an IRSTLM convention.
ngramString.resize(ngram_order + 1);
bool illegal_bos_or_eos = false;
for (int i = 0; i < ngram_order; i++) {
if (*cur_cstr == '\0')
KALDI_ERR << "Bad line in LM file [parsing "<<(ilev)<<"-grams]: "<<inpline;
KALDI_ERR << "Bad line in LM file [parsing "<<(ngram_order)<<"-grams]: "<<inpline;
const char *end_cstr = strpbrk(cur_cstr, " \t");
const char *end_cstr = strpbrk(cur_cstr, " \t\r");
std::string this_word;
if (end_cstr == NULL) {
this_word = std::string(cur_cstr);
......@@ -274,33 +284,45 @@ bool LmTable::ReadFstFromLmFile(std::istream &istrm,
cur_cstr++;
}
// words are inserted so position 1 is most recent word,
// and position N oldest word (IRSTLM convention)
ngramString.insert(ngramString.begin(), this_word);
// Checks if <s> only appears at the beginning of the ngram, and if </s>
// only appears at the end of the ngram.
if ((ngram_order > 1 && i != 0 && this_word == "<s>") ||
(ngram_order > 1 && i != ngram_order - 1 && this_word == "</s>")) {
illegal_bos_or_eos = true;
break;
}
ngramString[ngram_order - i].swap(this_word);
}
// reserve an element 0 so that words go from 1, ..., ng.size-1
ngramString.insert(ngramString.begin(), "");
if (illegal_bos_or_eos) {
KALDI_WARN << "<s> is not at the beginning of the n-gram, or </s> is "
<< "not at the end of the n-gram, skipping it: " << inpline;
continue;
}
bow = 0;
if (ilev < maxlev) {
if (ngram_order < max_ngram_order) {
// try converting anything left in the line to a backoff weight
if (*cur_cstr != '\0') {
char *end_cstr;
bow = KALDI_STRTOF(cur_cstr, &end_cstr);
if (bow != bow || bow - bow != 0) {
KALDI_ERR << "nan or inf detected in LM file [parsing " << (ilev)
KALDI_ERR << "nan or inf detected in LM file [parsing " << (ngram_order)
<< "-grams]: " << inpline;
}
if (end_cstr != cur_cstr) { // got something.
while (*end_cstr != '\0' && isspace(*end_cstr))
end_cstr++;
if (*end_cstr != '\0')
KALDI_ERR << "Junk "<<(end_cstr)<<" at end of line [parsing "<<(ilev)<<"-grams]"<<inpline;
KALDI_ERR << "Junk " << (end_cstr) << " at end of line [parsing "
<< (ngram_order) << "-grams]" << inpline;
} else {
KALDI_ERR << "Junk "<<(cur_cstr)<<" at end of line [parsing "<<(ilev)<<"-grams]"<<inpline;
KALDI_ERR << "Junk " << (cur_cstr) << " at end of line [parsing "
<< (ngram_order) << "-grams]" << inpline;
}
}
}
conv_->AddArcsForNgramProb(ilev, maxlev, prob, bow,
conv_->AddArcsForNgramProb(ngram_order, max_ngram_order, prob, bow,
ngramString, fst,
startSent, endSent);
} // end of loop on individual n-gram lines
......@@ -346,7 +368,7 @@ void LmTable::DumpStart(ngram ng,
fst::SymbolTable *pStateSymbs = new fst::SymbolTable("kaldi-lm-state");
// dump level by level
for (int l = 1; l <= maxlev; l++) {
for (int l = 1; l <= max_ngram_order; l++) {
ng.size = 0;
cerr << "Processing " << l << "-grams" << endl;
DumpContinue(ng, 1, l, 0, cursize[1],
......@@ -358,39 +380,39 @@ void LmTable::DumpStart(ngram ng,
}
// run through given levels and positions in table
void LmTable::DumpContinue(ngram ng, int ilev, int elev,
void LmTable::DumpContinue(ngram ng, int ngram_order, int elev,
table_entry_pos_t ipos, table_entry_pos_t epos,
fst::StdVectorFst *fst,
fst::SymbolTable *pStateSymbs,
const string startSent, const string endSent) {
LMT_TYPE ndt = tbltype[ilev];
LMT_TYPE ndt = tbltype[ngram_order];
ngram ing(ng.dict);
int ndsz = nodesize(ndt);
#ifdef KALDI_PARANOID
KALDI_ASSERT(ng.size == ilev - 1);
KALDI_ASSERT(ipos >= 0 && epos <= cursize[ilev] && ipos < epos);
KALDI_ASSERT(ng.size == ngram_order - 1);
KALDI_ASSERT(ipos >= 0 && epos <= cursize[ngram_order] && ipos < epos);
KALDI_ASSERT(pStateSymbs);
#endif
ng.pushc(0);
for (table_entry_pos_t i = ipos; i < epos; i++) {
*ng.wordp(1) = word(table[ilev] + (table_pos_t)i * ndsz);
float ipr = prob(table[ilev] + (table_pos_t)i * ndsz, ndt);
// int ipr = prob(table[ilev] + i * ndsz, ndt);
*ng.wordp(1) = word(table[ngram_order] + (table_pos_t)i * ndsz);
float ipr = prob(table[ngram_order] + (table_pos_t)i * ndsz, ndt);
// int ipr = prob(table[ngram_order] + i * ndsz, ndt);
// skip pruned n-grams
if (isPruned && ipr == NOPROB) continue;
if (ilev < elev) {
if (ngram_order < elev) {
// get first and last successor position
table_entry_pos_t isucc = (i > 0 ? bound(table[ilev] +
table_entry_pos_t isucc = (i > 0 ? bound(table[ngram_order] +
(table_pos_t) (i-1) * ndsz,
ndt) : 0);
table_entry_pos_t esucc = bound(table[ilev] +
table_entry_pos_t esucc = bound(table[ngram_order] +
(table_pos_t) i * ndsz, ndt);
if (isucc < esucc) // there are successors!
DumpContinue(ng, ilev+1, elev, isucc, esucc,
DumpContinue(ng, ngram_order+1, elev, isucc, esucc,
fst, pStateSymbs, startSent, endSent);
// else
// cerr << "no successors for " << ng << "\n";
......@@ -405,7 +427,7 @@ void LmTable::DumpContinue(ngram ng, int ilev, int elev,
ng = ing;
}
// cerr << "ilev " << ilev << " ngsize " << ng.size << endl;
// cerr << "ngram_order " << ngram_order << " ngsize " << ng.size << endl;
// for FST creation: vector of words strings
std::vector<string> ngramString;
......@@ -418,13 +440,13 @@ void LmTable::DumpContinue(ngram ng, int ilev, int elev,
// reserve index 0 so that words go from 1, .., ng.size-1
ngramString.insert(ngramString.begin(), "");
float ibo = 0;
if (ilev < maxlev) {
if (ngram_order < max_ngram_order) {
// Backoff
ibo = bow(table[ilev]+ (table_pos_t)i * ndsz, ndt);
ibo = bow(table[ngram_order]+ (table_pos_t)i * ndsz, ndt);
// if (isQtable) cerr << "\t" << ibo;
// else if (ibo != 0.0) cerr << "\t" << ibo;
}
conv_->AddArcsForNgramProb(ilev, maxlev, ipr, ibo,
conv_->AddArcsForNgramProb(ngram_order, max_ngram_order, ipr, ibo,
ngramString, fst, pStateSymbs,
startSent, endSent);
}
......
......@@ -68,7 +68,7 @@ class LmFstConverter {
typedef fst::StdArc::Weight LmWeight;
typedef fst::StdArc::StateId StateId;
typedef unordered_map<StateId, StateId> BkStateMap;
typedef unordered_map<StateId, StateId> BackoffStateMap;
typedef unordered_map<std::string, StateId, StringHasher> HistStateMap;
public:
......@@ -110,19 +110,17 @@ class LmFstConverter {
int kstart,
int kend,
fst::StdVectorFst *pfst,
bool &newlyAdded);
bool &newly_added);
StateId FindState(const std::string str) {
HistStateMap::const_iterator it = histState_.find(str);
if (it == histState_.end()) {
return -1;
}
return it->second;
HistStateMap::const_iterator it = hist_state_.find(str);
if (it == hist_state_.end()) return -1;
else return it->second;
}
bool use_natural_log_;
BkStateMap bkState_;
HistStateMap histState_;
BackoffStateMap backoff_state_;
HistStateMap hist_state_;
};
#ifndef HAVE_IRSTLM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment