Commit cd5af17c authored by Dan Povey's avatar Dan Povey
Browse files

Changes to MMI code; adding scripts for training with MMI.

git-svn-id: https://svn.code.sf.net/p/kaldi/code/sandbox/discrim@474 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
parents 60fa7170 019fa536
...@@ -59,36 +59,40 @@ silprob=0.5 # same prob as word ...@@ -59,36 +59,40 @@ silprob=0.5 # same prob as word
scripts/make_lexicon_fst.pl data/local/lexicon.txt $silprob sil | \ scripts/make_lexicon_fst.pl data/local/lexicon.txt $silprob sil | \
fstcompile --isymbols=data/lang/phones.txt --osymbols=data/lang/words.txt \ fstcompile --isymbols=data/lang/phones.txt --osymbols=data/lang/words.txt \
--keep_isymbols=false --keep_osymbols=false | \ --keep_isymbols=false --keep_osymbols=false | \
fstarcsort --sort_type=olabel > data/lang/L.fst fstarcsort --sort_type=olabel > data/lang/L.fst || exit 1;
scripts/make_lexicon_fst.pl data/local/lexicon_disambig.txt $silprob sil '#'$ndisambig | \ scripts/make_lexicon_fst.pl data/local/lexicon_disambig.txt $silprob sil '#'$ndisambig | \
fstcompile --isymbols=data/lang_test/phones_disambig.txt --osymbols=data/lang/words.txt \ fstcompile --isymbols=data/lang_test/phones_disambig.txt --osymbols=data/lang/words.txt \
--keep_isymbols=false --keep_osymbols=false | fstarcsort --sort_type=olabel \ --keep_isymbols=false --keep_osymbols=false | fstarcsort --sort_type=olabel \
> data/lang_test/L_disambig.fst > data/lang_test/L_disambig.fst || exit 1;
for x in L_disambig.fst phones_disambig.txt; do
cp data/lang_test/$x data/lang || exit 1;
done
fstcompile --isymbols=data/lang/words.txt --osymbols=data/lang/words.txt --keep_isymbols=false \ fstcompile --isymbols=data/lang/words.txt --osymbols=data/lang/words.txt --keep_isymbols=false \
--keep_osymbols=false data/local/G.txt > data/lang_test/G.fst --keep_osymbols=false data/local/G.txt > data/lang_test/G.fst || exit 1;
# Checking that G is stochastic [note, it wouldn't be for an Arpa] # Checking that G is stochastic [note, it wouldn't be for an Arpa]
fstisstochastic data/lang_test/G.fst || echo Error: G is not stochastic fstisstochastic data/lang_test/G.fst || exit 1;
# Checking that G.fst is determinizable. # Checking that G.fst is determinizable.
fstdeterminize data/lang_test/G.fst /dev/null || echo Error determinizing G. fstdeterminize data/lang_test/G.fst /dev/null || exit 1;
# Checking that L_disambig.fst is determinizable. # Checking that L_disambig.fst is determinizable.
fstdeterminize data/lang_test/L_disambig.fst /dev/null || echo Error determinizing L. fstdeterminize data/lang_test/L_disambig.fst /dev/null || exit 1;
# Checking that disambiguated lexicon times G is determinizable # Checking that disambiguated lexicon times G is determinizable
fsttablecompose data/lang_test/L_disambig.fst data/lang_test/G.fst | \ fsttablecompose data/lang_test/L_disambig.fst data/lang_test/G.fst | \
fstdeterminize >/dev/null || echo Error fstdeterminize >/dev/null || exit 1;
# Checking that LG is stochastic: # Checking that LG is stochastic:
fsttablecompose data/lang/L.fst data/lang_test/G.fst | \ fsttablecompose data/lang/L.fst data/lang_test/G.fst | \
fstisstochastic || echo Error: LG is not stochastic. fstisstochastic || exit 1;
# Checking that L_disambig.G is stochastic: # Checking that L_disambig.G is stochastic:
fsttablecompose data/lang_test/L_disambig.fst data/lang_test/G.fst | \ fsttablecompose data/lang_test/L_disambig.fst data/lang_test/G.fst | \
fstisstochastic || echo Error: LG is not stochastic. fstisstochastic || exit 1;
## Check lexicon. ## Check lexicon.
......
...@@ -57,6 +57,9 @@ local/decode.sh steps/decode_deltas.sh exp/tri1/decode ...@@ -57,6 +57,9 @@ local/decode.sh steps/decode_deltas.sh exp/tri1/decode
steps/align_deltas.sh --graphs "ark,s,cs:gunzip -c exp/tri1/graphs.fsts.gz|" \ steps/align_deltas.sh --graphs "ark,s,cs:gunzip -c exp/tri1/graphs.fsts.gz|" \
data/train data/lang exp/tri1 exp/tri1_ali data/train data/lang exp/tri1 exp/tri1_ali
# 2level full-cov training...
steps/train-2lvl.sh data/train data/lang exp/tri1_ali exp/tri1-2lvl 100 1024 1800 0 0 0
# train tri2a [delta+delta-deltas] # train tri2a [delta+delta-deltas]
steps/train_deltas.sh data/train data/lang exp/tri1_ali exp/tri2a steps/train_deltas.sh data/train data/lang exp/tri1_ali exp/tri2a
# decode tri2a # decode tri2a
...@@ -93,17 +96,27 @@ steps/train_lda_et.sh data/train data/lang exp/tri1_ali exp/tri2c ...@@ -93,17 +96,27 @@ steps/train_lda_et.sh data/train data/lang exp/tri1_ali exp/tri2c
scripts/mkgraph.sh data/lang_test exp/tri2c exp/tri2c/graph scripts/mkgraph.sh data/lang_test exp/tri2c exp/tri2c/graph
local/decode.sh steps/decode_lda_et.sh exp/tri2c/decode local/decode.sh steps/decode_lda_et.sh exp/tri2c/decode
# Align all data with LDA+MLLT system (tri2b) and do LDA+MLLT+SAT # Align all data with LDA+MLLT system (tri2b)
steps/align_lda_mllt.sh --graphs "ark,s,cs:gunzip -c exp/tri2b/graphs.fsts.gz|" \ steps/align_lda_mllt.sh --graphs "ark,s,cs:gunzip -c exp/tri2b/graphs.fsts.gz|" \
data/train data/lang exp/tri2b exp/tri2b_ali data/train data/lang exp/tri2b exp/tri2b_ali
# Do MMI on top of LDA+MLLT.
steps/train_lda_etc_mmi.sh data/train data/lang exp/tri2b_ali exp/tri3a &
local/decode.sh steps/decode_lda_mllt.sh exp/tri3a/decode
# Do LDA+MLLT+SAT
steps/train_lda_mllt_sat.sh data/train data/lang exp/tri2b_ali exp/tri3d steps/train_lda_mllt_sat.sh data/train data/lang exp/tri2b_ali exp/tri3d
scripts/mkgraph.sh data/lang_test exp/tri3d exp/tri3d/graph
local/decode.sh steps/decode_lda_mllt_sat.sh exp/tri3d/decode local/decode.sh steps/decode_lda_mllt_sat.sh exp/tri3d/decode
# Align all data with LDA+MLLT+SAT system (tri3d) # Align all data with LDA+MLLT+SAT system (tri3d)
steps/align_lda_mllt_sat.sh --graphs "ark,s,cs:gunzip -c exp/tri3d/graphs.fsts.gz|" \ steps/align_lda_mllt_sat.sh --graphs "ark,s,cs:gunzip -c exp/tri3d/graphs.fsts.gz|" \
data/train data/lang exp/tri3d exp/tri3d_ali data/train data/lang exp/tri3d exp/tri3d_ali
# MMI on top of that.
steps/train_lda_etc_mmi.sh data/train data/lang exp/tri3d_ali exp/tri4a &
local/decode.sh steps/decode_lda_mllt_sat.sh exp/tri4a/decode
# Try another pass on top of that. # Try another pass on top of that.
steps/train_lda_mllt_sat.sh data/train data/lang exp/tri3d_ali exp/tri4d steps/train_lda_mllt_sat.sh data/train data/lang exp/tri3d_ali exp/tri4d
scripts/mkgraph.sh data/lang_test exp/tri4d exp/tri4d/graph scripts/mkgraph.sh data/lang_test exp/tri4d exp/tri4d/graph
......
#!/usr/bin/perl
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# This script is used in discriminative training.
# This script makes a simple unigram-loop version of G.fst
# using a unigram grammar estimated from some training transcripts.
# This is for MMI training.
# We don't have any silences in G.fst; these are supplied by the
# optional silences in the lexicon.
# Note: the symbols in the transcripts become the input and output
# symbols of G.txt; these can be numeric or not.
if(@ARGV != 0) {
die "Usage: make_unigram_grammar.pl < text-transcripts > G.txt"
}
$totcount = 0;
$nl = 0;
while (<>) {
@A = split(" ", $_);
foreach $a (@A) {
$count{$a}++;
$totcount++;
}
$nl++;
$totcount++; # Treat end-of-sentence as a symbol for purposes of
# $totcount, so the grammar is properly stochastic. This doesn't
# become </s>, it just becomes the final-prob.
}
foreach $a (keys %count) {
$prob = $count{$a} / $totcount;
$cost = -log($prob); # Negated natural-log probs.
print "0\t0\t$a\t$a\t$cost\n";
}
# Zero final-cost.
$final_prob = $nl / $totcount;
$final_cost = -log($final_prob);
print "0\t$final_cost\n";
...@@ -67,10 +67,14 @@ gmm-decode-faster --beam=20.0 --acoustic-scale=0.1 --word-symbol-table=$lang/wor ...@@ -67,10 +67,14 @@ gmm-decode-faster --beam=20.0 --acoustic-scale=0.1 --word-symbol-table=$lang/wor
$srcdir/final.alimdl $graphdir/HCLG.fst "$sifeats" ark,t:$dir/pass1.tra ark,t:$dir/pass1.ali \ $srcdir/final.alimdl $graphdir/HCLG.fst "$sifeats" ark,t:$dir/pass1.tra ark,t:$dir/pass1.ali \
2> $dir/decode_pass1.log || exit 1; 2> $dir/decode_pass1.log || exit 1;
adaptmdl=$srcdir/final.mdl # Compute fMLLR transforms with this model.
[ -f $srcdir/final.adaptmdl ] && adaptmdl=$srcdir/final.adaptmdl # e.g. in MMI-trained systems
( ali-to-post ark:$dir/pass1.ali ark:- | \ ( ali-to-post ark:$dir/pass1.ali ark:- | \
weight-silence-post 0.0 $silphonelist $srcdir/final.alimdl ark:- ark:- | \ weight-silence-post 0.0 $silphonelist $srcdir/final.alimdl ark:- ark:- | \
gmm-post-to-gpost $srcdir/final.alimdl "$sifeats" ark:- ark:- | \ gmm-post-to-gpost $srcdir/final.alimdl "$sifeats" ark:- ark:- | \
gmm-est-fmllr-gpost --spk2utt=ark:$data/spk2utt $srcdir/final.mdl "$sifeats" \ gmm-est-fmllr-gpost --spk2utt=ark:$data/spk2utt $adaptmdl "$sifeats" \
ark,s,cs:- ark:$dir/trans.ark ) \ ark,s,cs:- ark:$dir/trans.ark ) \
2> $dir/trans.log || exit 1; 2> $dir/trans.log || exit 1;
......
#!/bin/bash
# Copyright 2010-2011 Microsoft Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
# MERCHANTABLITY OR NON-INFRINGEMENT.
# See the Apache 2 License for the specific language governing permissions and
# limitations under the License.
# To be run from ..
# This directory does MMI model training, starting from trained
# models. The models must be trained on raw features plus
# cepstral mean normalization plus splice-9-frames, an LDA+[something]
# transform, then possibly speaker-specific affine transforms
# (fMLLR/CMLLR). This script works out from the alignment directory
# whether you trained with some kind of speaker-specific transform.
#
# This training run starts from an initial directory that has
# alignments, models and transforms from an LDA+MLLT system:
# ali, final.mdl, final.mat
if [ $# != 4 ]; then
echo "Usage: steps/train_lda_etc_mmi.sh <data-dir> <lang-dir> <ali-dir> <exp-dir>"
echo " e.g.: steps/train_lda_etc_mmi.sh data/train data/lang exp/tri3d_ali exp/tri4a"
exit 1;
fi
if [ -f path.sh ]; then . path.sh; fi
data=$1
lang=$2
alidir=$3
dir=$4
num_iters=4
acwt=0.1
beam=20
latticebeam=10
scale_opts="--transition-scale=1.0 --acoustic-scale=0.1 --self-loop-scale=0.1"
mkdir -p $dir
cp $alidir/tree $alidir/final.mat $dir # Will use the same tree and transforms as in the baseline.
cp $alidir/final.mdl $dir/0.mdl
if [ -f $alidir/final.alimdl ]; then
cp $alidir/final.alimdl $dir/final.alimdl
cp $alidir/final.mdl $dir/final.adaptmdl # This model used by decoding scripts,
# when you don't want to compute fMLLR transforms with the MMI-trained model.
fi
scripts/split_scp.pl $data/feats.scp $dir/feats{0,1,2,3}.scp
feats="ark:apply-cmvn --norm-vars=false --utt2spk=ark:$data/utt2spk ark:$alidir/cmvn.ark scp:$data/feats.scp ark:- | splice-feats ark:- ark:- | transform-feats $dir/final.mat ark:- ark:- |"
for n in 0 1 2 3; do
featspart[$n]="ark:apply-cmvn --norm-vars=false --utt2spk=ark:$data/utt2spk ark:$alidir/cmvn.ark scp:$dir/feats$n.scp ark:- | splice-feats ark:- ark:- | transform-feats $dir/final.mat ark:- ark:- |"
done
if [ -f $alidir/trans.ark ]; then
echo "Running with speaker transforms $alidir/trans.ark"
feats="$feats transform-feats --utt2spk=ark:$data/utt2spk ark:$alidir/trans.ark ark:- ark:- |"
for n in 0 1 2 3; do
featspart[$n]="${featspart[$n]} transform-feats --utt2spk=ark:$data/utt2spk ark:$alidir/trans.ark ark:- ark:- |"
done
fi
# compute integer form of transcripts.
scripts/sym2int.pl --ignore-first-field $lang/words.txt < $data/text > $dir/train.tra \
|| exit 1;
cp -r $lang $dir/lang
# Compute grammar FST which corresponds to unigram decoding graph.
cat $dir/train.tra | awk '{for(n=2;n<=NF;n++){ printf("%s ", $n); } printf("\n"); }' | \
scripts/make_unigram_grammar.pl | fstcompile > $dir/lang/G.fst \
|| exit 1;
# mkgraph.sh expects a whole directory "lang", so put everything in one directory...
# it gets L_disambig.fst and G.fst (among other things) from $dir/lang, and
# final.mdl from $alidir; the output HCLG.fst goes in $dir/graph.
scripts/mkgraph.sh $dir/lang $alidir $dir/graph || exit 1;
echo "Making denominator lattices"
rm $dir/.error 2>/dev/null
for n in 0 1 2 3; do
gmm-latgen-simple --beam=$beam --lattice-beam=$latticebeam --acoustic-scale=$acwt \
--word-symbol-table=$lang/words.txt \
$alidir/final.mdl $dir/graph/HCLG.fst "${featspart[$n]}" "ark:|gzip -c >$dir/lat$n.gz" \
2>$dir/decode_den.$n.log || touch $dir/.error &
done
wait
if [ -f $dir/.error ]; then
echo "Error creating denominator lattices"
exit 1;
fi
# No need to create "numerator" alignments/lattices: we just use the
# alignments in $alidir.
echo "Note: ignore absolute offsets in the objective function values"
echo "This is caused by not having LM, lexicon or transition-probs in numerator"
x=0;
while [ $x -lt $num_iters ]; do
echo "Iteration $x: getting denominator stats."
# Get denominator stats...
if [ $x -eq 0 ]; then
( lattice-to-post --acoustic-scale=$acwt "ark:gunzip -c $dir/lat?.gz|" ark:- | \
gmm-acc-stats $dir/$x.mdl "$feats" ark:- $dir/den_acc.$x.acc ) \
2>$dir/acc_den.$x.log || exit 1;
else # Need to recompute acoustic likelihoods...
( gmm-rescore-lattice $dir/$x.mdl "ark:gunzip -c $dir/lat?.gz|" "$feats" ark:- | \
lattice-to-post --acoustic-scale=$acwt ark:- ark:- | \
gmm-acc-stats $dir/$x.mdl "$feats" ark:- $dir/den_acc.$x.acc ) \
2>$dir/acc_den.$x.log || exit 1;
fi
echo "Iteration $x: getting numerator stats."
# Get numerator stats...
gmm-acc-stats-ali $dir/$x.mdl "$feats" ark:$alidir/ali $dir/num_acc.$x.acc \
2>$dir/acc_num.$x.log || exit 1;
# Update.
gmm-est-mmi $dir/$x.mdl $dir/num_acc.$x.acc $dir/den_acc.$x.acc $dir/$[$x+1].mdl \
2>$dir/update.$x.log || exit 1;
den=`grep Overall $dir/acc_den.$x.log | grep lattice-to-post | awk '{print $7}'`
num=`grep Overall $dir/acc_num.$x.log | grep gmm-acc-stats-ali | awk '{print $11}'`
diff=`perl -e "print ($num * $acwt - $den);"`
impr=`grep Overall $dir/update.$x.log | awk '{print $10;}'`
impr=`perl -e "print ($impr * $acwt);"` # auxf impr normalized by multiplying by
# kappa, so it's comparable to an objective-function change.
echo On iter $x, objf was $diff, auxf improvement was $impr
x=$[$x+1]
done
# Just copy the source-dir's occs, in case we later need them for something...
cp $alidir/final.occs $dir
( cd $dir; ln -s $x.mdl final.mdl )
echo Done
...@@ -150,11 +150,11 @@ inline void AccumDiagGmm::Resize(const DiagGmm &gmm, GmmFlagsType flags) { ...@@ -150,11 +150,11 @@ inline void AccumDiagGmm::Resize(const DiagGmm &gmm, GmmFlagsType flags) {
/// a Gaussian mixture model. /// a Gaussian mixture model.
/// Update using the DiagGmm: exponential form /// Update using the DiagGmm: exponential form
void MleDiagGmmUpdate(const MleDiagGmmOptions &config, void MleDiagGmmUpdate(const MleDiagGmmOptions &config,
const AccumDiagGmm &diaggmm_acc, const AccumDiagGmm &diaggmm_acc,
GmmFlagsType flags, GmmFlagsType flags,
DiagGmm *gmm, DiagGmm *gmm,
BaseFloat *obj_change_out, BaseFloat *obj_change_out,
BaseFloat *count_out); BaseFloat *count_out);
/// Calc using the DiagGMM exponential form /// Calc using the DiagGMM exponential form
BaseFloat MlObjective(const DiagGmm& gmm, BaseFloat MlObjective(const DiagGmm& gmm,
......
...@@ -116,7 +116,7 @@ void UnitTestMmieAmDiagGmm() { ...@@ -116,7 +116,7 @@ void UnitTestMmieAmDiagGmm() {
size_t iteration = 0; size_t iteration = 0;
size_t maxiterations = 2; size_t maxiterations = 2;
MmieDiagGmmOptions config; MmieDiagGmmOptions config;
BaseFloat obj, count; BaseFloat auxf, count;
while (iteration < maxiterations) { while (iteration < maxiterations) {
std::cout << "Iteration :" << iteration << " Num Gauss: " << gmm->NumGauss() << '\n'; std::cout << "Iteration :" << iteration << " Num Gauss: " << gmm->NumGauss() << '\n';
...@@ -146,7 +146,8 @@ void UnitTestMmieAmDiagGmm() { ...@@ -146,7 +146,8 @@ void UnitTestMmieAmDiagGmm() {
} }
/// get 2 mixtures from 1 /// get 2 mixtures from 1
MleDiagGmmUpdate(config, num, flags, gmm, &obj, &count); MleDiagGmmUpdate(config, num, flags, gmm, &auxf, &count);
/// Split gaussian /// Split gaussian
if (iteration < maxiterations -1) gmm->Split(gmm->NumGauss() * 2, 0.001); if (iteration < maxiterations -1) gmm->Split(gmm->NumGauss() * 2, 0.001);
...@@ -214,7 +215,10 @@ void UnitTestMmieAmDiagGmm() { ...@@ -214,7 +215,10 @@ void UnitTestMmieAmDiagGmm() {
am_gmm.GetGaussianMean(0,0,&tmp_mean); am_gmm.GetGaussianMean(0,0,&tmp_mean);
std::cout << "Mean of 1st Gmm before: " << tmp_mean << '\n'; std::cout << "Mean of 1st Gmm before: " << tmp_mean << '\n';
MmieAmDiagGmmUpdate(config, mmi_am_accs, flags, &am_gmm, &obj, &count); BaseFloat auxf_gauss, auxf_weight;
int32 num_floored;
MmieAmDiagGmmUpdate(config, mmi_am_accs, flags, &am_gmm,
&auxf_gauss, &auxf_weight, &count, &num_floored);
am_gmm.GetGaussianMean(0,0,&tmp_mean); am_gmm.GetGaussianMean(0,0,&tmp_mean);
std::cout << "Mean of 1st Gmm after: " << tmp_mean << '\n'; std::cout << "Mean of 1st Gmm after: " << tmp_mean << '\n';
......
...@@ -165,31 +165,55 @@ void MmieAccumAmDiagGmm::WriteDen(std::ostream& out_stream, bool binary) const { ...@@ -165,31 +165,55 @@ void MmieAccumAmDiagGmm::WriteDen(std::ostream& out_stream, bool binary) const {
void MmieAmDiagGmmUpdate(const MmieDiagGmmOptions &config, void MmieAmDiagGmmUpdate(const MmieDiagGmmOptions &config,
const MmieAccumAmDiagGmm &mmieamdiaggmm_acc, const MmieAccumAmDiagGmm &mmieamdiaggmm_acc,
GmmFlagsType flags, GmmFlagsType flags,
AmDiagGmm *am_gmm, AmDiagGmm *am_gmm,
BaseFloat *obj_change_out, BaseFloat *auxf_change_gauss,
BaseFloat *count_out) { BaseFloat *auxf_change_weights,
BaseFloat *count_out,
int32 *num_floored_out) {
KALDI_ASSERT(am_gmm != NULL); KALDI_ASSERT(am_gmm != NULL);
KALDI_ASSERT(mmieamdiaggmm_acc.NumAccs() == am_gmm->NumPdfs()); KALDI_ASSERT(mmieamdiaggmm_acc.NumAccs() == am_gmm->NumPdfs());
if (obj_change_out != NULL) *obj_change_out = 0.0; if (auxf_change_gauss != NULL) *auxf_change_gauss = 0.0;
if (auxf_change_weights != NULL) *auxf_change_weights = 0.0;
if (count_out != NULL) *count_out = 0.0; if (count_out != NULL) *count_out = 0.0;
BaseFloat tmp_obj_change, tmp_count; if (num_floored_out != NULL) *num_floored_out = 0.0;
BaseFloat *p_obj = (obj_change_out != NULL) ? &tmp_obj_change : NULL, BaseFloat tmp_auxf_change_gauss, tmp_auxf_change_weights, tmp_count;
*p_count = (count_out != NULL) ? &tmp_count : NULL; int32 tmp_num_floored;
MmieAccumDiagGmm mmie_gmm; MmieAccumDiagGmm mmie_gmm;
for (size_t i = 0; i < mmieamdiaggmm_acc.NumAccs(); i++) { for (size_t i = 0; i < mmieamdiaggmm_acc.NumAccs(); i++) {
mmie_gmm.Resize(am_gmm->GetPdf(i).NumGauss(), am_gmm->GetPdf(i).Dim(), flags); mmie_gmm.Resize(am_gmm->GetPdf(i).NumGauss(), am_gmm->GetPdf(i).Dim(), flags);
mmie_gmm.SubtractAccumulatorsISmoothing(mmieamdiaggmm_acc.GetNumAcc(i), mmieamdiaggmm_acc.GetDenAcc(i), config); mmie_gmm.SubtractAccumulatorsISmoothing(mmieamdiaggmm_acc.GetNumAcc(i),
mmie_gmm.Update(config, flags, &(am_gmm->GetPdf(i)), p_obj, p_count); mmieamdiaggmm_acc.GetDenAcc(i),
config);
if (obj_change_out != NULL) *obj_change_out += tmp_obj_change; mmie_gmm.Update(config, flags, &(am_gmm->GetPdf(i)),
if (count_out != NULL) *count_out += tmp_count; &tmp_auxf_change_gauss, &tmp_auxf_change_weights,
&tmp_count, &tmp_num_floored);
if (auxf_change_gauss != NULL) *auxf_change_gauss += tmp_auxf_change_gauss;
if (auxf_change_weights != NULL) *auxf_change_weights += tmp_auxf_change_weights;
if (count_out != NULL) *count_out += tmp_count;
if (num_floored_out != NULL) *num_floored_out += tmp_num_floored;
} }
}
BaseFloat MmieAccumAmDiagGmm::TotNumCount() {
BaseFloat ans = 0.0;
for (size_t i = 0; i < num_accumulators_.size(); i++)
if (num_accumulators_[i])
ans += num_accumulators_[i]->occupancy().Sum();
return ans;
} }
BaseFloat MmieAccumAmDiagGmm::TotDenCount() {
BaseFloat ans = 0.0;
for (size_t i = 0; i < den_accumulators_.size(); i++)
if (den_accumulators_[i])
ans += den_accumulators_[i]->occupancy().Sum();
return ans;
}
} // namespace kaldi } // namespace kaldi
...@@ -54,6 +54,8 @@ class MmieAccumAmDiagGmm { ...@@ -54,6 +54,8 @@ class MmieAccumAmDiagGmm {
AccumDiagGmm& GetDenAcc(int32 index) const; AccumDiagGmm& GetDenAcc(int32 index) const;
void CopyToNumAcc(int32 index); void CopyToNumAcc(int32 index);
BaseFloat TotNumCount();
BaseFloat TotDenCount();
private: private:
/// MMIE accumulators and update methods for the GMMs /// MMIE accumulators and update methods for the GMMs
std::vector<AccumDiagGmm*> num_accumulators_; std::vector<AccumDiagGmm*> num_accumulators_;
...@@ -68,10 +70,13 @@ class MmieAccumAmDiagGmm { ...@@ -68,10 +70,13 @@ class MmieAccumAmDiagGmm {
/// for computing the maximum-likelihood estimates of the parameters of /// for computing the maximum-likelihood estimates of the parameters of
/// an acoustic model that uses diagonal Gaussian mixture models as emission densities. /// an acoustic model that uses diagonal Gaussian mixture models as emission densities.
void MmieAmDiagGmmUpdate(const MmieDiagGmmOptions &config, void MmieAmDiagGmmUpdate(const MmieDiagGmmOptions &config,
const MmieAccumAmDiagGmm &mmieamdiaggmm_acc, const MmieAccumAmDiagGmm &mmieamdiaggmm_acc,
GmmFlagsType flags, GmmFlagsType flags,
AmDiagGmm *am_gmm, BaseFloat *obj_change_out, AmDiagGmm *am_gmm,
BaseFloat *count_out); BaseFloat *auxf_change_gauss,
BaseFloat *auxf_change_weight,
BaseFloat *count_out,
int32 *num_floored_out);
} // End namespace kaldi } // End namespace kaldi
......
...@@ -173,7 +173,7 @@ void UnitTestEstimateMmieDiagGmm() { ...@@ -173,7 +173,7 @@ void UnitTestEstimateMmieDiagGmm() {
mmie_gmm.SubtractAccumulatorsISmoothing(num, den, config); mmie_gmm.SubtractAccumulatorsISmoothing(num, den, config);
BaseFloat obj, count; BaseFloat auxf_gauss, auxf_weight, count;
//Vector<double> mean_hlp(dim); //Vector<double> mean_hlp(dim);
//mean_hlp.CopyFromVec(gmm->means_invvars().Row(0)); //mean_hlp.CopyFromVec(gmm->means_invvars().Row(0));
//std::cout << "MEANX: " << mean_hlp << '\n'; //std::cout << "MEANX: " << mean_hlp << '\n';
...@@ -187,8 +187,9 @@ void UnitTestEstimateMmieDiagGmm() { ...@@ -187,8 +187,9 @@ void UnitTestEstimateMmieDiagGmm() {
Input ki("tmp_stats", &binary_in); Input ki("tmp_stats", &binary_in);
mmie_gmm.Read(ki.Stream(), binary_in, false); // false = not adding. mmie_gmm.Read(ki.Stream(), binary_in, false); // false = not adding.
int32 num_floored;
mmie_gmm.Update(config, flags, gmm, &obj, &count); mmie_gmm.Update(config, flags, gmm, &auxf_gauss, &auxf_weight, &count,
&num_floored);
//mean_hlp.CopyFromVec(gmm->means_invvars().Row(0)); //mean_hlp.CopyFromVec(gmm->means_invvars().Row(0));
//std::cout << "MEANY: " << mean_hlp << '\n'; //std::cout << "MEANY: " << mean_hlp << '\n';
std::cout << "MEANY: " << gmm->weights() << '\n'; std::cout << "MEANY: " << gmm->weights() << '\n';
...@@ -214,4 +215,4 @@ int main() { ...@@ -214,4 +215,4 @@ int main() {
kaldi::UnitTestEstimateMmieDiagGmm(); kaldi::UnitTestEstimateMmieDiagGmm();
} }
std::cout << "Test OK.\n"; std::cout << "Test OK.\n";
} }
\ No newline at end of file
...@@ -156,10 +156,11 @@ void MmieAccumDiagGmm::Scale(BaseFloat f, GmmFlagsType flags) { ...@@ -156,10 +156,11 @@ void MmieAccumDiagGmm::Scale(BaseFloat f, GmmFlagsType flags) {
} }
void MmieAccumDiagGmm::SubtractAccumulatorsISmoothing(const AccumDiagGmm& num_acc, void MmieAccumDiagGmm::SubtractAccumulatorsISmoothing(
const AccumDiagGmm& den_acc, const AccumDiagGmm& num_acc,
const MmieDiagGmmOptions& opts){ const AccumDiagGmm& den_acc,
const MmieDiagGmmOptions& opts){
//KALDI_ASSERT(num_acc.NumGauss() == den_acc.NumGauss && num_acc.Dim() == den_acc.Dim()); //KALDI_ASSERT(num_acc.NumGauss() == den_acc.NumGauss && num_acc.Dim() == den_acc.Dim());
//std::cout << "NumGauss: " << num_acc.NumGauss() << " " << den_acc.NumGauss() << " " << num_comp_ << '\n'; //std::cout << "NumGauss: " << num_acc.NumGauss() << " " << den_acc.NumGauss() << " " << num_comp_ << '\n';
KALDI_ASSERT(num_acc.NumGauss() == num_comp_ && num_acc.Dim() == dim_); KALDI_ASSERT(num_acc.NumGauss() == num_comp_ && num_acc.Dim() == dim_);
...@@ -169,173 +170,204 @@ void MmieAccumDiagGmm::SubtractAccumulatorsISmoothing(const AccumDiagGmm& num_ac ...@@ -169,173 +170,204 @@ void MmieAccumDiagGmm::SubtractAccumulatorsISmoothing(const AccumDiagGmm& num_ac
// no subracting occs, just copy them to local vars // no subracting occs, just copy them to local vars
num_occupancy_.CopyFromVec(num_acc.occupancy()); num_occupancy_.CopyFromVec(num_acc.occupancy());
den_occupancy_.CopyFromVec(den_acc.occupancy());