author | Dan Povey <dpovey@gmail.com> | |
Tue, 20 Dec 2011 00:09:48 +0000 (00:09 +0000) | ||
committer | Dan Povey <dpovey@gmail.com> | |
Tue, 20 Dec 2011 00:09:48 +0000 (00:09 +0000) |
31 files changed:
diff --cc egs/rm/s3/run.sh
index 7646f495ef852d2c58c3b82a65d4807d45a64975,3b4db1d5d4f3e15a775da32ddf58e59f13ef505e..fe22890ae4b301dab7c473f4b914363b42c69e5c
--- 1/egs/rm/s3/run.sh
--- 2/egs/rm/s3/run.sh
+++ b/egs/rm/s3/run.sh
# Some system combination experiments (just compose lattices).
local/decode_combine.sh steps/decode_combine.sh exp/tri1/decode exp/tri2a/decode exp/combine_1_2a/decode
local/decode_combine.sh steps/decode_combine.sh exp/sgmm4f/decode/ exp/tri3d/decode exp/combine_sgmm4f_tri3d/decode
+local/decode_combine.sh steps/decode_combine.sh exp/sgmm4f/decode/ exp/tri4a/decode exp/combine_sgmm4f_tri4a/decode
-for x in exp/*/decode*; do [ -d $x ] && grep WER $x/wer_* | scripts/best_wer.sh; done
-exp/combine_1_2a/decode/wer_7:%WER 3.399027 [ 426 / 12533, 55 ins, 94 del, 277 sub ]
-exp/combine_sgmm4f_tri3d/decode/wer_5:%WER 1.731429 [ 217 / 12533, 30 ins, 43 del, 144 sub ]
-exp/mono/decode/wer_6:%WER 10.340701 [ 1296 / 12533, 95 ins, 391 del, 810 sub ]
-exp/sgmm3d/decode/wer_5:%WER 2.267284 [ 284 / 12526, 38 ins, 51 del, 195 sub ]
-exp/sgmm3e/decode/wer_6:%WER 2.122397 [ 266 / 12533, 37 ins, 51 del, 178 sub ]
-exp/sgmm4f/decode/wer_4:%WER 1.795261 [ 225 / 12533, 45 ins, 37 del, 143 sub ]
-exp/sgmm4f/decode_fmllr/wer_5:%WER 1.771324 [ 222 / 12533, 38 ins, 42 del, 142 sub ]
-exp/tri1/decode/wer_6:%WER 3.566584 [ 447 / 12533, 74 ins, 88 del, 285 sub ]
-exp/tri2a/decode/wer_7:%WER 3.518711 [ 441 / 12533, 57 ins, 91 del, 293 sub ]
-exp/tri2b/decode/wer_9:%WER 3.614458 [ 453 / 12533, 59 ins, 111 del, 283 sub ]
-exp/tri2c/decode/wer_6:%WER 2.833653 [ 355 / 12528, 54 ins, 71 del, 230 sub ]
-exp/tri3d/decode/wer_7:%WER 2.489428 [ 312 / 12533, 43 ins, 63 del, 206 sub ]
-exp/tri4d/decode/wer_7:%WER 2.649007 [ 332 / 12533, 53 ins, 67 del, 212 sub ]
+### From here is semi-continuous experiments. ###
+### Note: this is not yet working. Do not run this. ***
+echo "semi-continuous code not finalized" && exit 1;
+
+# Train a classic semi-continuous model using {diag,full} densities
+# the numeric parameters following exp/tri1-semi are:
+# number of gaussians, something like 4096 for diag, 2048 for full
+# number of tree leaves
+# type of suff-stats interpolation (0 regular, 1 preserves counts)
+# rho-stats, rho value for the smoothing of the statistics (0 for no smoothing)
+# rho-iters, rho value to interpolate the parameters with the last iteration (0 for no interpolation)
+
-
+steps/train_ubm_lda_etc.sh 1024 data/train data/lang exp/tri2b_ali exp/ubm3f
+steps/train_lda_mllt_semi_full.sh data/train data/lang exp/tri2b_ali exp/ubm3f/final.ubm exp/tiedfull3f 2500 1 35 0.2
+
+steps/train_semi_full.sh data/train data/lang exp/tri1_ali exp/tri1_semi 1024 2500 1 35 0.2
+local/decode.sh steps/decode_tied_full.sh exp/tri1_semi/decode
+
+# 2level full-cov training...
+steps/train_2lvl.sh data/train data/lang exp/tri1_ali exp/tri1_2lvl 100 1024 1800 0 0 0
+
+# Train a 2-lvl semi-continuous model using {diag,full} densities
+# the numeric parameters following exp/tri1_2lvl are:
+# number of codebooks, typically 1-3 times number of phones, the more, the faster
+# total number of gaussians, something like 2048 for full, 4096 for diag
+# number of tree leaves
+# type of suff-stats interpolation (0 regular, 1 preserves counts)
+# rho-stats, rho value for the smoothing of the statistics (0 for no smoothing)
+# rho-iters, rho value to interpolate the parameters with the last iteration (0 for no interpolation)
+steps/train_2lvl_full.sh data/train data/lang exp/tri1_ali exp/tri1_2lvl 104 2048 2500 0 1 10 0
+local/decode.sh steps/decode_tied_full.sh exp/tri1_2lvl/decode
-local/decode_combine.sh steps/decode_combine.sh exp/tri1/decode exp/tri2a/decode exp/combine_tri3d_sgmm4f
+# note on new gselect:
+# gmm-gselect --n=50 "sgmm-write-ubm exp/sgmm3d/final.mdl - | fgmm-global-to-gmm - - |" 'ark,s,cs:apply-cmvn --norm-vars=false --utt2spk=ark:data/train/utt2spk ark:exp/tri2b_ali/cmvn.ark scp:data/train/feats.scp ark:- | splice-feats ark:- ark:- | transform-feats exp/sgmm3d/final.mat ark:- ark:- |' ark,t:- | fgmm-gselect --n=15 "sgmm-write-ubm exp/sgmm3d/final.mdl -|" 'ark,s,cs:apply-cmvn --norm-vars=false --utt2spk=ark:data/train/utt2spk ark:exp/tri2b_ali/cmvn.ark scp:data/train/feats.scp ark:- | splice-feats ark:- ark:- | transform-feats exp/sgmm3d/final.mat ark:- ark:- |' ark:- ark,t:- | head -1 > f2
diff --cc egs/rm/s3/steps/train_lda_etc_mmi.sh
index d96ef676663ae82a6a7f9a824f7e9e4ea02994d2,234c67bae0de110c227230fb9590ef3182be456e..5a95930e280ec2a2f1a78113e709ed28ddc793b5
# alignments, models and transforms from an LDA+MLLT system:
# ali, final.mdl, final.mat
-b=0 # boosting constant, for boosted MMI.
+boost=0 # boosting constant, for boosted MMI.
++tau=100 # Tau value.
++
if [ $1 == "--boost" ]; then # e.g. "--boost 0.05"
shift;
- b=$1;
+ boost=$1;
shift;
fi
# Get numerator stats...
gmm-acc-stats-ali $dir/$x.mdl "$feats" ark:$alidir/ali $dir/num_acc.$x.acc \
2>$dir/acc_num.$x.log || exit 1;
-- # Update.
-- gmm-est-mmi $dir/$x.mdl $dir/num_acc.$x.acc $dir/den_acc.$x.acc $dir/$[$x+1].mdl \
++
++ ( gmm-est-gaussians-ebw $dir/$x.mdl "gmm-ismooth-stats --tau=$tau $dir/num_acc.$x.acc $dir/num_acc.$x.acc -|" \
++ $dir/den_acc.$x.acc - | \
++ gmm-est-weights-ebw - $dir/num_acc.$x.acc $dir/den_acc.$x.acc $dir/$[$x+1].mdl ) \
2>$dir/update.$x.log || exit 1;
den=`grep Overall $dir/acc_den.$x.log | grep lattice-to-post | awk '{print $7}'`
num=`grep Overall $dir/acc_num.$x.log | grep gmm-acc-stats-ali | awk '{print $11}'`
diff=`perl -e "print ($num * $acwt - $den);"`
-- impr=`grep Overall $dir/update.$x.log | awk '{print $10;}'`
++ impr=`grep Overall $dir/update.$x.log | head -1 | awk '{print $10;}'`
impr=`perl -e "print ($impr * $acwt);"` # auxf impr normalized by multiplying by
# kappa, so it's comparable to an objective-function change.
echo On iter $x, objf was $diff, auxf improvement was $impr | tee $dir/objf.$x.log
diff --cc egs/rm/s3/steps/train_semi_full.sh
Simple merge
diff --cc egs/wsj/s3/run.sh
index 6f5ebbf810c9c5f475d69f00fcb1c35d6c19deb5,b71bcba8c0fa5d3fc4b06393022a3c5808b26102..589a218046fd6ee0effcf5ec316d1987d78712a8
--- 1/egs/wsj/s3/run.sh
--- 2/egs/wsj/s3/run.sh
+++ b/egs/wsj/s3/run.sh
steps/align_lda_mllt.sh --num-jobs 10 --cmd "$train_cmd" \
--use-graphs data/train_si84 data/lang exp/tri2b exp/tri2b_ali_si84
+# Train and test MMI (and boosted MMI) on tri2b system.
+steps/make_denlats_lda_etc.sh --num-jobs 10 --cmd "$train_cmd" \
+ data/train_si84 data/lang exp/tri2b_ali_si84 exp/tri2b_denlats_si84
+steps/train_lda_etc_mmi.sh --num-jobs 10 --cmd "$train_cmd" \
+ data/train_si84 data/lang exp/tri2b_ali_si84 exp/tri2b_denlats_si84 exp/tri2b exp/tri2b_mmi
+scripts/decode.sh --cmd "$decode_cmd" steps/decode_lda_mllt.sh exp/tri2b/graph_tgpr data/test_eval92 exp/tri2b_mmi/decode_tgpr_eval92
+steps/train_lda_etc_mmi.sh --num-jobs 10 --boost 0.1 --cmd "$train_cmd" \
+ data/train_si84 data/lang exp/tri2b_ali_si84 exp/tri2b_denlats_si84 exp/tri2b exp/tri2b_mmi_b0.1
+scripts/decode.sh --cmd "$decode_cmd" steps/decode_lda_mllt.sh exp/tri2b/graph_tgpr data/test_eval92 exp/tri2b_mmi_b0.1/decode_tgpr_eval92
+
++steps/train_lda_etc_mce.sh data/train_si84 data/lang \
++ exp/tri2b_ali_si84 exp/tri2b_mce
++
++ scripts/decode.sh --num-jobs 10 --cmd "$train_cmd" steps/decode_lda_mllt.sh \
++ exp/tri3a/graph_tgpr data/test_eval92 exp/tri2b_mce/decode_tgpr_eval92
++
# Train LDA+ET system.
steps/train_lda_et.sh --num-jobs 10 --cmd "$train_cmd" \
2500 15000 data/train_si84 data/lang exp/tri1_ali_si84 exp/tri2c
diff --cc egs/wsj/s3/steps/train_lda_etc_mce.sh
index 0000000000000000000000000000000000000000,41cbb5313f68811610ecb07703d328ac4cdbcd99..f3db7dd65cb9812c6250cdc03e45f6172f747e3b
mode 000000,100755..100755
mode 000000,100755..100755
--- /dev/null
-# Copyright 2010-2011 Microsoft Corporation
+ #!/bin/bash
- echo "Usage: steps/train_lda_etc_mmi.sh <data-dir> <lang-dir> <ali-dir> <exp-dir>"
- echo " e.g.: steps/train_lda_etc_mmi.sh data/train data/lang exp/tri3d_ali exp/tri4a"
++# Copyright 2011 Chao Weng
++# This script does not work in its current version, due to changes
++# made in the binaries. It will be updated in the next version or two.
+
+ # Licensed under the Apache License, Version 2.0 (the "License");
+ # you may not use this file except in compliance with the License.
+ # You may obtain a copy of the License at
+ #
+ # http://www.apache.org/licenses/LICENSE-2.0
+ #
+ # THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ # KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+ # WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+ # MERCHANTABLITY OR NON-INFRINGEMENT.
+ # See the Apache 2 License for the specific language governing permissions and
+ # limitations under the License.
+
+ # To be run from ..
+ # This directory does MMI model training, starting from trained
+ # models. The models must be trained on raw features plus
+ # cepstral mean normalization plus splice-9-frames, an LDA+[something]
+ # transform, then possibly speaker-specific affine transforms
+ # (fMLLR/CMLLR). This script works out from the alignment directory
+ # whether you trained with some kind of speaker-specific transform.
+ #
+ # This training run starts from an initial directory that has
+ # alignments, models and transforms from an LDA+MLLT system:
+ # ali, final.mdl, final.mat
+
+ b=0 # boosting constant, for boosted MMI.
+ if [ $1 == "--boost" ]; then # e.g. "--boost 0.05"
+ shift;
+ b=$1;
+ shift;
+ fi
+
+ if [ $# != 4 ]; then
- # when you don't want to compute fMLLR transforms with the MMI-trained model.
++ echo "Usage: steps/train_lda_etc_mce.sh <data-dir> <lang-dir> <ali-dir> <exp-dir>"
++ echo " e.g.: steps/train_lda_etc_mce.sh data/train data/lang exp/tri3d_ali exp/tri4a"
+ exit 1;
+ fi
+
+ if [ -f path.sh ]; then . path.sh; fi
+
+ data=$1
+ lang=$2
+ alidir=$3
+ dir=$4
+
+ num_iters=4
+ acwt=0.1
+ beam=20
+ latticebeam=10
+ scale_opts="--transition-scale=1.0 --acoustic-scale=0.1 --self-loop-scale=0.1"
+ silphonelist=`cat $lang/silphones.csl`
+
+ oov_sym="<SPOKEN_NOISE>"
+ mkdir -p $dir
+ cp $alidir/tree $alidir/final.mat $dir # Will use the same tree and transforms as in the baseline.
+ cp $alidir/final.mdl $dir/0.mdl
+
+ if [ -f $alidir/final.alimdl ]; then
+ cp $alidir/final.alimdl $dir/final.alimdl
+ cp $alidir/final.mdl $dir/final.adaptmdl # This model used by decoding scripts,
-case 0 in #goto here
- 1)
++ # when you don't want to compute fMLLR transforms with the MCE-trained model.
+ fi
+
+ scripts/split_scp.pl $data/feats.scp $dir/feats{0,1,2,3}.scp
+
+ feats="ark:apply-cmvn --norm-vars=false --utt2spk=ark:$data/utt2spk ark:$alidir/cmvn scp:$data/feats.scp ark:- | splice-feats ark:- ark:- | transform-feats $dir/final.mat ark:- ark:- |"
+
+
++#case 0 in #goto here
++# 1)
+
+
+
+
+ for n in 0 1 2 3; do
+ featspart[$n]="ark:apply-cmvn --norm-vars=false --utt2spk=ark:$data/utt2spk ark:$alidir/$n.cmvn scp:$dir/feats$n.scp ark:- | splice-feats ark:- ark:- | transform-feats $dir/final.mat ark:- ark:- |"
+ done
+
+ if [ -f $alidir/trans.ark ]; then
+ echo "Running with speaker transforms $alidir/trans.ark"
+ feats="$feats transform-feats --utt2spk=ark:$data/utt2spk ark:$alidir/trans.ark ark:- ark:- |"
+ for n in 0 1 2 3; do
+ featspart[$n]="${featspart[$n]} transform-feats --utt2spk=ark:$data/utt2spk ark:$alidir/trans.ark ark:- ark:- |"
+ done
+ fi
+
+ # compute integer form of transcripts.
+ scripts/sym2int.pl --map-oov $oov_sym --ignore-first-field $lang/words.txt < $data/text > $dir/train.tra \
+ || exit 1;
+
+ cp -r $lang $dir/lang
+
+ # Compute grammar FST which corresponds to unigram decoding graph.
+ cat $dir/train.tra | awk '{for(n=2;n<=NF;n++){ printf("%s ", $n); } printf("\n"); }' | \
+ scripts/make_unigram_grammar.pl | fstcompile > $dir/lang/G.fst \
+ || exit 1;
+
+ # mkgraph.sh expects a whole directory "lang", so put everything in one directory...
+ # it gets L_disambig.fst and G.fst (among other things) from $dir/lang, and
+ # final.mdl from $alidir; the output HCLG.fst goes in $dir/graph.
+
+ scripts/mkgraph.sh $dir/lang $alidir $dir/dgraph || exit 1;
+
+
+
+ echo "Making denominator lattices"
+
+
+ rm $dir/.error 2>/dev/null
+ for n in 0 1 2 3; do
+ gmm-latgen-simple --beam=$beam --lattice-beam=$latticebeam --acoustic-scale=$acwt \
+ --word-symbol-table=$lang/words.txt \
+ $alidir/final.mdl $dir/dgraph/HCLG.fst "${featspart[$n]}" \
+ "ark:|lattice-boost-ali --b=$b --silence-phones=$silphonelist $alidir/final.mdl ark:- ark,s,cs:$alidir/$n.ali ark:- | gzip -c >$dir/lat$n.gz" \
+ 2>$dir/decode_den.$n.log || touch $dir/.error &
+ done
+ wait
+ if [ -f $dir/.error ]; then
+ echo "Error creating denominator lattices"
+ exit 1;
+ fi
+
+
+
+ # test generating numerator lattice(contains LM)
+
+ fsttablecompose $dir/lang/L.fst $dir/lang/G.fst > $dir/LG.fst
+ for n in 0 1 2 3; do
+ compile-train-graphs $dir/tree $dir/0.mdl $dir/LG.fst "ark:scripts/sym2int.pl --map-oov \"$oov_sym\" --ignore-first-field $lang/words.txt < $data/split4/$n/text |" ark:$dir/fsts$n \
+ 2>$dir/gen_numLG.$x.log || exit 1;
+ done
+
+ cat $dir/fsts0 $dir/fsts1 $dir/fsts2 $dir/fsts3 > $dir/fsts
+ # test remove numerator lattice
+ # No need to create "numerator" alignments/lattices: we just use the
+ # alignments in $alidir.
+
+ for n in 0 1 2 3; do
+ tra="ark:scripts/sym2int.pl --map-oov \"$oov_sym\" --ignore-first-field $lang/words.txt < $data/split4/$n/text |";
+ lattice-difference "$tra" "ark:gunzip -c $dir/lat$n.gz|" "ark:|lattice-boost-ali --b=$b --silence-phones=$silphonelist $alidir/final.mdl ark:- ark,s,cs:$alidir/$n.ali ark:- | gzip -c >$dir/mcelat$n.gz" \
+ 2>$dir/lattice_difference$n.log || exit 1;
+ done
+
+ exit 1;
+
+ ;; #here:
+ esac
+ #
+
+ echo "Note: ignore absolute offsets in the objective function values"
+ echo "This is caused by not having LM, lexicon or transition-probs in numerator"
+ x=0;
+ while [ $x -lt $num_iters ]; do
+ echo "Iteration $x: getting denominator stats."
+ # Get denominator stats...
+ if [ $x -eq 0 ]; then
+ gmm-align-compiled $scale_opts --beam=10 --retry-beam=40 $dir/$x.mdl ark:$dir/fsts "$feats" "ark:$dir/ali" ark:$dir/num$x.score \
+ 2>$dir/gmm_align_compiled.$x.log || exit 1;
+ lattice-to-post --acoustic-scale=$acwt "ark:gunzip -c $dir/mcelat?.gz|" ark:$dir/den$x.post ark:$dir/den$x.score \
+ 2>$dir/lattice_to_post.$x.log || exit 1;
+ ali-to-post ark:$dir/ali ark:$dir/num.post \
+ 2>$dir/ali_to_post.log || exit 1;
+ compute-mce-scale --mce-alpha=0.01 --mce-beta=0.0 ark:$dir/num$x.score ark:$dir/den$x.score ark:$dir/post$x.scale \
+ 2>$dir/compute_mce_scale$x.log || exit 1;
+ scale-post ark:$dir/post$x.scale ark:$dir/num.post ark:$dir/scaled_num$x.post \
+ 2>$dir/scale_post_num$x.log || exit 1;
+ scale-post ark:$dir/post$x.scale ark:$dir/den$x.post ark:$dir/scaled_den$x.post \
+ 2>$dir/scale_post_den$x.log || exit 1;
+ gmm-acc-stats $dir/$x.mdl "$feats" ark:$dir/scaled_den$x.post $dir/den_acc.$x.acc \
+ 2>$dir/acc_den.$x.log || exit 1;
+
+ else # Need to recompute acoustic likelihoods...
+ gmm-align-compiled $scale_opts --beam=10 --retry-beam=40 $dir/$x.mdl ark:$dir/fsts "$feats" "ark:$dir/ali" ark:$dir/num$x.score \
+ 2>$dir/gen_align_compiled.$x.log || exit 1;
+ ( gmm-rescore-lattice $dir/$x.mdl "ark:gunzip -c $dir/mcelat?.gz|" "$feats" ark:- | \
+ lattice-to-post --acoustic-scale=$acwt ark:- ark:$dir/den$x.post ark:$dir/den$x.score ) \
+ 2>$dir/lattice_to_post.$x.log || exit 1;
+ compute-mce-scale --mce-alpha=0.01 --mce-beta=0.0 ark:$dir/num$x.score ark:$dir/den$x.score ark:$dir/post$x.scale \
+ 2>$dir/compute_mce_scale$x.log || exit 1;
+ scale-post ark:$dir/post$x.scale ark:$dir/num.post ark:$dir/scaled_num$x.post \
+ 2>$dir/scale_post_num$x.log || exit 1;
+ scale-post ark:$dir/post$x.scale ark:$dir/den$x.post ark:$dir/scaled_den$x.post \
+ 2>$dir/scale_post_den$x.log || exit 1;
+ gmm-acc-stats $dir/$x.mdl "$feats" ark:$dir/scaled_den$x.post $dir/den_acc.$x.acc \
+ 2>$dir/acc_den.$x.log || exit 1;
+ fi
+ echo "Iteration $x: getting numerator stats."
+ # Get numerator stats...
+ gmm-acc-stats $dir/$x.mdl "$feats" ark:$dir/scaled_num$x.post $dir/num_acc.$x.acc \
+ 2>$dir/acc_num.$x.log || exit 1;
+ gmm-acc-stats $dir/$x.mdl "$feats" ark:$dir/num.post $dir/i_smooth_acc.$x.acc \
+ 2>$dir/acc_i_smooth.$x.log || exit 1;
+ # Update.
+ gmm-est-mmi --i-smooth-stats=$dir/i_smooth_acc.$x.acc --i-smooth-tau=0.0 $dir/$x.mdl $dir/num_acc.$x.acc $dir/den_acc.$x.acc $dir/$[$x+1].mdl \
+ 2>$dir/update.$x.log || exit 1;
+
+ den=`grep Overall $dir/acc_den.$x.log | grep lattice-to-post | awk '{print $7}'`
+ num=`grep Overall $dir/acc_num.$x.log | grep gmm-acc-stats-ali | awk '{print $11}'`
+ diff=`perl -e "print ($num * $acwt - $den);"`
+ impr=`grep Overall $dir/update.$x.log | awk '{print $10;}'`
+ impr=`perl -e "print ($impr * $acwt);"` # auxf impr normalized by multiplying by
+ # kappa, so it's comparable to an objective-function change.
+ echo On iter $x, objf was $diff, auxf improvement was $impr | tee $dir/objf.$x.log
+ x=$[$x+1]
+ done
+
+ # Just copy the source-dir's occs, in case we later need them for something...
+ cp $alidir/final.occs $dir
+ ( cd $dir; ln -s $x.mdl final.mdl )
+
+
+ echo Done
diff --cc egs/wsj/s3/steps/train_lda_etc_mmi.sh
index 9becdf7919eb1598e47801e9a1fceec7c0350c91,0000000000000000000000000000000000000000..f29a5b8f08bd30557cb72468112453b6d246acaa
mode 100755,000000..100755
mode 100755,000000..100755
--- /dev/null
- gmm-est-mmi $cur_mdl $dir/num_acc.$x.acc $dir/den_acc.$x.acc $dir/$[$x+1].mdl \
- || exit 1;
+#!/bin/bash
+# Copyright 2010-2011 Microsoft Corporation Arnab Ghoshal
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+# WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+# MERCHANTABLITY OR NON-INFRINGEMENT.
+# See the Apache 2 License for the specific language governing permissions and
+# limitations under the License.
+
+# To be run from ..
+# This script does MMI training
+# This script trains a model on top of LDA + [something] features, where
+# [something] may be MLLT, or ET, or MLLT + SAT. Any speaker-specific
+# transforms are expected to be located in the alignment directory.
+# This script never re-estimates any transforms, it just does model
+# training. To make this faster, it initializes the model from the
+# old system's model, i.e. for each p.d.f., it takes the best-match pdf
+# from the old system (based on overlap of tree-stats counts), and
+# uses that GMM to initialize the current GMM.
+
+niters=4
+nj=4
+boost=0.0
++tau=100
+cmd=scripts/run.pl
+acwt=0.1
+stage=0
+
+for x in `seq 8`; do
+ if [ $1 == "--num-jobs" ]; then
+ shift; nj=$1; shift
+ fi
+ if [ $1 == "--num-iters" ]; then
+ shift; niters=$1; shift
+ fi
+ if [ $1 == "--boost" ]; then
+ shift; boost=$1; shift
+ fi
+ if [ $1 == "--cmd" ]; then
+ shift; cmd=$1; shift
+ [ -z "$cmd" ] && echo Empty argument to --cmd option && exit 1;
+ fi
+ if [ $1 == "--acwt" ]; then
+ shift; acwt=$1; shift
+ fi
+ if [ $1 == "--stage" ]; then
+ shift; stage=$1; shift
+ fi
+done
+
+if [ $# != 6 ]; then
+ echo "Usage: steps/train_lda_etc_mmi.sh <data-dir> <lang-dir> <ali-dir> <denlat-dir> <model-dir> <exp-dir>"
+ echo " e.g.: steps/train_lda_etc_mmi.sh data/train_si84 data/lang exp/tri2b_ali_si84 exp/tri2b_denlats_si84 exp/tri2b exp/tri2b_mmi"
+ exit 1;
+fi
+
+if [ -f path.sh ]; then . path.sh; fi
+
+data=$1
+lang=$2
+alidir=$3
+denlatdir=$4
+srcdir=$5 # may be same model as in alidir, but may not be, e.g.
+ # if you want to test MMI with different #iters.
+dir=$6
+silphonelist=`cat $lang/silphones.csl`
+mkdir -p $dir/log
+
+if [ ! -f $srcdir/final.mdl -o ! -f $srcdir/final.mat ]; then
+ echo "Error: alignment dir $alidir does not contain one of final.mdl or final.mat"
+ exit 1;
+fi
+cp $srcdir/final.mat $srcdir/tree $dir
+
+n=`get_splits.pl $nj | awk '{print $1}'`
+if [ -f $alidir/$n.trans ]; then
+ use_trans=true
+ echo Using transforms from directory $alidir
+else
+ echo No transforms present in alignment directory: assuming speaker independent.
+ use_trans=false
+fi
+
+for n in `get_splits.pl $nj`; do
+ featspart[$n]="ark,s,cs:apply-cmvn --norm-vars=false --utt2spk=ark:$data/split$nj/$n/utt2spk ark:$alidir/$n.cmvn scp:$data/split$nj/$n/feats.scp ark:- | splice-feats ark:- ark:- | transform-feats $alidir/final.mat ark:- ark:- |"
+ $use_trans && featspart[$n]="${featspart[$n]} transform-feats --utt2spk=ark:$data/split$nj/$n/utt2spk ark:$alidir/$n.trans ark:- ark:- |"
+
+ [ ! -f $denlatdir/lat.$n.gz ] && echo No such file $denlatdir/lat.$n.gz && exit 1;
+ latspart[$n]="ark:gunzip -c $denlatdir/lat.$n.gz|"
+ # note: in next line, doesn't matter which model we use, it's only used to map to phones.
+ [ $boost != "0.0" -a $boost != "0" ] && latspart[$n]="${latspart[$n]} lattice-boost-ali --b=$boost --silence-phones=$silphonelist $alidir/final.mdl ark:- 'ark,s,cs:gunzip -c $alidir/$n.ali.gz|' ark:- |"
+done
+
+rm $dir/.error 2>/dev/null
+cur_mdl=$srcdir/final.mdl
+x=0
+while [ $x -lt $niters ]; do
+ echo "Iteration $x: getting denominator stats."
+ # Get denominator stats... For simplicity we rescore the lattice
+ # on all iterations, even though it shouldn't be necessary on the zeroth
+ # (but we want this script to work even if $srcdir doesn't contain the
+ # model used to generate the lattice).
+ if [ $stage -le $x ]; then
+ for n in `get_splits.pl $nj`; do
+ $cmd $dir/log/acc_den.$x.$n.log \
+ gmm-rescore-lattice $cur_mdl "${latspart[$n]}" "${featspart[$n]}" ark:- \| \
+ lattice-to-post --acoustic-scale=$acwt ark:- ark:- \| \
+ gmm-acc-stats $cur_mdl "${featspart[$n]}" ark:- $dir/den_acc.$x.$n.acc \
+ || touch $dir/.error &
+ done
+ wait
+ [ -f $dir/.error ] && echo Error accumulating den stats on iter $x && exit 1;
+ $cmd $dir/log/den_acc_sum.$x.log \
+ gmm-sum-accs $dir/den_acc.$x.acc $dir/den_acc.$x.*.acc || exit 1;
+ rm $dir/den_acc.$x.*.acc
+
+ echo "Iteration $x: getting numerator stats."
+ for n in `get_splits.pl $nj`; do
+ $cmd $dir/log/acc_num.$x.$n.log \
+ gmm-acc-stats-ali $cur_mdl "${featspart[$n]}" "ark:gunzip -c $alidir/$n.ali.gz|" \
+ $dir/num_acc.$x.$n.acc || touch $dir/.error &
+ done
+ wait;
+ [ -f $dir/.error ] && echo Error accumulating num stats on iter $x && exit 1;
+ $cmd $dir/log/num_acc_sum.$x.log \
+ gmm-sum-accs $dir/num_acc.$x.acc $dir/num_acc.$x.*.acc || exit 1;
+ rm $dir/num_acc.$x.*.acc
+
+ $cmd $dir/log/update.$x.log \
- impr=`grep Overall $dir/log/update.$x.log | awk '{print $10;}'`
++ gmm-est-gaussians-ebw $cur_mdl "gmm-ismooth-stats --tau=$tau $dir/num_acc.$x.acc $dir/num_acc.$x.acc -|" \
++ $dir/den_acc.$x.acc - \| \
++ gmm-est-weights-ebw - $dir/num_acc.$x.acc $dir/den_acc.$x.acc $dir/$[$x+1].mdl || exit 1;
+ else
+ echo "not doing this iteration because --stage=$stage"
+ fi
+ cur_mdl=$dir/$[$x+1].mdl
+
+ # Some diagnostics
+ den=`grep Overall $dir/log/acc_den.$x.*.log | grep lattice-to-post | awk '{p+=$7*$9; nf+=$9;} END{print p/nf;}'`
+ num=`grep Overall $dir/log/acc_num.$x.*.log | grep gmm-acc-stats-ali | awk '{p+=$11*$13; nf+=$13;} END{print p/nf}'`
+ diff=`perl -e "print ($num * $acwt - $den);"`
++ impr=`grep Overall $dir/log/update.$x.log | head -1 | awk '{print $10;}'`
+ impr=`perl -e "print ($impr * $acwt);"` # auxf impr normalized by multiplying by
+ # kappa, so it's comparable to an objective-function change.
+ echo On iter $x, objf was $diff, auxf improvement was $impr | tee $dir/objf.$x.log
+
+ x=$[$x+1]
+done
+
+echo "Succeeded with $niters iterations of MMI training (boosting factor = $boost)"
+
+( cd $dir; ln -s $x.mdl final.mdl )
diff --cc src/bin/Makefile
index 13f0c19fddcc6a37bedb8a241c6baf02cdc641a1,2218ab3e0bb789a685aa39f602b34f79b54e76db..dff75f187285beb2b436cf2e405e1f32f44e1261
--- 1/src/bin/Makefile
--- 2/src/bin/Makefile
+++ b/src/bin/Makefile
ali-to-phones ali-to-post weight-silence-post acc-lda est-lda \
ali-to-pdf est-mllt build-tree build-tree-two-level decode-faster \
decode-faster-mapped scale-vecs copy-transition-model rand-prune-post \
- phones-to-prons prons-to-wordali copy-gselect copy-tree
- phones-to-prons prons-to-wordali scale-post
++ phones-to-prons prons-to-wordali copy-gselect copy-tree scale-post
OBJFILES =
diff --cc src/gmm/Makefile
index 4b93bc7d96b112f3a327fb0299c96aa37382121b,4b93bc7d96b112f3a327fb0299c96aa37382121b..46b69888051be91fc3d8395892f70eaf0b03f88f
--- 1/src/gmm/Makefile
--- 2/src/gmm/Makefile
+++ b/src/gmm/Makefile
include ../kaldi.mk
TESTFILES = diag-gmm-test mle-diag-gmm-test full-gmm-test mle-full-gmm-test \
-- am-diag-gmm-test mmie-diag-gmm-test mmie-am-diag-gmm-test
++ am-diag-gmm-test ebw-diag-gmm-test
OBJFILES = diag-gmm.o diag-gmm-normal.o mle-diag-gmm.o am-diag-gmm.o mle-am-diag-gmm.o \
full-gmm.o full-gmm-normal.o mle-full-gmm.o fmpe-am-diag-gmm.o model-common.o \
-- model-test-common.o ebw-diag-gmm.o mmie-diag-gmm.o mmie-am-diag-gmm.o
++ model-test-common.o ebw-diag-gmm.o
LIBFILE = kaldi-gmm.a
diff --cc src/gmm/ebw-diag-gmm-test.cc
index 2cadb4c280fa32bf4aa5ec620f730cacba5426fa,e03300108a3a4dfbb65160fcb8c39f561c81c7ee..5dedc23847e638ae866c769b177cbe6bade0cf58
--// gmm/mmie-diag-gmm-test.cc
++// gmm/ebw-diag-gmm-test.cc
// Copyright 2009-2011 Petr Motlicek
#include "gmm/diag-gmm.h"
--#include "gmm/mmie-diag-gmm.h"
++#include "gmm/ebw-diag-gmm.h"
#include "util/kaldi-io.h"
gmm->ComputeGconsts();
-- MmieAccumDiagGmm mmie_gmm;
++ EbwOptions ebw_opts;
++ EbwWeightOptions ebw_weight_opts;
-- MmieDiagGmmOptions config;
-- config.min_variance = 0.01;
-- GmmFlagsType flags = kGmmAll; // Should later try reducing this.
++ int r = rand() % 16;
++ GmmFlagsType flags = (r%2 == 0 ? kGmmMeans : 0)
++ + ((r/2)%2 == 0 ? kGmmVariances : 0)
++ + ((r/4)%2 == 0 ? kGmmWeights : 0);
++ double tau = (r/8)%2 == 0 ? 100 : 0.0;
++
++ if ((flags & kGmmVariances) && !(flags & kGmmMeans))
++ return; // Don't do this case: not supported in the update equations.
AccumDiagGmm num;
AccumDiagGmm den;
den.Resize(gmm->NumGauss(), gmm->Dim(), flags);
den.SetZero(flags);
-- mmie_gmm.Resize(gmm->NumGauss(), gmm->Dim(), flags);
-
--
-// iterate
size_t iteration = 0;
- float lastloglike = 0.0;
- int32 lastloglike_nM = 0;
-
++ double last_log_like_diff;
while (iteration < maxiterations) {
Vector<BaseFloat> featvec_num(dim);
Vector<BaseFloat> featvec_den(dim);
num.SetZero(flags);
den.Resize(gmm->NumGauss(), gmm->Dim(), flags);
den.SetZero(flags);
-- mmie_gmm.Resize(gmm->NumGauss(), gmm->Dim(), flags);
--
double loglike_num = 0.0;
double loglike_den = 0.0;
<< std::scientific << loglike_den << " number of components: "
<< gmm->NumGauss() << '\n';
--
- mmie_gmm.SubtractAccumulatorsISmoothing(num, den, config);
- mmie_gmm.SubtractAccumulatorsISmoothing(num, den, config, num);
++ double loglike_diff = loglike_num - loglike_den;
++ if (iteration > 0) {
++ KALDI_LOG << "Objective changed " << last_log_like_diff
++ << " to " << loglike_diff;
++ if (loglike_diff < last_log_like_diff)
++ KALDI_WARN << "Objective decreased (flags = "
++ << GmmFlagsToString(flags) << ", tau = " << tau << " )";
++ }
++ last_log_like_diff = loglike_diff;
++
++ AccumDiagGmm num_smoothed(num);
++ num_smoothed.SmoothStats(tau); // Apply I-smoothing.
++
BaseFloat auxf_gauss, auxf_weight, count;
-- //Vector<double> mean_hlp(dim);
-- //mean_hlp.CopyFromVec(gmm->means_invvars().Row(0));
-- //std::cout << "MEANX: " << mean_hlp << '\n';
std::cout << "MEANX: " << gmm->weights() << '\n';
-- // binary write
- {
- Output ko("tmp_stats", false);
- mmie_gmm.Write(ko.Stream(), false);
- ko.Stream().flush();
- }
- mmie_gmm.Write(Output("tmp_stats", false).Stream(), false);
-
- // binary read
- bool binary_in;
- Input ki("tmp_stats", &binary_in);
- mmie_gmm.Read(ki.Stream(), binary_in, false); // false = not adding.
-
+ int32 num_floored;
- mmie_gmm.Update(config, flags, gmm, &auxf_gauss, &auxf_weight, &count,
- &num_floored);
++ UpdateEbwDiagGmm(num_smoothed, den, flags, ebw_opts,
++ gmm, &auxf_gauss, &count, &num_floored);
+
++ if (flags & kGmmWeights) {
++ UpdateEbwWeightsDiagGmm(num, den, ebw_weight_opts, gmm, &auxf_weight,
++ &count);
++ }
+
- // binary read
- bool binary_in;
- Input ki("tmp_stats", &binary_in);
- mmie_gmm.Read(ki.Stream(), binary_in, false); // false = not adding.
-
- int32 num_floored;
- mmie_gmm.Update(config, flags, gmm, &auxf_gauss, &auxf_weight, &count,
- &num_floored);
//mean_hlp.CopyFromVec(gmm->means_invvars().Row(0));
//std::cout << "MEANY: " << mean_hlp << '\n';
-- std::cout << "MEANY: " << gmm->weights() << '\n';
++ std::cout << "MEANY: " << gmm->weights() << '\n';
if ((iteration % 3 == 1) && (gmm->NumGauss() * 2 <= maxcomponents)) {
gmm->Split(gmm->NumGauss() * 2, 0.001);
std::cout << "Ngauss, Ndim: " << gmm->NumGauss() << " " << gmm->Dim() << '\n';
--
}
--
iteration++;
}
}
int main() {
-- // repeat the test 5 times
-- for (int i = 0; i < 5; ++i) {
++ // repeat the test 20 times
++ for (int i = 0; i < 20; ++i) {
kaldi::UnitTestEstimateMmieDiagGmm();
}
std::cout << "Test OK.\n";
diff --cc src/gmm/ebw-diag-gmm.cc
index 7e9ada534796b715d5ae78f0c4bafcda62709137,7e9ada534796b715d5ae78f0c4bafcda62709137..46aa1262e688874d67feae561579426aae143779
+++ b/src/gmm/ebw-diag-gmm.cc
--// gmm/mle-diag-gmm.cc
++// gmm/ebw-diag-gmm.cc
// Copyright 2009-2011 Arnab Ghoshal, Petr Motlicek
namespace kaldi {
--void AccumEbwDiagGmm::Read(std::istream &in_stream, bool binary, bool add) {
-- int32 dimension, num_components;
-- GmmFlagsType flags;
-- std::string token;
--
-- ExpectMarker(in_stream, binary, "<GMMEBWACCS>");
-- ExpectMarker(in_stream, binary, "<VECSIZE>");
-- ReadBasicType(in_stream, binary, &dimension);
-- ExpectMarker(in_stream, binary, "<NUMCOMPONENTS>");
-- ReadBasicType(in_stream, binary, &num_components);
-- ExpectMarker(in_stream, binary, "<FLAGS>");
-- ReadBasicType(in_stream, binary, &flags);
--
-- if (add) {
-- if ((NumGauss() != 0 || Dim() != 0 || Flags() != 0)) {
-- if (num_components != NumGauss() || dimension != Dim()
-- || flags != Flags()) {
-- KALDI_ERR << "Dimension or flags mismatch: " << NumGauss() << ", "
-- << Dim() << ", " << Flags() << " vs. " << num_components
-- << ", " << dimension << ", " << flags;
++// This function is used inside the EBW update routines.
++// returns true if all variances were positive.
++static bool EBWUpdateGaussian(
++ BaseFloat D,
++ GmmFlagsType flags,
++ const VectorBase<double> &orig_mean,
++ const VectorBase<double> &orig_var,
++ const VectorBase<double> &x_stats,
++ const VectorBase<double> &x2_stats,
++ double occ,
++ VectorBase<double> *mean,
++ VectorBase<double> *var,
++ double *auxf_impr) {
++ if (! (flags&(kGmmMeans|kGmmVariances)) || occ <= 0.0) { // nothing to do.
++ if (auxf_impr) *auxf_impr = 0.0;
++ mean->CopyFromVec(orig_mean);
++ var->CopyFromVec(orig_var);
++ return true;
++ }
++ KALDI_ASSERT(!( (flags&kGmmVariances) && !(flags&kGmmMeans))
++ && "We didn't make the update cover this case sensibly (update vars not means)");
++
++ mean->SetZero();
++ var->SetZero();
++ mean->AddVec(D, orig_mean);
++ var->AddVec2(D, orig_mean);
++ var->AddVec(D, orig_var);
++ mean->AddVec(1.0, x_stats);
++ var->AddVec(1.0, x2_stats);
++ BaseFloat scale = 1.0 / (occ + D);
++ mean->Scale(scale);
++ var->Scale(scale);
++ var->AddVec2(-1.0, *mean);
++
++ if (!(flags&kGmmVariances)) var->CopyFromVec(orig_var);
++ if (!(flags&kGmmMeans)) mean->CopyFromVec(orig_mean);
++
++ // Return false if any NaN's.
++ for (int32 i = 0; i < mean->Dim(); i++) {
++ double m = ((*mean)(i)), v = ((*var)(i));
++ if (m!=m || v!=v || m-m != 0 || v-v != 0) {
++ return false;
++ }
++ }
++
++ if (var->Min() > 0.0) {
++ if (auxf_impr != NULL) {
++ // work out auxf improvement.
++ BaseFloat old_auxf = 0.0, new_auxf = 0.0;
++ int32 dim = orig_mean.Dim();
++ for (int32 i = 0; i < dim; i++) {
++ BaseFloat mean_diff = (*mean)(i) - orig_mean(i);
++ old_auxf += (occ+D) * -0.5 * (log(orig_var(i)) +
++ ((*var)(i) + mean_diff*mean_diff)
++ / orig_var(i));
++ new_auxf += (occ+D) * -0.5 * (log((*var)(i)) + 1.0);
++
}
-- } else {
-- Resize(num_components, dimension, flags);
++ *auxf_impr = new_auxf - old_auxf;
}
++ return true;
++ } else return false;
++}
++
++// Update Gaussian parameters only (no weights)
++void UpdateEbwDiagGmm(const AccumDiagGmm &num_stats, // with I-smoothing, if used.
++ const AccumDiagGmm &den_stats,
++ GmmFlagsType flags,
++ const EbwOptions &opts,
++ DiagGmm *gmm,
++ BaseFloat *auxf_change_out,
++ BaseFloat *count_out,
++ int32 *num_floored_out) {
++ GmmFlagsType acc_flags = num_stats.Flags();
++ if (flags & ~acc_flags)
++ KALDI_ERR << "Incompatible flags: you requested to update flags \""
++ << GmmFlagsToString(flags) << "\" but accumulators have only \""
++ << GmmFlagsToString(acc_flags) << '"';
++
++ // It could be that the num stats actually contain the difference between
++ // num and den (for mean and var stats), and den stats only have the weights.
++ bool den_has_stats;
++ if (den_stats.Flags() != acc_flags) {
++ den_has_stats = false;
++ if (den_stats.Flags() != kGmmWeights)
++ KALDI_ERR << "Incompatible flags: num stats have flags \""
++ << GmmFlagsToString(acc_flags) << "\" vs. den stats \""
++ << GmmFlagsToString(den_stats.Flags()) << '"';
} else {
-- Resize(num_components, dimension, flags);
++ den_has_stats = true;
++ }
++ int32 num_comp = num_stats.NumGauss();
++ int32 dim = num_stats.Dim();
++ KALDI_ASSERT(num_stats.NumGauss() == den_stats.NumGauss());
++ KALDI_ASSERT(num_stats.Dim() == gmm->Dim());
++ KALDI_ASSERT(gmm->NumGauss() == num_comp);
++
++ if ( !(flags & (kGmmMeans | kGmmVariances)) ) {
++ return; // Nothing to update.
}
++
++ // copy DiagGMM model and transform this to the normal case
++ DiagGmmNormal diaggmmnormal;
++ gmm->ComputeGconsts();
++ diaggmmnormal.CopyFromDiagGmm(*gmm);
-- ReadMarker(in_stream, binary, &token);
-- while (token != "</GMMEBWACCS>") {
-- if (token == "<NUM_OCCUPANCY>") {
-- num_occupancy_.Read(in_stream, binary, add);
-- } else if (token == "<DEN_OCCUPANCY>") {
-- den_occupancy_.Read(in_stream, binary, add);
-- } else if (token == "<MEANACCS>") {
-- mean_accumulator_.Read(in_stream, binary, add);
-- } else if (token == "<DIAGVARACCS>") {
-- variance_accumulator_.Read(in_stream, binary, add);
-- } else {
-- KALDI_ERR << "Unexpected token '" << token << "' in model file ";
++ // go over all components
++ Vector<double> mean(dim), var(dim), mean_stats(dim), var_stats(dim);
++
++ for (int32 g = 0; g < num_comp; g++) {
++ BaseFloat num_count = num_stats.occupancy()(g),
++ den_count = den_stats.occupancy()(g);
++ if (num_count == 0.0 && den_count == 0.0) {
++ KALDI_VLOG(2) << "Not updating Gaussian " << g << " since counts are zero";
++ continue;
++ }
++ mean_stats.CopyFromVec(num_stats.mean_accumulator().Row(g));
++ if (den_has_stats)
++ mean_stats.AddVec(-1.0, den_stats.mean_accumulator().Row(g));
++ if (flags & kGmmVariances) {
++ var_stats.CopyFromVec(num_stats.variance_accumulator().Row(g));
++ if (den_has_stats)
++ var_stats.AddVec(-1.0, den_stats.variance_accumulator().Row(g));
++ }
++ double D = opts.E * den_count / 2; // E*gamma_den/2 where E = 2;
++ // We initialize to half the value of D that would be dictated by
++ // E; this is part of the strategy used to ensure that the value of
++ // D we use is at least twice the value that would ensure positive
++ // variances.
++
++ int32 iter, max_iter = 100;
++ for (iter = 0; iter < max_iter; iter++) { // will normally break from the loop
++ // the first time.
++ if (EBWUpdateGaussian(D, flags,
++ diaggmmnormal.means_.Row(g),
++ diaggmmnormal.vars_.Row(g),
++ mean_stats, var_stats, num_count-den_count,
++ &mean, &var, NULL)) {
++ // Succeeded in getting all +ve vars at this value of D.
++ // So double D and commit changes.
++ D *= 2.0;
++ double auxf_impr = 0.0;
++ EBWUpdateGaussian(D, flags,
++ diaggmmnormal.means_.Row(g),
++ diaggmmnormal.vars_.Row(g),
++ mean_stats, var_stats, num_count-den_count,
++ &mean, &var, &auxf_impr);
++
++ if (auxf_change_out) *auxf_change_out += auxf_impr;
++ if (count_out) *count_out += den_count; // The idea is that for MMI, this will
++ // reflect the actual #frames trained on (the numerator one would be I-smoothed).
++ // In general (e.g. for MPE), we won't know the #frames.
++ diaggmmnormal.means_.CopyRowFromVec(mean, g);
++ diaggmmnormal.vars_.CopyRowFromVec(var, g);
++ break;
++ } else {
++ // small step
++ D *= 1.1;
++ }
}
-- ReadMarker(in_stream, binary, &token);
++ if (iter > 0 && num_floored_out != NULL) *num_floored_out++;
++ if (iter == max_iter) KALDI_WARN << "Dropped off end of loop, recomputing D. (unexpected.)";
}
++ // copy to natural representation according to flags.
++ diaggmmnormal.CopyToDiagGmm(gmm, flags);
++ gmm->ComputeGconsts();
}
--void AccumEbwDiagGmm::Write(std::ostream &out_stream, bool binary) const {
-- WriteMarker(out_stream, binary, "<GMMEBWACCS>");
-- WriteMarker(out_stream, binary, "<VECSIZE>");
-- WriteBasicType(out_stream, binary, dim_);
-- WriteMarker(out_stream, binary, "<NUMCOMPONENTS>");
-- WriteBasicType(out_stream, binary, num_comp_);
-- WriteMarker(out_stream, binary, "<FLAGS>");
-- WriteBasicType(out_stream, binary, flags_);
--
-- // convert into BaseFloat before writing things
-- Vector<BaseFloat> num_occupancy_bf(num_occupancy_.Dim());
-- Vector<BaseFloat> den_occupancy_bf(den_occupancy_.Dim());
-- Matrix<BaseFloat> mean_accumulator_bf(mean_accumulator_.NumRows(),
-- mean_accumulator_.NumCols());
-- Matrix<BaseFloat> variance_accumulator_bf(variance_accumulator_.NumRows(),
-- variance_accumulator_.NumCols());
-- num_occupancy_bf.CopyFromVec(num_occupancy_);
-- den_occupancy_bf.CopyFromVec(den_occupancy_);
-- mean_accumulator_bf.CopyFromMat(mean_accumulator_);
-- variance_accumulator_bf.CopyFromMat(variance_accumulator_);
--
-- WriteMarker(out_stream, binary, "<NUM_OCCUPANCY>");
-- num_occupancy_bf.Write(out_stream, binary);
-- WriteMarker(out_stream, binary, "<DEN_OCCUPANCY>");
-- den_occupancy_bf.Write(out_stream, binary);
-- WriteMarker(out_stream, binary, "<MEANACCS>");
-- mean_accumulator_bf.Write(out_stream, binary);
-- WriteMarker(out_stream, binary, "<DIAGVARACCS>");
-- variance_accumulator_bf.Write(out_stream, binary);
-- WriteMarker(out_stream, binary, "</GMMEBWACCS>");
--}
++void UpdateEbwWeightsDiagGmm(const AccumDiagGmm &num_stats, // should have no I-smoothing
++ const AccumDiagGmm &den_stats,
++ const EbwWeightOptions &opts,
++ DiagGmm *gmm,
++ BaseFloat *auxf_change_out,
++ BaseFloat *count_out) {
--void AccumEbwDiagGmm::Resize(int32 num_comp, int32 dim, GmmFlagsType flags) {
-- KALDI_ASSERT(num_comp > 0 && dim > 0);
-- num_comp_ = num_comp;
-- dim_ = dim;
-- flags_ = AugmentGmmFlags(flags);
-- num_occupancy_.Resize(num_comp);
-- den_occupancy_.Resize(num_comp);
-- if (flags_ & kGmmMeans)
-- mean_accumulator_.Resize(num_comp, dim);
-- else
-- mean_accumulator_.Resize(0, 0);
-- if (flags_ & kGmmVariances)
-- variance_accumulator_.Resize(num_comp, dim);
-- else
-- variance_accumulator_.Resize(0, 0);
--}
++ DiagGmmNormal diaggmmnormal;
++ gmm->ComputeGconsts();
++ diaggmmnormal.CopyFromDiagGmm(*gmm);
--void AccumEbwDiagGmm::SetZero(GmmFlagsType flags) {
-- if (flags & ~flags_)
-- KALDI_ERR << "Flags in argument do not match the active accumulators";
-- if (flags & kGmmWeights) {
-- num_occupancy_.SetZero();
-- den_occupancy_.SetZero();
++ Vector<double> weights(diaggmmnormal.weights_),
++ num_occs(num_stats.occupancy()),
++ den_occs(den_stats.occupancy());
++ if (num_occs.Sum() + den_occs.Sum() < opts.min_num_count_weight_update) {
++ KALDI_LOG << "Not updating weights for this state because total count is "
++ << num_occs.Sum() + den_occs.Sum() << " < "
++ << opts.min_num_count_weight_update;
++ return;
}
-- if (flags & kGmmMeans) mean_accumulator_.SetZero();
-- if (flags & kGmmVariances) variance_accumulator_.SetZero();
--}
--
++ KALDI_ASSERT(weights.Dim() == num_occs.Dim() && num_occs.Dim() == den_occs.Dim());
++ if (weights.Dim() == 1) return; // Nothing to do: only one mixture.
++ double weight_auxf_at_start = 0.0, weight_auxf_at_end = 0.0;
--void AccumEbwDiagGmm::Scale(BaseFloat f, GmmFlagsType flags) {
-- if (flags & ~flags_)
-- KALDI_ERR << "Flags in argument do not match the active accumulators";
-- double d = static_cast<double>(f);
-- if (flags & kGmmWeights) {
-- num_occupancy_.Scale(d);
-- den_occupancy_.SetZero();
++ int32 num_comp = weights.Dim();
++ for (int32 g = 0; g < num_comp; g++) { // c.f. eq. 4.32 in Dan Povey's thesis.
++ weight_auxf_at_start +=
++ num_occs(g) * log (weights(g))
++ - den_occs(g) * weights(g) / diaggmmnormal.weights_(g);
}
-- if (flags & kGmmMeans) mean_accumulator_.Scale(d);
-- if (flags & kGmmVariances) variance_accumulator_.Scale(d);
--}
++ for (int32 iter = 0; iter < 50; iter++) {
++ Vector<double> k_jm(num_comp); // c.f. eq. 4.35
++ double max_m = 0.0;
++ for (int32 g = 0; g < num_comp; g++)
++ max_m = std::max(max_m, den_occs(g)/diaggmmnormal.weights_(g));
++ for (int32 g = 0; g < num_comp; g++)
++ k_jm(g) = max_m - den_occs(g)/diaggmmnormal.weights_(g);
++ for (int32 g = 0; g < num_comp; g++) // c.f. eq. 4.34
++ weights(g) = num_occs(g) + k_jm(g)*weights(g);
++ weights.Scale(1.0 / weights.Sum()); // c.f. eq. 4.34 (denominator)
++ }
++ for (int32 g = 0; g < num_comp; g++) { // weight flooring.
++ if (weights(g) < opts.min_gaussian_weight)
++ weights(g) = opts.min_gaussian_weight;
++ }
++ weights.Scale(1.0 / weights.Sum()); // renormalize after flooring..
++ // floor won't be exact now but doesn't really matter.
--void AccumEbwDiagGmm::AccumulateFromPosteriors(
-- const VectorBase<BaseFloat>& data,
-- const VectorBase<BaseFloat>& pos_post,
-- const VectorBase<BaseFloat>& neg_post) {
-- assert(static_cast<int32>(data.Dim()) == Dim());
-- assert(static_cast<int32>(pos_post.Dim()) == NumGauss());
-- Vector<double> pos_post_d(pos_post),
-- neg_post_d(neg_post); // Copy with type-conversion
--
-- // accumulate
-- num_occupancy_.AddVec(1.0, pos_post_d);
-- num_occupancy_.AddVec(1.0, neg_post_d);
-- if (flags_ & kGmmMeans) {
-- Vector<double> data_d(data); // Copy with type-conversion
-- // TODO(arnab): we need to decide whether the neg posts have negative value
-- mean_accumulator_.AddVecVec(1.0, pos_post_d, data_d);
-- mean_accumulator_.AddVecVec(-1.0, neg_post_d, data_d);
-- if (flags_ & kGmmVariances) {
-- data_d.ApplyPow(2.0);
-- variance_accumulator_.AddVecVec(1.0, pos_post_d, data_d);
-- variance_accumulator_.AddVecVec(-1.0, neg_post_d, data_d);
-- }
++ for (int32 g = 0; g < num_comp; g++) { // c.f. eq. 4.32 in Dan Povey's thesis.
++ weight_auxf_at_end +=
++ num_occs(g) * log (weights(g))
++ - den_occs(g) * weights(g) / diaggmmnormal.weights_(g);
}
--}
--void AccumEbwDiagGmm::SmoothWithAccum(BaseFloat tau, const AccumDiagGmm& src_acc) {
-- KALDI_ASSERT(src_acc.NumGauss() == num_comp_ && src_acc.Dim() == dim_);
-- double tau_d = static_cast<double>(tau);
-- num_occupancy_.AddVec(tau_d, src_acc.occupancy());
-- mean_accumulator_.AddMat(tau_d, src_acc.mean_accumulator(), kNoTrans);
-- variance_accumulator_.AddMat(tau_d, src_acc.variance_accumulator(), kNoTrans);
++ if (auxf_change_out)
++ *auxf_change_out += weight_auxf_at_end - weight_auxf_at_start;
++ if (count_out)
++ *count_out += num_occs.Sum(); // only really valid for MMI.
++
++ diaggmmnormal.weights_.CopyFromVec(weights);
++
++ // copy to natural representation
++ diaggmmnormal.CopyToDiagGmm(gmm, kGmmAll);
++ gmm->ComputeGconsts();
}
++void UpdateEbwAmDiagGmm(const AccumAmDiagGmm &num_stats, // with I-smoothing, if used.
++ const AccumAmDiagGmm &den_stats,
++ GmmFlagsType flags,
++ const EbwOptions &opts,
++ AmDiagGmm *am_gmm,
++ BaseFloat *auxf_change_out,
++ BaseFloat *count_out,
++ int32 *num_floored_out) {
++ KALDI_ASSERT(num_stats.NumAccs() == den_stats.NumAccs()
++ && num_stats.NumAccs() == am_gmm->NumPdfs());
--void AccumEbwDiagGmm::SmoothWithModel(BaseFloat tau, const DiagGmm& gmm) {
-- KALDI_ASSERT(gmm.NumGauss() == num_comp_ && gmm.Dim() == dim_);
-- Matrix<double> means(num_comp_, dim_);
-- Matrix<double> vars(num_comp_, dim_);
-- gmm.GetMeans(&means);
-- gmm.GetVars(&vars);
++ for (int32 pdf = 0; pdf < num_stats.NumAccs(); pdf++)
++ UpdateEbwDiagGmm(num_stats.GetAcc(pdf), den_stats.GetAcc(pdf), flags,
++ opts, &(am_gmm->GetPdf(pdf)), auxf_change_out,
++ count_out, num_floored_out);
++}
-- mean_accumulator_.AddMat(tau, means);
-- means.ApplyPow(2.0);
-- vars.AddMat(1.0, means, kNoTrans);
-- variance_accumulator_.AddMat(tau, vars);
-- num_occupancy_.Add(tau);
--}
++void UpdateEbwWeightsAmDiagGmm(const AccumAmDiagGmm &num_stats, // with I-smoothing, if used.
++ const AccumAmDiagGmm &den_stats,
++ const EbwWeightOptions &opts,
++ AmDiagGmm *am_gmm,
++ BaseFloat *auxf_change_out,
++ BaseFloat *count_out) {
++ KALDI_ASSERT(num_stats.NumAccs() == den_stats.NumAccs()
++ && num_stats.NumAccs() == am_gmm->NumPdfs());
++
++ for (int32 pdf = 0; pdf < num_stats.NumAccs(); pdf++)
++ UpdateEbwWeightsDiagGmm(num_stats.GetAcc(pdf), den_stats.GetAcc(pdf),
++ opts, &(am_gmm->GetPdf(pdf)), auxf_change_out,
++ count_out);
++}
--AccumEbwDiagGmm::AccumEbwDiagGmm(const AccumEbwDiagGmm &other)
-- : dim_(other.dim_), num_comp_(other.num_comp_),
-- flags_(other.flags_), num_occupancy_(other.num_occupancy_),
-- den_occupancy_(other.den_occupancy_),
-- mean_accumulator_(other.mean_accumulator_),
-- variance_accumulator_(other.variance_accumulator_) {}
} // End of namespace kaldi
diff --cc src/gmm/ebw-diag-gmm.h
index 696ab780b0dee7d8258a1e877a25e07057656208,696ab780b0dee7d8258a1e877a25e07057656208..cac5896a56b3e99c620573bbc658792cd22eabbe
+++ b/src/gmm/ebw-diag-gmm.h
#include "gmm/diag-gmm.h"
#include "gmm/diag-gmm-normal.h"
#include "gmm/mle-diag-gmm.h"
++#include "gmm/mle-am-diag-gmm.h"
#include "gmm/model-common.h"
#include "util/parse-options.h"
namespace kaldi {
--class AccumEbwDiagGmm {
-- public:
-- AccumEbwDiagGmm(): dim_(0), num_comp_(0), flags_(0) {}
-- explicit AccumEbwDiagGmm(const DiagGmm &gmm, GmmFlagsType flags) {
-- Resize(gmm.NumGauss(), gmm.Dim(), flags);
++// Options for Extended Baum-Welch Gaussian update.
++struct EbwOptions {
++ BaseFloat E;
++ EbwOptions(): E(2.0) { }
++ void Register(ParseOptions *po) {
++ std::string module = "EbwOptions: ";
++ po->Register("E", &E, module+"Constant E for Extended Baum-Welch (EBW) update");
}
-- // provide copy constructor.
-- explicit AccumEbwDiagGmm(const AccumEbwDiagGmm &other);
--
-- void Read(std::istream &in_stream, bool binary, bool add);
-- void Write(std::ostream &out_stream, bool binary) const;
--
-- /// Allocates memory for accumulators
-- void Resize(int32 num_comp, int32 dim, GmmFlagsType flags);
--
-- /// Returns the number of mixture components
-- int32 NumGauss() const { return num_comp_; }
-- /// Returns the dimensionality of the feature vectors
-- int32 Dim() const { return dim_; }
--
-- void SetZero(GmmFlagsType flags);
-- void Scale(BaseFloat f, GmmFlagsType flags);
--
-- // TODO(arnab): maybe it's better to acc using a single posterior, but we
-- // need to know which occ stats to add to. Create 2 functions instead?
-- /// Accumulate for all components, given the posteriors.
-- void AccumulateFromPosteriors(const VectorBase<BaseFloat>& data,
-- const VectorBase<BaseFloat>& pos_posts,
-- const VectorBase<BaseFloat>& neg_posts);
--
--
-- // TODO(arnab): we could keep the smoothing functions here as well. For
-- // example, MPE stats will be directly accumulated as EBW stats and they
-- // need to be smoothed. For MMIE, the numerator accumulator can be smoothed
-- // before doing the subtraction.
--
-- /// Smooths the accumulated counts using some other accumulator. Performs
-- /// a weighted sum of the current accumulator with the given one. An example
-- /// use for this is I-smoothing for MPE. Both accumulators must have the same
-- /// dimension and number of components.
-- void SmoothWithAccum(BaseFloat tau, const AccumDiagGmm& src_acc);
--
-- /// Smooths the accumulated counts using the parameters of a given model.
-- /// An example use of this is MAP-adaptation. The model must have the
-- /// same dimension and number of components as the current accumulator.
-- void SmoothWithModel(BaseFloat tau, const DiagGmm& src_gmm);
--
-- // Accessors
-- const GmmFlagsType Flags() const { return flags_; }
-- const Vector<double>& num_occupancy() const { return num_occupancy_; }
-- const Vector<double>& den_occupancy() const { return den_occupancy_; }
-- const Matrix<double>& mean_accumulator() const { return mean_accumulator_; }
-- const Matrix<double>& variance_accumulator() const { return variance_accumulator_; }
--
-- private:
-- int32 dim_;
-- int32 num_comp_;
-- /// Flags corresponding to the accumulators that are stored.
-- GmmFlagsType flags_;
--
-- Vector<double> num_occupancy_;
-- Vector<double> den_occupancy_;
-- Matrix<double> mean_accumulator_;
-- Matrix<double> variance_accumulator_;
};
++struct EbwWeightOptions {
++ BaseFloat min_num_count_weight_update; // minimum numerator count at state level, before we update.
++ BaseFloat min_gaussian_weight;
++ EbwWeightOptions(): min_num_count_weight_update(10.0),
++ min_gaussian_weight(1.0e-05) { }
++ void Register(ParseOptions *po) {
++ std::string module = "EbwWeightOptions: ";
++ po->Register("min-num-count-weight-update", &min_num_count_weight_update,
++ module+"Minimum numerator count required at "
++ "state level before we update the weights.");
++ po->Register("min-gaussian-weight", &min_gaussian_weight,
++ module+"Minimum Gaussian weight allowed in EBW update of weights");
++ }
++};
++
++
++// Update Gaussian parameters only (no weights)
++// The pointer parameters auxf_change_out etc. are incremented, not set.
++void UpdateEbwDiagGmm(const AccumDiagGmm &num_stats, // with I-smoothing, if used.
++ const AccumDiagGmm &den_stats,
++ GmmFlagsType flags,
++ const EbwOptions &opts,
++ DiagGmm *gmm,
++ BaseFloat *auxf_change_out,
++ BaseFloat *count_out,
++ int32 *num_floored_out);
++
++// The pointer parameters auxf_change_out etc. are incremented, not set.
++void UpdateEbwAmDiagGmm(const AccumAmDiagGmm &num_stats, // with I-smoothing, if used.
++ const AccumAmDiagGmm &den_stats,
++ GmmFlagsType flags,
++ const EbwOptions &opts,
++ AmDiagGmm *am_gmm,
++ BaseFloat *auxf_change_out,
++ BaseFloat *count_out,
++ int32 *num_floored_out);
++
++// Updates the weights using the EBW-like method described in Dan Povey's thesis
++// (this method has no tunable parameters).
++// The pointer parameters auxf_change_out etc. are incremented, not set.
++void UpdateEbwWeightsDiagGmm(const AccumDiagGmm &num_stats, // should have no I-smoothing
++ const AccumDiagGmm &den_stats,
++ const EbwWeightOptions &opts,
++ DiagGmm *gmm,
++ BaseFloat *auxf_change_out,
++ BaseFloat *count_out);
++
++// The pointer parameters auxf_change_out etc. are incremented, not set.
++void UpdateEbwWeightsAmDiagGmm(const AccumAmDiagGmm &num_stats, // should have no I-smoothing
++ const AccumAmDiagGmm &den_stats,
++ const EbwWeightOptions &opts,
++ AmDiagGmm *amA_gmm,
++ BaseFloat *auxf_change_out,
++ BaseFloat *count_out);
++
++// For I-smoothing functions, see mle-am-diag-gmm.h
++
++
} // End namespace kaldi
diff --cc src/gmm/fmpe-am-diag-gmm.cc
index 830e0966fa87a8505dde23116c33a87fd4b5b94e,830e0966fa87a8505dde23116c33a87fd4b5b94e..ce419246860f75bdf8e0cbd62228356fdb7c2956
}
void FmpeAccumModelDiff::ComputeModelParaDiff(const DiagGmm& diag_gmm,
-- const AccumEbwDiagGmm& ebw_acc,
++ const AccumDiagGmm& num_acc,
++ const AccumDiagGmm& den_acc,
const AccumDiagGmm& mle_acc) {
-- KALDI_ASSERT(ebw_acc.NumGauss() == num_comp_ && ebw_acc.Dim() == dim_);
++ KALDI_ASSERT(num_acc.NumGauss() == num_comp_ && num_acc.Dim() == dim_);
++ KALDI_ASSERT(den_acc.NumGauss() == num_comp_); // den_acc.Dim() may not be defined,
++ // if we used the "compressed form" of accs where den only has counts.
KALDI_ASSERT(mle_acc.NumGauss() == num_comp_ && mle_acc.Dim() == dim_);
--
++
Matrix<double> mean_diff_tmp(num_comp_, dim_);
Matrix<double> var_diff_tmp(num_comp_, dim_);
Matrix<double> mat_tmp(num_comp_, dim_);
Matrix<double> means_invvars(num_comp_, dim_);
Matrix<double> inv_vars(num_comp_, dim_);
-- occ_diff.CopyFromVec(ebw_acc.num_occupancy());
-- occ_diff.AddVec(-1.0, ebw_acc.den_occupancy());
++ occ_diff.CopyFromVec(num_acc.occupancy());
++ occ_diff.AddVec(-1.0, den_acc.occupancy());
means_invvars.CopyFromMat(diag_gmm.means_invvars(), kNoTrans);
inv_vars.CopyFromMat(diag_gmm.inv_vars(), kNoTrans);
/// compute the means differentials first
-- mean_diff_tmp.CopyFromMat(ebw_acc.mean_accumulator(), kNoTrans);
++ mean_diff_tmp.CopyFromMat(num_acc.mean_accumulator(), kNoTrans);
++ if (den_acc.Flags() & kGmmMeans) // probably will be false.
++ mean_diff_tmp.AddMat(-1.0, den_acc.mean_accumulator(), kNoTrans);
mean_diff_tmp.MulElements(inv_vars);
mat_tmp.CopyFromMat(means_invvars, kNoTrans);
mean_diff_accumulator_.CopyFromMat(mean_diff_tmp, kNoTrans);
/// compute the vars differentials second
-- var_diff_tmp.CopyFromMat(ebw_acc.variance_accumulator(), kNoTrans);
++ var_diff_tmp.CopyFromMat(num_acc.variance_accumulator(), kNoTrans);
++ if (den_acc.Flags() & kGmmVariances) // probably will be false.
++ var_diff_tmp.AddMat(-1.0, den_acc.variance_accumulator(), kNoTrans);
++
var_diff_tmp.MulElements(inv_vars);
var_diff_tmp.MulElements(inv_vars);
--
-- mat_tmp.CopyFromMat(ebw_acc.mean_accumulator(), kNoTrans);
++
++ mat_tmp.CopyFromMat(num_acc.mean_accumulator(), kNoTrans);
++ if (den_acc.Flags() & kGmmMeans) // probably will be false.
++ mat_tmp.AddMat(-1.0, den_acc.mean_accumulator(), kNoTrans);
mat_tmp.MulElements(inv_vars);
mat_tmp.MulElements(means_invvars);
diff --cc src/gmm/fmpe-am-diag-gmm.h
index 2bb874b9d14a4b1f9ca06de0285166ecb0b6ba5c,2bb874b9d14a4b1f9ca06de0285166ecb0b6ba5c..edf4dabb38199f916d70fc58eb59cda3d0e14938
/// the MPE training, including the numerator and denominator accumulators
/// and applies I-smoothing to the numerator accs, if needed,
/// which using mle_acc.
-- void ComputeModelParaDiff(const DiagGmm& diag_gmm,
-- const AccumEbwDiagGmm& ebw_acc,
-- const AccumDiagGmm& mle_acc);
++ void ComputeModelParaDiff(const DiagGmm &diag_gmm,
++ const AccumDiagGmm &num_acc,
++ const AccumDiagGmm &den_acc,
++ const AccumDiagGmm &mle_acc);
private:
diff --cc src/gmm/mle-am-diag-gmm.cc
index cbfe872d733db01d5eb1ac756a9b3937fe6ecbe5,cbfe872d733db01d5eb1ac756a9b3937fe6ecbe5..089d7fd08e35687dd5ed6c16ea995eff953ba393
return *(gmm_accumulators_[index]);
}
++AccumDiagGmm& AccumAmDiagGmm::GetAcc(int32 index) {
++ assert(index >= 0 && index < static_cast<int32>(gmm_accumulators_.size()));
++ return *(gmm_accumulators_[index]);
++}
++
AccumAmDiagGmm::~AccumAmDiagGmm() {
DeletePointers(&gmm_accumulators_);
}
}
}
++void AccumAmDiagGmm::SmoothStats(BaseFloat tau) {
++ int32 num_pdfs = gmm_accumulators_.size();
++ for (int32 i = 0; i < num_pdfs; i++)
++ gmm_accumulators_[i]->SmoothStats(tau);
++}
++
++void AccumAmDiagGmm::SmoothWithAccum(BaseFloat tau, const AccumAmDiagGmm &src_accs) {
++ int32 num_pdfs = gmm_accumulators_.size();
++ KALDI_ASSERT(num_pdfs == src_accs.NumAccs());
++ for (int32 i = 0; i < num_pdfs; i++)
++ gmm_accumulators_[i]->SmoothWithAccum(tau, src_accs.GetAcc(i));
++}
++
++void AccumAmDiagGmm::SmoothWithModel(BaseFloat tau, const AmDiagGmm &src_model) {
++ int32 num_pdfs = gmm_accumulators_.size();
++ KALDI_ASSERT(num_pdfs == src_model.NumPdfs());
++ for (int32 i = 0; i < num_pdfs; i++)
++ gmm_accumulators_[i]->SmoothWithModel(tau, src_model.GetPdf(i));
++}
++
++BaseFloat AccumAmDiagGmm::TotCount() const {
++ BaseFloat ans = 0.0;
++ for (int32 pdf = 0; pdf < NumAccs(); pdf++)
++ ans += gmm_accumulators_[pdf]->occupancy().Sum();
++ return ans;
++}
++
void MleAmDiagGmmUpdate(const MleDiagGmmOptions &config,
const AccumAmDiagGmm &amdiaggmm_acc,
GmmFlagsType flags,
diff --cc src/gmm/mle-am-diag-gmm.h
index a852b676d492ebeb0addd2ebb0a355353d705075,a852b676d492ebeb0addd2ebb0a355353d705075..1a171962412cb4f6e03f740bd90baeb249b889d2
int32 NumAccs() const { return gmm_accumulators_.size(); }
++ BaseFloat TotCount() const;
++
const AccumDiagGmm& GetAcc(int32 index) const;
++ AccumDiagGmm& GetAcc(int32 index);
++
++ // The next three functions are mostly useful in discriminative training.
++ // They call the corresponding functions in class AccumDiagGmm.
++ void SmoothStats(BaseFloat tau);
++ void SmoothWithAccum(BaseFloat tau, const AccumAmDiagGmm& src_accs);
++ void SmoothWithModel(BaseFloat tau, const AmDiagGmm& src_gmm);
++
private:
/// MLE accumulators and update methods for the GMMs
std::vector<AccumDiagGmm*> gmm_accumulators_;
/// for computing the maximum-likelihood estimates of the parameters of
/// an acoustic model that uses diagonal Gaussian mixture models as emission densities.
void MleAmDiagGmmUpdate(const MleDiagGmmOptions &config, const AccumAmDiagGmm &amdiaggmm_acc,
-- GmmFlagsType flags, AmDiagGmm *am_gmm, BaseFloat *obj_change_out,
-- BaseFloat *count_out);
++ GmmFlagsType flags, AmDiagGmm *am_gmm, BaseFloat *obj_change_out,
++ BaseFloat *count_out);
} // End namespace kaldi
diff --cc src/gmm/mle-diag-gmm.cc
index 2ea1435943ff6865be707f6ca861e6486bb54369,2ea1435943ff6865be707f6ca861e6486bb54369..d33cf19e46e0ab153afc790f87ed140f117da98a
+++ b/src/gmm/mle-diag-gmm.cc
// Careful: this wouldn't be valid if it were used to update the
// Gaussian weights.
++// Note: if we have zero stats for something, we don't smooth.
void AccumDiagGmm::SmoothStats(BaseFloat tau) {
-- Vector<double> smoothing_vec(occupancy_);
-- smoothing_vec.InvertElements();
-- smoothing_vec.Scale(static_cast<double>(tau));
-- smoothing_vec.Add(1.0);
-- // now smoothing_vec = (tau + occ) / occ
--
-- mean_accumulator_.MulRowsVec(smoothing_vec);
-- variance_accumulator_.MulRowsVec(smoothing_vec);
-- occupancy_.Add(static_cast<double>(tau));
++ AccumDiagGmm tmp_accum(*this);
++ SmoothWithAccum(tau, tmp_accum);
}
KALDI_ASSERT(src_acc.NumGauss() == num_comp_ && src_acc.Dim() == dim_);
for (int32 i = 0; i < num_comp_; i++) {
if (src_acc.occupancy_(i) != 0.0) { // can only smooth if src was nonzero...
-- occupancy_(i) += tau;
-- mean_accumulator_.Row(i).AddVec(tau / src_acc.occupancy_(i),
-- src_acc.mean_accumulator_.Row(i));
-- variance_accumulator_.Row(i).AddVec(tau / src_acc.occupancy_(i),
-- src_acc.variance_accumulator_.Row(i));
++ if (flags_ & kGmmWeights)
++ occupancy_(i) += tau;
++ if (flags_ & kGmmMeans)
++ mean_accumulator_.Row(i).AddVec(tau / src_acc.occupancy_(i),
++ src_acc.mean_accumulator_.Row(i));
++ if (flags_ & kGmmVariances)
++ variance_accumulator_.Row(i).AddVec(tau / src_acc.occupancy_(i),
++ src_acc.variance_accumulator_.Row(i));
} else
KALDI_WARN << "Could not smooth since source acc had zero occupancy.";
}
gmm.GetMeans(&means);
gmm.GetVars(&vars);
-- mean_accumulator_.AddMat(tau, means);
++ if (flags_ & kGmmMeans)
++ mean_accumulator_.AddMat(tau, means);
means.ApplyPow(2.0);
vars.AddMat(1.0, means, kNoTrans);
-- variance_accumulator_.AddMat(tau, vars);
--
-- occupancy_.Add(tau);
++ if (flags_ & kGmmVariances)
++ variance_accumulator_.AddMat(tau, vars);
++ if (flags_ & kGmmWeights)
++ occupancy_.Add(tau);
}
AccumDiagGmm::AccumDiagGmm(const AccumDiagGmm &other)
diff --cc src/gmm/mle-diag-gmm.h
index e46882214af400c1ea6da4f0bb1a7ab03a1e7cf8,e46882214af400c1ea6da4f0bb1a7ab03a1e7cf8..02f6cf33cb126cee09e14d270d8cb1b8081a4fa7
+++ b/src/gmm/mle-diag-gmm.h
BaseFloat frame_posterior);
/// Smooths the accumulated counts by adding 'tau' extra frames. An example
-- /// use for this is I-smoothing for MMIE.
++ /// use for this is I-smoothing for MMIE. Calls SmoothWithAccum.
void SmoothStats(BaseFloat tau);
-- /// Smooths the accumulated counts using some other accumulator. Performs
-- /// a weighted sum of the current accumulator with the given one. An example
-- /// use for this is I-smoothing for MPE. Both accumulators must have the same
-- /// dimension and number of components.
++ /// Smooths the accumulated counts using some other accumulator. Performs a
++ /// weighted sum of the current accumulator with the given one. An example use
++ /// for this is I-smoothing for MMI and MPE. Both accumulators must have the
++ /// same dimension and number of components.
void SmoothWithAccum(BaseFloat tau, const AccumDiagGmm& src_acc);
--
++
/// Smooths the accumulated counts using the parameters of a given model.
/// An example use of this is MAP-adaptation. The model must have the
/// same dimension and number of components as the current accumulator.
diff --cc src/gmm/mmie-am-diag-gmm.cc
index fc9f64b09b5a960dfc8a054648e1a5c9b08bd2d1,9ae070580a2dd0acada93a6da81faee388fc4c1b..0000000000000000000000000000000000000000
deleted file mode 100644,100644
deleted file mode 100644,100644
+++ /dev/null
--// gmm/mmie-am-diag-gmm.cc
--
--// Copyright 2009-2011 Petr Motlicek
--
--// Licensed under the Apache License, Version 2.0 (the "License");
--// you may not use this file except in compliance with the License.
--// You may obtain a copy of the License at
--//
--// http://www.apache.org/licenses/LICENSE-2.0
--//
--// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
--// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
--// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
--// MERCHANTABLITY OR NON-INFRINGEMENT.
--// See the Apache 2 License for the specific language governing permissions and
--// limitations under the License.
--
--#include "gmm/am-diag-gmm.h"
--#include "gmm/mmie-am-diag-gmm.h"
--#include "gmm/mmie-diag-gmm.h"
--#include "util/stl-utils.h"
--
--namespace kaldi {
--
--/// Reads accumulators of MMI Numerators
--AccumDiagGmm& MmieAccumAmDiagGmm::GetNumAcc(int32 index) const {
-- assert(index >= 0 && index < static_cast<int32>(num_accumulators_.size()));
-- return *(num_accumulators_[index]);
--}
--
--/// Reads accumulators of MMI Denominators
--AccumDiagGmm& MmieAccumAmDiagGmm::GetDenAcc(int32 index) const {
-- assert(index >= 0 && index < static_cast<int32>(den_accumulators_.size()));
-- return *(den_accumulators_[index]);
--}
--
-/// Reads accumulators of I-Smooth
-AccumDiagGmm& MmieAccumAmDiagGmm::GetISmoothAcc(int32 index) const {
- assert(index >= 0 && index < static_cast<int32>(i_smooth_accumulators_.size()));
- return *(i_smooth_accumulators_[index]);
-}
--
--MmieAccumAmDiagGmm::~MmieAccumAmDiagGmm() {
-- DeletePointers(&num_accumulators_);
-- DeletePointers(&den_accumulators_);
--}
--
--void MmieAccumAmDiagGmm::Init(const AmDiagGmm &model,
-- GmmFlagsType flags) {
-- DeletePointers(&num_accumulators_); // in case was non-empty when called.
-- DeletePointers(&den_accumulators_); // in case was non-empty when called.
--
-- num_accumulators_.resize(model.NumPdfs(), NULL);
-- den_accumulators_.resize(model.NumPdfs(), NULL);
--
-- for (int32 i = 0; i < model.NumPdfs(); i++) {
-- num_accumulators_[i] = new AccumDiagGmm();
-- num_accumulators_[i]->Resize(model.GetPdf(i), flags);
-- den_accumulators_[i] = new AccumDiagGmm();
-- den_accumulators_[i]->Resize(model.GetPdf(i), flags);
--
-- }
--}
--
--void MmieAccumAmDiagGmm::Init(const AmDiagGmm &model,
-- int32 dim, GmmFlagsType flags) {
-- KALDI_ASSERT(dim > 0);
-- DeletePointers(&num_accumulators_); // in case was non-empty when called.
-- DeletePointers(&den_accumulators_); // in case was non-empty when called.
-- num_accumulators_.resize(model.NumPdfs(), NULL);
-- den_accumulators_.resize(model.NumPdfs(), NULL);
--
-- for (int32 i = 0; i < model.NumPdfs(); i++) {
-- num_accumulators_[i] = new AccumDiagGmm();
-- num_accumulators_[i]->Resize(model.GetPdf(i).NumGauss(),
-- dim, flags);
-- den_accumulators_[i] = new AccumDiagGmm();
-- den_accumulators_[i]->Resize(model.GetPdf(i).NumGauss(),
-- dim, flags);
--
-- }
--}
--
--void MmieAccumAmDiagGmm::SetZero(GmmFlagsType flags) {
-- for (size_t i = 0; i < num_accumulators_.size(); ++i) {
-- num_accumulators_[i]->SetZero(flags);
-- den_accumulators_[i]->SetZero(flags);
-- }
--}
--
--
--void MmieAccumAmDiagGmm::ReadNum(std::istream& in_stream, bool binary,
-- bool add) {
-- int32 num_pdfs;
-- ExpectMarker(in_stream, binary, "<NUMPDFS>");
-- ReadBasicType(in_stream, binary, &num_pdfs);
-- KALDI_ASSERT(num_pdfs > 0);
-- if (!add || (add && num_accumulators_.empty())) {
-- num_accumulators_.resize(num_pdfs, NULL);
-- for (std::vector<AccumDiagGmm*>::iterator it = num_accumulators_.begin(),
-- end = num_accumulators_.end(); it != end; ++it) {
-- if (*it != NULL) delete *it;
-- *it = new AccumDiagGmm();
-- (*it)->Read(in_stream, binary, add);
-- }
--
--
-- } else {
-- if (num_accumulators_.size() != static_cast<size_t> (num_pdfs))
-- KALDI_ERR << "Adding NUM accumulators but num-pdfs do not match: "
-- << (num_accumulators_.size()) << " vs. "
-- << (num_pdfs);
--
-- for (std::vector<AccumDiagGmm*>::iterator it = num_accumulators_.begin(),
-- end = num_accumulators_.end(); it != end; ++it)
-- (*it)->Read(in_stream, binary, add);
--
-- }
--}
--
--void MmieAccumAmDiagGmm::ReadDen(std::istream& in_stream, bool binary,
-- bool add) {
-- int32 num_pdfs;
-- ExpectMarker(in_stream, binary, "<NUMPDFS>");
-- ReadBasicType(in_stream, binary, &num_pdfs);
-- KALDI_ASSERT(num_pdfs > 0);
-- if (!add || (add && den_accumulators_.empty())) {
-- den_accumulators_.resize(num_pdfs, NULL);
-- for (std::vector<AccumDiagGmm*>::iterator it = den_accumulators_.begin(),
-- end = den_accumulators_.end(); it != end; ++it) {
-- if (*it != NULL) delete *it;
-- *it = new AccumDiagGmm();
-- (*it)->Read(in_stream, binary, add);
-- }
--
--
-- } else {
-- if (den_accumulators_.size() != static_cast<size_t> (num_pdfs))
-- KALDI_ERR << "Adding DEN accumulators but num-pdfs do not match: "
-- << (den_accumulators_.size()) << " vs. "
-- << (num_pdfs);
--
-- for (std::vector<AccumDiagGmm*>::iterator it = den_accumulators_.begin(),
-- end = den_accumulators_.end(); it != end; ++it)
- (*it)->Read(in_stream, binary, add);
-
- }
-}
-
-void MmieAccumAmDiagGmm::ReadISmooth(std::istream& in_stream, bool binary,
- bool add) {
- int32 num_pdfs;
- ExpectMarker(in_stream, binary, "<NUMPDFS>");
- ReadBasicType(in_stream, binary, &num_pdfs);
- KALDI_ASSERT(num_pdfs > 0);
- if (!add || (add && i_smooth_accumulators_.empty())) {
- i_smooth_accumulators_.resize(num_pdfs, NULL);
- for (std::vector<AccumDiagGmm*>::iterator it = i_smooth_accumulators_.begin(),
- end = i_smooth_accumulators_.end(); it != end; ++it) {
- if (*it != NULL) delete *it;
- *it = new AccumDiagGmm();
- (*it)->Read(in_stream, binary, add);
- }
-
-
- } else {
- if (i_smooth_accumulators_.size() != static_cast<size_t> (num_pdfs))
- KALDI_ERR << "Adding DEN accumulators but num-pdfs do not match: "
- << (i_smooth_accumulators_.size()) << " vs. "
- << (num_pdfs);
-
- for (std::vector<AccumDiagGmm*>::iterator it = i_smooth_accumulators_.begin(),
- end = i_smooth_accumulators_.end(); it != end; ++it)
-- (*it)->Read(in_stream, binary, add);
--
-- }
--}
--
--void MmieAccumAmDiagGmm::WriteNum(std::ostream& out_stream, bool binary) const {
-- int32 num_pdfs = num_accumulators_.size();
-- WriteMarker(out_stream, binary, "<NUMPDFS>");
-- WriteBasicType(out_stream, binary, num_pdfs);
-- for (std::vector<AccumDiagGmm*>::const_iterator it =
-- num_accumulators_.begin(), end = num_accumulators_.end(); it != end; ++it) {
-- (*it)->Write(out_stream, binary);
-- }
--}
--
--
--void MmieAccumAmDiagGmm::WriteDen(std::ostream& out_stream, bool binary) const {
-- int32 num_pdfs = den_accumulators_.size();
-- WriteMarker(out_stream, binary, "<NUMPDFS>");
-- WriteBasicType(out_stream, binary, num_pdfs);
-- for (std::vector<AccumDiagGmm*>::const_iterator it =
-- den_accumulators_.begin(), end = den_accumulators_.end(); it != end; ++it) {
-- (*it)->Write(out_stream, binary);
-- }
--}
--
--
--void MmieAmDiagGmmUpdate(const MmieDiagGmmOptions &config,
-- const MmieAccumAmDiagGmm &mmieamdiaggmm_acc,
-- GmmFlagsType flags,
-- AmDiagGmm *am_gmm,
-- BaseFloat *auxf_change_gauss,
-- BaseFloat *auxf_change_weights,
-- BaseFloat *count_out,
-- int32 *num_floored_out) {
-- KALDI_ASSERT(am_gmm != NULL);
-- KALDI_ASSERT(mmieamdiaggmm_acc.NumAccs() == am_gmm->NumPdfs());
-- if (auxf_change_gauss != NULL) *auxf_change_gauss = 0.0;
-- if (auxf_change_weights != NULL) *auxf_change_weights = 0.0;
-- if (count_out != NULL) *count_out = 0.0;
-- if (num_floored_out != NULL) *num_floored_out = 0.0;
-- BaseFloat tmp_auxf_change_gauss, tmp_auxf_change_weights, tmp_count;
-- int32 tmp_num_floored;
--
-- MmieAccumDiagGmm mmie_gmm;
--
-- for (size_t i = 0; i < mmieamdiaggmm_acc.NumAccs(); i++) {
-- mmie_gmm.Resize(am_gmm->GetPdf(i).NumGauss(), am_gmm->GetPdf(i).Dim(), flags);
-- mmie_gmm.SubtractAccumulatorsISmoothing(mmieamdiaggmm_acc.GetNumAcc(i),
-- mmieamdiaggmm_acc.GetDenAcc(i),
- config);
- config,
- config.has_i_smooth_stats ?
- mmieamdiaggmm_acc.GetISmoothAcc(i):
- mmieamdiaggmm_acc.GetNumAcc(i));
-- mmie_gmm.Update(config, flags, &(am_gmm->GetPdf(i)),
-- &tmp_auxf_change_gauss, &tmp_auxf_change_weights,
-- &tmp_count, &tmp_num_floored);
-- if (auxf_change_gauss != NULL) *auxf_change_gauss += tmp_auxf_change_gauss;
-- if (auxf_change_weights != NULL) *auxf_change_weights += tmp_auxf_change_weights;
-- if (count_out != NULL) *count_out += tmp_count;
-- if (num_floored_out != NULL) *num_floored_out += tmp_num_floored;
-- }
--}
--
--BaseFloat MmieAccumAmDiagGmm::TotNumCount() {
-- BaseFloat ans = 0.0;
-- for (size_t i = 0; i < num_accumulators_.size(); i++)
-- if (num_accumulators_[i])
-- ans += num_accumulators_[i]->occupancy().Sum();
-- return ans;
--}
--
--BaseFloat MmieAccumAmDiagGmm::TotDenCount() {
-- BaseFloat ans = 0.0;
-- for (size_t i = 0; i < den_accumulators_.size(); i++)
-- if (den_accumulators_[i])
-- ans += den_accumulators_[i]->occupancy().Sum();
-- return ans;
--}
--
--
--
--} // namespace kaldi
diff --cc src/gmm/mmie-am-diag-gmm.h
index 92a90a6a53daceaa97d2261a27117576954ff1e4,b57c5056046d54c96d77809dcd735525adf268cc..0000000000000000000000000000000000000000
deleted file mode 100644,100644
deleted file mode 100644,100644
+++ /dev/null
--// gmm/mmie-am-diag-gmm.h
--
--// Copyright 2009-2011
--// Author: Petr Motlicek
--
--// Licensed under the Apache License, Version 2.0 (the "License");
--// you may not use this file except in compliance with the License.
--// You may obtain a copy of the License at
--//
--// http://www.apache.org/licenses/LICENSE-2.0
--//
--// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
--// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
--// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
--// MERCHANTABLITY OR NON-INFRINGEMENT.
--// See the Apache 2 License for the specific language governing permissions and
--// limitations under the License.
--
--
--#ifndef KALDI_GMM_MMIE_AM_DIAG_GMM_H_
--#define KALDI_GMM_MMIE_AM_DIAG_GMM_H_ 1
--
--#include <vector>
--
--#include "gmm/am-diag-gmm.h"
--#include "gmm/mmie-diag-gmm.h"
--#include "gmm/mle-diag-gmm.h"
--
--
--namespace kaldi {
--
--class MmieAccumAmDiagGmm {
-- public:
-- MmieAccumAmDiagGmm() {}
-- ~MmieAccumAmDiagGmm();
--
-- void ReadNum(std::istream &in_stream, bool binary, bool add);
-- void ReadDen(std::istream &in_stream, bool binary, bool add);
- void ReadISmooth(std::istream &in_stream, bool binary, bool add);
-- void WriteNum(std::ostream &out_stream, bool binary) const;
-- void WriteDen(std::ostream &out_stream, bool binary) const;
-
- //
-- /// Initializes accumulators for each GMM based on the number of components
-- /// and dimension.
-- void Init(const AmDiagGmm &model, GmmFlagsType flags);
-- /// Initialization using different dimension than model.
-- void Init(const AmDiagGmm &model, int32 dim, GmmFlagsType flags);
-- void SetZero(GmmFlagsType flags);
--
-- int32 NumAccs() { return num_accumulators_.size(); }
--
-- int32 NumAccs() const { return num_accumulators_.size(); }
--
-- AccumDiagGmm& GetNumAcc(int32 index) const;
-- AccumDiagGmm& GetDenAcc(int32 index) const;
- AccumDiagGmm& GetISmoothAcc(int32 index) const;
--
-- void CopyToNumAcc(int32 index);
-- BaseFloat TotNumCount();
-- BaseFloat TotDenCount();
-
-- private:
-- /// MMIE accumulators and update methods for the GMMs
-- std::vector<AccumDiagGmm*> num_accumulators_;
-- std::vector<AccumDiagGmm*> den_accumulators_;
-
-
- std::vector<AccumDiagGmm*> i_smooth_accumulators_;
-- // Cannot have copy constructor and assigment operator
-- KALDI_DISALLOW_COPY_AND_ASSIGN(MmieAccumAmDiagGmm);
--};
--
--
--/// for computing the maximum-likelihood estimates of the parameters of
--/// an acoustic model that uses diagonal Gaussian mixture models as emission densities.
--void MmieAmDiagGmmUpdate(const MmieDiagGmmOptions &config,
-- const MmieAccumAmDiagGmm &mmieamdiaggmm_acc,
-- GmmFlagsType flags,
-- AmDiagGmm *am_gmm,
-- BaseFloat *auxf_change_gauss,
-- BaseFloat *auxf_change_weight,
-- BaseFloat *count_out,
-- int32 *num_floored_out);
--
--} // End namespace kaldi
--
--
--#endif // KALDI_GMM_MMIE_AM_DIAG_GMM_H_
diff --cc src/gmm/mmie-diag-gmm.cc
index 7edf6cdf723847dfdec49ba1190ac1516f730859,00b14de0ebe63741f23827c577fe4f1bfe76ef18..0000000000000000000000000000000000000000
deleted file mode 100644,100644
deleted file mode 100644,100644
+++ /dev/null
--// gmm/mmie-diag-gmm.cc
--
--// Copyright 2009-2011 Petr Motlicek, Arnab Ghoshal
--
--// Licensed under the Apache License, Version 2.0 (the "License");
--// you may not use this file except in compliance with the License.
--// You may obtain a copy of the License at
--//
--// http://www.apache.org/licenses/LICENSE-2.0
--//
--// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
--// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
--// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
--// MERCHANTABLITY OR NON-INFRINGEMENT.
--// See the Apache 2 License for the specific language governing permissions and
--// limitations under the License.
--
--#include <algorithm> // for std::max
--#include <string>
--#include <vector>
--
--#include "gmm/diag-gmm.h"
--#include "gmm/mmie-diag-gmm.h"
--
--
--namespace kaldi {
--
--
--void MmieAccumDiagGmm::Read(std::istream &in_stream, bool binary, bool add) {
-- int32 dimension, num_components;
-- GmmFlagsType flags;
-- std::string token;
--
-- ExpectMarker(in_stream, binary, "<GMMMMIACCS>");
-- ExpectMarker(in_stream, binary, "<VECSIZE>");
-- ReadBasicType(in_stream, binary, &dimension);
-- ExpectMarker(in_stream, binary, "<NUMCOMPONENTS>");
-- ReadBasicType(in_stream, binary, &num_components);
-- ExpectMarker(in_stream, binary, "<FLAGS>");
-- ReadBasicType(in_stream, binary, &flags);
--
-- if (add) {
-- if ((NumGauss() != 0 || Dim() != 0 || Flags() != 0)) {
-- if (num_components != NumGauss() || dimension != Dim()
-- || flags != Flags()) {
-- KALDI_ERR << "Dimension or flags mismatch: " << NumGauss() << ", "
-- << Dim() << ", " << Flags() << " vs. " << num_components
-- << ", " << dimension << ", " << flags;
-- }
-- } else {
-- Resize(num_components, dimension, flags);
-- }
-- } else {
-- Resize(num_components, dimension, flags);
-- }
--
-- ReadMarker(in_stream, binary, &token);
-- while (token != "</GMMMMIACCS>") {
-- if (token == "<NUM_OCCUPANCY>") {
-- num_occupancy_.Read(in_stream, binary, add);
-- } else if (token == "<DEN_OCCUPANCY>") {
-- den_occupancy_.Read(in_stream, binary, add);
-- } else if (token == "<MEANACCS>") {
-- mean_accumulator_.Read(in_stream, binary, add);
-- } else if (token == "<DIAGVARACCS>") {
-- variance_accumulator_.Read(in_stream, binary, add);
-- } else {
-- KALDI_ERR << "Unexpected token '" << token << "' in model file ";
-- }
-- ReadMarker(in_stream, binary, &token);
-- }
-- /// get difference
-- occupancy_.CopyFromVec(num_occupancy_);
-- occupancy_.AddVec(-1.0, den_occupancy_);
--
--}
--
--void MmieAccumDiagGmm::Write(std::ostream &out_stream, bool binary) const {
-- WriteMarker(out_stream, binary, "<GMMMMIACCS>");
-- WriteMarker(out_stream, binary, "<VECSIZE>");
-- WriteBasicType(out_stream, binary, dim_);
-- WriteMarker(out_stream, binary, "<NUMCOMPONENTS>");
-- WriteBasicType(out_stream, binary, num_comp_);
-- WriteMarker(out_stream, binary, "<FLAGS>");
-- WriteBasicType(out_stream, binary, flags_);
--
-- // convert into BaseFloat before writing things
-- Vector<BaseFloat> num_occupancy_bf(num_occupancy_.Dim());
-- Vector<BaseFloat> den_occupancy_bf(den_occupancy_.Dim());
-- Matrix<BaseFloat> mean_accumulator_bf(mean_accumulator_.NumRows(),
-- mean_accumulator_.NumCols());
-- Matrix<BaseFloat> variance_accumulator_bf(variance_accumulator_.NumRows(),
-- variance_accumulator_.NumCols());
-- num_occupancy_bf.CopyFromVec(num_occupancy_);
-- den_occupancy_bf.CopyFromVec(den_occupancy_);
-- mean_accumulator_bf.CopyFromMat(mean_accumulator_);
-- variance_accumulator_bf.CopyFromMat(variance_accumulator_);
--
-- WriteMarker(out_stream, binary, "<NUM_OCCUPANCY>");
-- num_occupancy_bf.Write(out_stream, binary);
-- WriteMarker(out_stream, binary, "<DEN_OCCUPANCY>");
-- den_occupancy_bf.Write(out_stream, binary);
-- WriteMarker(out_stream, binary, "<MEANACCS>");
-- mean_accumulator_bf.Write(out_stream, binary);
-- WriteMarker(out_stream, binary, "<DIAGVARACCS>");
-- variance_accumulator_bf.Write(out_stream, binary);
-- WriteMarker(out_stream, binary, "</GMMMMIACCS>");
--}
--
--
--
--
--void MmieAccumDiagGmm::Resize(int32 num_comp, int32 dim, GmmFlagsType flags) {
-- KALDI_ASSERT(num_comp > 0 && dim > 0);
-- num_comp_ = num_comp;
-- dim_ = dim;
-- flags_ = AugmentGmmFlags(flags);
-- num_occupancy_.Resize(num_comp);
-- den_occupancy_.Resize(num_comp);
-- occupancy_.Resize(num_comp);
-- if (flags_ & kGmmMeans)
-- mean_accumulator_.Resize(num_comp, dim);
-- else
-- mean_accumulator_.Resize(0, 0);
-- if (flags_ & kGmmVariances)
-- variance_accumulator_.Resize(num_comp, dim);
-- else
-- variance_accumulator_.Resize(0, 0);
--}
--
--
--void MmieAccumDiagGmm::SetZero(GmmFlagsType flags) {
-- if (flags & ~flags_)
-- KALDI_ERR << "Flags in argument do not match the active accumulators";
-- if (flags & kGmmWeights) {
-- num_occupancy_.SetZero();
-- den_occupancy_.SetZero();
-- occupancy_.SetZero();
-- }
-- if (flags & kGmmMeans) mean_accumulator_.SetZero();
-- if (flags & kGmmVariances) variance_accumulator_.SetZero();
--}
--
--
--void MmieAccumDiagGmm::Scale(BaseFloat f, GmmFlagsType flags) {
-- if (flags & ~flags_)
-- KALDI_ERR << "Flags in argument do not match the active accumulators";
-- double d = static_cast<double>(f);
-- if (flags & kGmmWeights) {
-- num_occupancy_.Scale(d);
-- den_occupancy_.Scale(d);
-- occupancy_.Scale(d);
-- }
-- if (flags & kGmmMeans) mean_accumulator_.Scale(d);
-- if (flags & kGmmVariances) variance_accumulator_.Scale(d);
--}
--
--
--void MmieAccumDiagGmm::SubtractAccumulatorsISmoothing(
-- const AccumDiagGmm& num_acc,
-- const AccumDiagGmm& den_acc,
- const MmieDiagGmmOptions& opts){
- const MmieDiagGmmOptions& opts,
- const AccumDiagGmm& i_smooth_acc){
--
-- //KALDI_ASSERT(num_acc.NumGauss() == den_acc.NumGauss && num_acc.Dim() == den_acc.Dim());
-- //std::cout << "NumGauss: " << num_acc.NumGauss() << " " << den_acc.NumGauss() << " " << num_comp_ << '\n';
-- KALDI_ASSERT(num_acc.NumGauss() == num_comp_ && num_acc.Dim() == dim_);
-- KALDI_ASSERT(den_acc.NumGauss() == num_comp_ && den_acc.Dim() == dim_);
-
- KALDI_ASSERT(i_smooth_acc.NumGauss() == num_comp_ && i_smooth_acc.Dim() == dim_);
--
-- // no subracting occs, just copy them to local vars
-- num_occupancy_.CopyFromVec(num_acc.occupancy());
-- den_occupancy_.CopyFromVec(den_acc.occupancy());
-- occupancy_.CopyFromVec(num_occupancy_);
-- occupancy_.AddVec(-1.0, den_occupancy_);
--
-- // Copy nums to private vars
-- mean_accumulator_.CopyFromMat(num_acc.mean_accumulator(), kNoTrans);
-- variance_accumulator_.CopyFromMat(num_acc.variance_accumulator(), kNoTrans);
--
- // Copy I- smoothing stats
- Vector<double> i_smooth_occupancy(i_smooth_acc.occupancy());
- Matrix<double> i_smooth_mean_accumulator(i_smooth_acc.mean_accumulator());
- Matrix<double> i_smooth_variance_accumulator(i_smooth_acc.variance_accumulator());
-- // I- smoothing
-- for (int32 g = 0; g < num_comp_; g++) {
- double occ = num_occupancy_(g);
- double occ = i_smooth_occupancy(g);
-- if (occ >= 0.0) {
-- occupancy_(g) += opts.i_smooth_tau; // Add I-smoothing to occupancy_, but
-- // *not* to num_occupancy_, which remains the original count before
-- // I-smoothing, and which we use to update the weights.
-- mean_accumulator_.Row(g).AddVec(opts.i_smooth_tau/occ,
- mean_accumulator_.Row(g));
- i_smooth_mean_accumulator.Row(g));
-- variance_accumulator_.Row(g).AddVec(opts.i_smooth_tau/occ,
- variance_accumulator_.Row(g));
- i_smooth_variance_accumulator.Row(g));
-- }
-- }
-- // Subtract den from smoothed num
-- mean_accumulator_.AddMat(-1.0, den_acc.mean_accumulator(), kNoTrans);
-- variance_accumulator_.AddMat(-1.0, den_acc.variance_accumulator(), kNoTrans);
--}
--
--
--bool MmieAccumDiagGmm::EBWUpdateGaussian(
-- BaseFloat D,
-- GmmFlagsType flags,
-- const VectorBase<double> &orig_mean,
-- const VectorBase<double> &orig_var,
-- const VectorBase<double> &x_stats,
-- const VectorBase<double> &x2_stats,
-- double occ,
-- VectorBase<double> *mean,
-- VectorBase<double> *var,
-- double *auxf_impr) const {
-- if (! (flags&(kGmmMeans|kGmmVariances)) || occ <= 0.0) { // nothing to do.
-- if (auxf_impr) *auxf_impr = 0.0;
-- mean->CopyFromVec(orig_mean);
-- var->CopyFromVec(orig_var);
-- return true;
-- }
-- KALDI_ASSERT(!( (flags&kGmmVariances) && !(flags&kGmmMeans)));
--
-- mean->SetZero();
-- var->SetZero();
-- mean->AddVec(D, orig_mean);
-- var->AddVec2(D, orig_mean);
-- var->AddVec(D, orig_var);
-- mean->AddVec(1.0, x_stats);
-- var->AddVec(1.0, x2_stats);
-- BaseFloat scale = 1.0 / (occ + D);
-- mean->Scale(scale);
-- var->Scale(scale);
-- var->AddVec2(-1.0, *mean);
--
-- if (!(flags&kGmmVariances)) var->CopyFromVec(orig_var);
-- if (!(flags&kGmmMeans)) mean->CopyFromVec(orig_mean);
--
-- if (var->Min() > 0.0) {
-- if (auxf_impr != NULL) {
-- // work out auxf improvement.
-- BaseFloat old_auxf = 0.0, new_auxf = 0.0;
-- int32 dim = orig_mean.Dim();
-- for (int32 i = 0; i < dim; i++) {
-- BaseFloat mean_diff = (*mean)(i) - orig_mean(i);
-- old_auxf += (occ+D) * -0.5 * (log(orig_var(i)) +
-- ((*var)(i) + mean_diff*mean_diff)
-- / orig_var(i));
-- new_auxf += (occ+D) * -0.5 * (log((*var)(i)) + 1.0);
-
-- }
-- *auxf_impr = new_auxf - old_auxf;
-- }
-- return true;
-- } else return false;
--}
--
--
--void MmieAccumDiagGmm::Update(const MmieDiagGmmOptions &config,
-- GmmFlagsType flags,
-- DiagGmm *gmm,
-- BaseFloat *auxf_change_out_gauss,
-- BaseFloat *auxf_change_out_weights,
-- BaseFloat *count_out,
-- int32 *num_floored_out) const {
-- if (flags_ & ~flags)
-- KALDI_ERR << "Flags in argument do not match the active accumulators";
--
-- if (auxf_change_out_gauss) *auxf_change_out_gauss = 0.0;
-- if (auxf_change_out_weights) *auxf_change_out_weights = 0.0;
-- if (count_out) *count_out = 0.0;
-- if (num_floored_out) *num_floored_out = 0;
--
-- KALDI_ASSERT(gmm->NumGauss() == (num_comp_));
-- if (flags_ & kGmmMeans)
-- KALDI_ASSERT(dim_ == mean_accumulator_.NumCols());
--
-- int32 num_comp = num_comp_;
-- int32 dim = dim_;
--
-- // copy DiagGMM model and transform this to the normal case
-- DiagGmmNormal diaggmmnormal;
-- gmm->ComputeGconsts();
-- diaggmmnormal.CopyFromDiagGmm(*gmm);
--
-- // go over all components
-- double occ;
-- Vector<double> mean(dim), var(dim);
-- for (int32 g = 0; g < num_comp; g++) {
-- double D = config.ebw_e * den_occupancy_(g) / 2; // E*y_den/2 where E = 2;
-- // We initialize to half the value of D that would be dicated by
-- // E; this is part of the strategy used to ensure that the value of
-- // D we use is at least twice the value that would ensure positive
-- // variances.
--
-- occ = occupancy_(g);
--
-- int32 iter, max_iter = 100;
-- for (iter = 0; iter < max_iter; iter++) { // will normally break the first time.
-- if (EBWUpdateGaussian(D, flags,
-- diaggmmnormal.means_.Row(g),
-- diaggmmnormal.vars_.Row(g),
-- mean_accumulator_.Row(g),
-- variance_accumulator_.Row(g),
-- occ,
-- &mean,
-- &var,
-- NULL)) {
-- // Succeeded in getting all +ve vars at this value of D.
-- // So double D and commit changes.
-- D *= 2.0;
-- double auxf_impr = 0.0;
-- EBWUpdateGaussian(D, flags,
-- diaggmmnormal.means_.Row(g),
-- diaggmmnormal.vars_.Row(g),
-- mean_accumulator_.Row(g),
-- variance_accumulator_.Row(g),
-- occ,
-- &mean,
-- &var,
-- &auxf_impr);
-- if (auxf_change_out_gauss) *auxf_change_out_gauss += auxf_impr;
-- if (count_out) *count_out += num_occupancy_(g);
-- // the EBWUpdateGaussian function only updates the
-- // appropriate parameters according to the flags.
- // variance flooring
- //for (int32 i = 0; i < var.Dim(); i++) {
- // if (var(i) < config.min_variance) {
- // var(i) = config.min_variance;
- // KALDI_WARN << " flooring variance with value = " << var(i);
- // }
- //}
-- diaggmmnormal.means_.CopyRowFromVec(mean, g);
-- diaggmmnormal.vars_.CopyRowFromVec(var, g);
-
-- break;
-- } else {
-- // small step
-- D *= 1.1;
-- }
-- }
-- if (iter > 0 && num_floored_out != NULL) *num_floored_out++;
-- if (iter == max_iter) KALDI_WARN << "Dropped off end of loop, recomputing D. (unexpected.)";
-- }
--
-- // Now update weights...
- if (flags & kGmmWeights && num_comp > 1 &&
- num_occupancy_.Sum() > config.min_count_weight_update) {
- if (flags & kGmmWeights) {
-- double weight_auxf_at_start = 0.0, weight_auxf_at_end = 0.0;
-- Vector<double> weights(diaggmmnormal.weights_);
-- for (int32 g = 0; g < num_comp; g++) { // c.f. eq. 4.32 in Dan Povey's thesis.
-- weight_auxf_at_start +=
-- num_occupancy_(g) * log (weights(g))
-- - den_occupancy_(g) * weights(g) / diaggmmnormal.weights_(g);
-- }
-- for (int32 iter = 0; iter < 50; iter++) {
-- Vector<double> k_jm(num_comp); // c.f. eq. 4.35
-- double max_m = 0.0;
-- for (int32 g = 0; g < num_comp; g++)
-- max_m = std::max(max_m, den_occupancy_(g)/diaggmmnormal.weights_(g));
-- for (int32 g = 0; g < num_comp; g++)
-- k_jm(g) = max_m - den_occupancy_(g)/diaggmmnormal.weights_(g);
-- for (int32 g = 0; g < num_comp; g++) // c.f. eq. 4.34
-- weights(g) = num_occupancy_(g) + k_jm(g)*weights(g);
-- weights.Scale(1.0 / weights.Sum()); // c.f. eq. 4.34 (denominator)
-- }
-- for (int32 g = 0; g < num_comp; g++) { // weight flooring.
-- if (weights(g) < config.min_gaussian_weight)
-- weights(g) = config.min_gaussian_weight;
-- }
-- weights.Scale(1.0 / weights.Sum()); // renormalize after flooring..
-- // floor won't be exact now but doesn't really matter.
--
-- for (int32 g = 0; g < num_comp; g++) { // c.f. eq. 4.32 in Dan Povey's thesis.
-- weight_auxf_at_end +=
-- num_occupancy_(g) * log (weights(g))
-- - den_occupancy_(g) * weights(g) / diaggmmnormal.weights_(g);
-- }
--
-- if (auxf_change_out_weights)
-- *auxf_change_out_weights += weight_auxf_at_end - weight_auxf_at_start;
-- diaggmmnormal.weights_.CopyFromVec(weights);
-- }
-- // copy to natural representation according to flags
-- diaggmmnormal.CopyToDiagGmm(gmm, flags);
-- gmm->ComputeGconsts();
--}
--
--
--
--
--MmieAccumDiagGmm::MmieAccumDiagGmm(const MmieAccumDiagGmm &other)
-- : dim_(other.dim_), num_comp_(other.num_comp_),
-- flags_(other.flags_), num_occupancy_(other.num_occupancy_),
-- den_occupancy_(other.den_occupancy_),
-- mean_accumulator_(other.mean_accumulator_),
-- variance_accumulator_(other.variance_accumulator_) {}
--
--
--
--//BaseFloat ComputeD(const DiagGmm& old_gmm, int32 mix_index, BaseFloat ebw_e){
--//}
--
--
--
--//BaseFloat MmieDiagGmm::MmiObjective(const DiagGmm& gmm) const {
--//}
--
--} // End of namespace kaldi
diff --cc src/gmm/mmie-diag-gmm.h
index 510afccb79dd10283fbe2a4bf1d35897a16e308c,29f325b9a891b8fbfb6ddd5558f13a95c3545f4e..0000000000000000000000000000000000000000
deleted file mode 100644,100644
deleted file mode 100644,100644
+++ /dev/null
--// gmm/mmie-diag-gmm.h
--
--// Copyright 2009-2011 Petr Motlicek, Arnab Ghoshal
--
--// Licensed under the Apache License, Version 2.0 (the "License");
--// you may not use this file except in compliance with the License.
--// You may obtain a copy of the License at
--//
--// http://www.apache.org/licenses/LICENSE-2.0
--//
--// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
--// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
--// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
--// MERCHANTABLITY OR NON-INFRINGEMENT.
--// See the Apache 2 License for the specific language governing permissions and
--// limitations under the License.
--
--
--#ifndef KALDI_GMM_MMIE_DIAG_GMM_H_
--#define KALDI_GMM_MMIE_DIAG_GMM_H_ 1
--
--#include <string>
--
--#include "gmm/diag-gmm.h"
--#include "gmm/mle-diag-gmm.h"
--#include "gmm/model-common.h"
--#include "util/parse-options.h"
--
--namespace kaldi {
--
--/** \struct MmieDiagGmmOptions
-- * Configuration variables like variance floor, minimum occupancy, etc.
-- * needed in the estimation process.
-- */
--struct MmieDiagGmmOptions : public MleDiagGmmOptions {
- BaseFloat i_smooth_tau;
- BaseFloat i_smooth_tau;
-- BaseFloat ebw_e;
- BaseFloat min_count_weight_update;
- //this boolean indicates whether we have extra Ismoothing stats in
- bool has_i_smooth_stats;
-- MmieDiagGmmOptions() : MleDiagGmmOptions() {
-- i_smooth_tau = 100.0;
-- ebw_e = 2.0;
- min_count_weight_update = 10.0;
- has_i_smooth_stats = false;
-- }
-- void Register(ParseOptions *po) {
-- std::string module = "MmieDiagGmmOptions: ";
-- po->Register("min-gaussian-weight", &min_gaussian_weight,
-- module+"Min Gaussian weight before we remove it.");
- po->Register("min-count-weight-update", &min_count_weight_update,
- module+"Minimum state-level numerator count required to do the weight update");
-- po->Register("min-variance", &min_variance,
-- module+"Variance floor (absolute variance).");
-- po->Register("remove-low-count-gaussians", &remove_low_count_gaussians,
-- module+"If true, remove Gaussians that fall below the floors.");
-- po->Register("i-smooth-tau", &i_smooth_tau,
-- module+"Coefficient for I-smoothing.");
-- po->Register("ebw-e", &ebw_e, module+"Smoothing constant for EBW update.");
-- }
--};
--
--
--/** Class for computing the maximum mutual information estimate of the
-- * parameters of a Gaussian mixture model.
-- */
--class MmieAccumDiagGmm {
-- public:
-- MmieAccumDiagGmm(): dim_(0), num_comp_(0), flags_(0) {}
-- //MmieDiagGmm() {}
-- explicit MmieAccumDiagGmm(const DiagGmm &gmm, GmmFlagsType flags) {
-- Resize(gmm.NumGauss(), gmm.Dim(), flags);
-- }
--
-- // provide copy constructor.
-- explicit MmieAccumDiagGmm(const MmieAccumDiagGmm &other);
--
--
-- void Read(std::istream &in_stream, bool binary, bool add);
-- void Write(std::ostream &out_stream, bool binary) const;
--
-- /// Allocates memory for accumulators
-- void Resize(int32 num_comp, int32 dim, GmmFlagsType flags);
--/// Calls ResizeAccumulators with arguments based on gmm
-- void Resize(const DiagGmm &gmm, GmmFlagsType flags);
--
--
-- /// Returns the number of mixture components
-- int32 NumGauss() const { return num_comp_; }
-- /// Returns the dimensionality of the feature vectors
-- int32 Dim() const { return dim_; }
--
-- void SetZero(GmmFlagsType flags);
-- void Scale(BaseFloat f, GmmFlagsType flags);
--
--
-- /// Computes the difference between the numerator and denominator accumulators
-- /// and applies I-smoothing to the numerator accs, if needed.
-- void SubtractAccumulatorsISmoothing(const AccumDiagGmm& num_acc,
-- const AccumDiagGmm& den_acc,
- const MmieDiagGmmOptions& opts);
- const MmieDiagGmmOptions& opts,
- const AccumDiagGmm& i_smooth_acc);
--
-- /// Does EBW update on one diagonal Gaussian; returns true if resulting
-- /// variance was all positive.
-- bool EBWUpdateGaussian(
-- BaseFloat D,
-- GmmFlagsType flags,
-- const VectorBase<double> &orig_mean,
-- const VectorBase<double> &orig_var,
-- const VectorBase<double> &x_stats,
-- const VectorBase<double> &x2_stats,
-- double occ,
-- VectorBase<double> *mean,
-- VectorBase<double> *var,
-- double *auxf_impr) const;
--
-- /// MMIE update
-- void Update(const MmieDiagGmmOptions &config,
-- GmmFlagsType flags,
-- DiagGmm *gmm,
-- BaseFloat *auxf_change_out_gauss, // gets set to EBW auxf impr.
-- BaseFloat *auxf_change_out_weights, // auxf impr in weights auxf.
-- BaseFloat *count_out, // gets set to numerator count.
-- int32 *num_floored_out) const;
--
--
--
-- // Accessors
-- const GmmFlagsType Flags() const { return flags_; }
-- const Vector<double>& num_occupancy() const { return num_occupancy_; }
-- const Vector<double>& den_occupancy() const { return den_occupancy_; }
-- const Vector<double>& occupancy() const { return occupancy_; }
-- const Matrix<double>& mean_accumulator() const { return mean_accumulator_; }
-- const Matrix<double>& variance_accumulator() const { return variance_accumulator_; }
--
--
-- private:
-- int32 dim_;
-- int32 num_comp_;
-- /// Flags corresponding to the accumulators that are stored.
-- GmmFlagsType flags_;
--
-- /// Accumulators.
-- /// We store the difference of mean and var; we keep occupancy
-- /// for num and den and their difference (with I-smoothing)
--
-- Vector<double> num_occupancy_; // Numerator occupancy
-- Vector<double> den_occupancy_; // Denominator occupancy
-- Vector<double> occupancy_; // Num-Den occupancy *plus I-smoothing*
-- Matrix<double> mean_accumulator_; // Sum of num-den+I-smooth stats.
-- Matrix<double> variance_accumulator_; // Sum of num-den+I-smooth stats.
--
-- // BaseFloat ComputeD(const DiagGmm& old_gmm, int32 mix_index, BaseFloat ebw_e);
--
-- /// Cannot have copy constructor and assigment operator
-- //KALDI_DISALLOW_COPY_AND_ASSIGN(MmieDiagGmm);
--};
--
--
--inline void MmieAccumDiagGmm::Resize(const DiagGmm &gmm, GmmFlagsType flags) {
-- Resize(gmm.NumGauss(), gmm.Dim(), flags);
--}
--
--
--} // End namespace kaldi
--
--
--#endif // KALDI_GMM_MMIE_DIAG_GMM_H_
diff --cc src/gmm/model-common.cc
index 3b6f4ee63068fea2a6a02ce03c39a54b32b06888,3b6f4ee63068fea2a6a02ce03c39a54b32b06888..6443e82dba8b067125dd6bc3dcf78119d2705a07
+++ b/src/gmm/model-common.cc
return flags;
}
++std::string GmmFlagsToString(GmmFlagsType flags) {
++ std::string ans;
++ if (flags & kGmmMeans) ans += "m";
++ if (flags & kGmmVariances) ans += "v";
++ if (flags & kGmmWeights) ans += "w";
++ if (flags & kGmmTransitions) ans += "t";
++ return ans;
++}
++
GmmFlagsType AugmentGmmFlags(GmmFlagsType flags) {
KALDI_ASSERT((flags & ~kGmmAll) == 0); // make sure only valid flags are present.
if (flags & kGmmVariances) flags |= kGmmMeans;
if (flags & kGmmMeans) flags |= kGmmWeights;
-- KALDI_ASSERT(flags & kGmmWeights); // make sure zero-stats will be accumulated
++ if (!(flags & kGmmWeights)) {
++ KALDI_WARN << "Adding in kGmmWeights (\"w\") to empty flags.";
++ flags |= kGmmWeights; // Just add this in regardless:
++ // if user wants no stats, this will stop programs from crashing due to dim mismatches.
++ }
return flags;
}
diff --cc src/gmm/model-common.h
index 00e4e7b52d97401544727d22b137b32978f43e1a,00e4e7b52d97401544727d22b137b32978f43e1a..869b13f44a476f7df5c22a402922deddb499cfc7
+++ b/src/gmm/model-common.h
/// flags.
GmmFlagsType StringToGmmFlags(std::string str);
++/// Convert GMM flags to string
++std::string GmmFlagsToString(GmmFlagsType gmm_flags);
++
// Make sure that the flags make sense, i.e. if there is variance
// accumulation that there is also mean accumulation
GmmFlagsType AugmentGmmFlags(GmmFlagsType flags);
diff --cc src/gmmbin/Makefile
index 9bc806327e4b4c95be2e95b67d967e3aa21cba3a,9f115a8b6aa302c946be711e3108a37a1536c1ad..7cc9ec10da300fce66ac67e345b42e15eb270211
--- 1/src/gmmbin/Makefile
--- 2/src/gmmbin/Makefile
+++ b/src/gmmbin/Makefile
gmm-decode-faster-regtree-mllr gmm-et-apply-c gmm-latgen-simple \
gmm-rescore-lattice gmm-decode-biglm-faster fmpe-gmm-model-diffs-est \
fmpe-gmm-acc-stats-gpost fmpe-gmm-sum-accs fmpe-init-gmms fmpe-gmm-est \
- gmm-est-mmi gmm-latgen-faster gmm-copy \
- gmm-est-mmi gmm-latgen-faster
++ gmm-est-gaussians-ebw gmm-est-weights-ebw gmm-latgen-faster gmm-copy \
+ gmm-global-acc-stats gmm-global-est gmm-global-sum-accs gmm-gselect \
- gmm-latgen-biglm-faster
++ gmm-latgen-biglm-faster gmm-ismooth-stats
OBJFILES =
diff --cc src/gmmbin/gmm-align-compiled.cc
index 4010074ed7e159bc79cf8d1aefae95104c909709,fb9be9b26e18af3ef1c42b91c5955c878bd7e3be..44127bf47562d3f7a17dbf0e48260b855a26109c
frame_count += features.NumRows();
GetLinearSymbolSequence(decoded, &alignment, &words, &weight);
-- BaseFloat like = (-weight.Value1() -weight.Value2()) / acoustic_scale;
++ BaseFloat like = -(weight.Value1()+weight.Value2()) / acoustic_scale;
tot_like += like;
- if (scores_writer.IsOpen()) {
- scores_writer.Write(key, like*acoustic_scale);
- }
++ if (scores_writer.IsOpen())
++ scores_writer.Write(key, -(weight.Value1()+weight.Value2()));
alignment_writer.Write(key, alignment);
num_success ++;
if (num_success % 50 == 0) {
diff --cc src/gmmbin/gmm-est-gaussians-ebw.cc
index 0000000000000000000000000000000000000000,0000000000000000000000000000000000000000..3712f93c1def0c391d08ae38b370c66db599af8a
new file mode 100644 (file)
new file mode 100644 (file)
--- /dev/null
--- /dev/null
@@@ -1,0 -1,0 +1,115 @@@
++// gmmbin/gmm-est-gaussians-ebw.cc
++
++// Copyright 2009-2011 Petr Motlicek Chao Weng
++
++// Licensed under the Apache License, Version 2.0 (the "License");
++// you may not use this file except in compliance with the License.
++// You may obtain a copy of the License at
++//
++// http://www.apache.org/licenses/LICENSE-2.0
++//
++// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
++// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
++// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
++// MERCHANTABLITY OR NON-INFRINGEMENT.
++// See the Apache 2 License for the specific language governing permissions and
++// limitations under the License.
++
++#include "base/kaldi-common.h"
++#include "util/common-utils.h"
++#include "gmm/am-diag-gmm.h"
++#include "tree/context-dep.h"
++#include "hmm/transition-model.h"
++#include "gmm/ebw-diag-gmm.h"
++
++int main(int argc, char *argv[]) {
++ try {
++ using namespace kaldi;
++ typedef kaldi::int32 int32;
++
++ const char *usage =
++ "Do EBW update for MMI, MPE or MCE discriminative training.\n"
++ "Numerator stats should already be I-smoothed (e.g. use gmm-ismooth-stats)\n"
++ "Usage: gmm-est-gaussians-ebw [options] <model-in> <stats-num-in> <stats-den-in> <model-out>\n"
++ "e.g.: gmm-est-gaussians-ebw 1.mdl num.acc den.acc 2.mdl\n";
++
++ bool binary_write = false;
++ std::string update_flags_str = "mv";
++
++ EbwOptions ebw_opts;
++ ParseOptions po(usage);
++ po.Register("binary", &binary_write, "Write output in binary mode");
++ po.Register("update-flags", &update_flags_str, "Which GMM parameters to "
++ "update: e.g. m or mv (w, t ignored).");
++
++ ebw_opts.Register(&po);
++
++ po.Read(argc, argv);
++
++ if (po.NumArgs() != 4) {
++ po.PrintUsage();
++ exit(1);
++ }
++
++ kaldi::GmmFlagsType update_flags =
++ StringToGmmFlags(update_flags_str);
++
++ std::string model_in_filename = po.GetArg(1),
++ num_stats_filename = po.GetArg(2),
++ den_stats_filename = po.GetArg(3),
++ model_out_filename = po.GetArg(4);
++
++ AmDiagGmm am_gmm;
++ TransitionModel trans_model;
++ {
++ bool binary_read;
++ Input ki(model_in_filename, &binary_read);
++ trans_model.Read(ki.Stream(), binary_read);
++ am_gmm.Read(ki.Stream(), binary_read);
++ }
++
++ Vector<double> num_transition_accs; // won't be used.
++ Vector<double> den_transition_accs; // won't be used.
++
++ AccumAmDiagGmm num_stats;
++ AccumAmDiagGmm den_stats;
++ {
++ bool binary;
++ Input ki(num_stats_filename, &binary);
++ num_transition_accs.Read(ki.Stream(), binary);
++ num_stats.Read(ki.Stream(), binary, true); // true == add; doesn't matter here.
++ }
++
++ {
++ bool binary;
++ Input ki(den_stats_filename, &binary);
++ num_transition_accs.Read(ki.Stream(), binary);
++ den_stats.Read(ki.Stream(), binary, true); // true == add; doesn't matter here.
++ }
++
++
++ { // Update GMMs.
++ BaseFloat auxf_impr = 0.0, count = 0.0;
++ int32 num_floored = 0;
++ UpdateEbwAmDiagGmm(num_stats, den_stats, update_flags, ebw_opts, &am_gmm,
++ &auxf_impr, &count, &num_floored);
++ KALDI_LOG << "Num count " << num_stats.TotCount() << ", den count "
++ << den_stats.TotCount();
++ KALDI_LOG << "Overall auxf impr/frame from Gaussian update is " << (auxf_impr/count)
++ << " over " << count << " frames; floored D for "
++ << num_floored << " Gaussians.";
++ }
++
++ {
++ Output ko(model_out_filename, binary_write);
++ trans_model.Write(ko.Stream(), binary_write);
++ am_gmm.Write(ko.Stream(), binary_write);
++ }
++
++ KALDI_LOG << "Written model to " << model_out_filename;
++
++ } catch(const std::exception& e) {
++ std::cerr << e.what() << '\n';
++ return -1;
++ }
++}
diff --cc src/gmmbin/gmm-est-mmi.cc
index bde3caad49c86daaa8b20ed595a0e382763f1ae9,8cf2977beefe986b6396db83de7d6799d08409c3..b2b9679b68a293252138defe4ce466be24e77736
// gmmbin/gmm-est-mmi.cc
--// Copyright 2009-2011 Petr Motlicek
++// Copyright 2009-2011 Petr Motlicek Chao Weng
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
MmieDiagGmmOptions mmi_opts;
const char *usage =
-- "Do EBW update with I-smoothing for MMI discriminative training.\n"
-- "Usage: gmm-est-mmi [options] <model-in> <stats-num-in> <stats-den-in> <model-out>\n"
-- "e.g.: gmm-est 1.mdl num.acc den.acc 2.mdl\n";
++ "Do EBW update with I-smoothing for MMI, MPE or MCE discriminative training.\n"
++ "Usage: gmm-est-mmi [options] <model-in> <stats-num-in> <stats-den-in> [<stats-ismooth-in>] <model-out>\n"
++ "e.g.: gmm-est 1.mdl num.acc den.acc 2.mdl\n"
++ "or (for MPE): gmm-est 1.mdl num.acc den.acc ml.acc 2.mdl\n";
bool binary_write = false;
//TransitionUpdateConfig tcfg;
po.Read(argc, argv);
-- if (po.NumArgs() != 4) {
++ if (po.NumArgs() < 4 || po.NumArgs() > 5) {
po.PrintUsage();
exit(1);
}
std::string model_in_filename = po.GetArg(1),
num_stats_filename = po.GetArg(2),
den_stats_filename = po.GetArg(3),
-- model_out_filename = po.GetArg(4);
--
++ i_smooth_stats_filename = (po.NumArgs() == 5 ? po.GetArg(4) : ""),
++ model_out_filename = po.GetArg(po.NumArgs());
AmDiagGmm am_gmm;
TransitionModel trans_model;
}
{
bool binary;
- Input is(den_stats_filename, &binary);
- num_transition_accs.Read(is.Stream(), binary);
- mmi_accs.ReadDen(is.Stream(), binary, true); // true == add; doesn't matter here.
+ Input ki(den_stats_filename, &binary);
+ num_transition_accs.Read(ki.Stream(), binary);
+ mmi_accs.ReadDen(ki.Stream(), binary, true); // true == add; doesn't matter here.
}
-
-
+
+ if (!i_smooth_stats_filename.empty()) {
+ mmi_opts.has_i_smooth_stats = true;
+ bool binary;
+ Input is(i_smooth_stats_filename, &binary);
+ num_transition_accs.Read(is.Stream(), binary); // not sure here.. probably useless.
+ mmi_accs.ReadISmooth(is.Stream(), binary, true);
+ }
+
{ // Update GMMs.
BaseFloat auxf_impr_gauss, auxf_impr_weights, count;
int32 num_floored;
diff --cc src/gmmbin/gmm-est-weights-ebw.cc
index 0000000000000000000000000000000000000000,0000000000000000000000000000000000000000..154e6ee91547ed7a425e31d11122f1ca797802c9
new file mode 100644 (file)
new file mode 100644 (file)
--- /dev/null
--- /dev/null
@@@ -1,0 -1,0 +1,114 @@@
++// gmmbin/gmm-est-weights-ebw.cc
++
++// Copyright 2009-2011 Petr Motlicek Chao Weng
++
++// Licensed under the Apache License, Version 2.0 (the "License");
++// you may not use this file except in compliance with the License.
++// You may obtain a copy of the License at
++//
++// http://www.apache.org/licenses/LICENSE-2.0
++//
++// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
++// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
++// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
++// MERCHANTABLITY OR NON-INFRINGEMENT.
++// See the Apache 2 License for the specific language governing permissions and
++// limitations under the License.
++
++#include "base/kaldi-common.h"
++#include "util/common-utils.h"
++#include "gmm/am-diag-gmm.h"
++#include "tree/context-dep.h"
++#include "hmm/transition-model.h"
++#include "gmm/ebw-diag-gmm.h"
++
++int main(int argc, char *argv[]) {
++ try {
++ using namespace kaldi;
++ typedef kaldi::int32 int32;
++
++ const char *usage =
++ "Do EBW update on weights for MMI, MPE or MCE discriminative training.\n"
++ "Numerator stats should not be I-smoothed\n"
++ "Usage: gmm-est-weights-ebw [options] <model-in> <stats-num-in> <stats-den-in> <model-out>\n"
++ "e.g.: gmm-est-weights-ebw 1.mdl num.acc den.acc 2.mdl\n";
++
++ bool binary_write = false;
++ std::string update_flags_str = "w";
++
++ EbwWeightOptions ebw_weight_opts;
++ ParseOptions po(usage);
++ po.Register("binary", &binary_write, "Write output in binary mode");
++ po.Register("update-flags", &update_flags_str, "Which GMM parameters to "
++ "update; only \"w\" flag is looked at.");
++
++ ebw_weight_opts.Register(&po);
++
++ po.Read(argc, argv);
++
++ if (po.NumArgs() != 4) {
++ po.PrintUsage();
++ exit(1);
++ }
++
++ kaldi::GmmFlagsType update_flags =
++ StringToGmmFlags(update_flags_str);
++
++ std::string model_in_filename = po.GetArg(1),
++ num_stats_filename = po.GetArg(2),
++ den_stats_filename = po.GetArg(3),
++ model_out_filename = po.GetArg(4);
++
++ AmDiagGmm am_gmm;
++ TransitionModel trans_model;
++ {
++ bool binary_read;
++ Input ki(model_in_filename, &binary_read);
++ trans_model.Read(ki.Stream(), binary_read);
++ am_gmm.Read(ki.Stream(), binary_read);
++ }
++
++ Vector<double> num_transition_accs; // won't be used.
++ Vector<double> den_transition_accs; // won't be used.
++
++ AccumAmDiagGmm num_stats;
++ AccumAmDiagGmm den_stats;
++ {
++ bool binary;
++ Input ki(num_stats_filename, &binary);
++ num_transition_accs.Read(ki.Stream(), binary);
++ num_stats.Read(ki.Stream(), binary, true); // true == add; doesn't matter here.
++ }
++
++ {
++ bool binary;
++ Input ki(den_stats_filename, &binary);
++ num_transition_accs.Read(ki.Stream(), binary);
++ den_stats.Read(ki.Stream(), binary, true); // true == add; doesn't matter here.
++ }
++
++ if (update_flags & kGmmWeights) { // Update weights.
++ BaseFloat auxf_impr = 0.0, count = 0.0;
++ UpdateEbwWeightsAmDiagGmm(num_stats, den_stats, ebw_weight_opts, &am_gmm,
++ &auxf_impr, &count);
++ KALDI_LOG << "Num count " << num_stats.TotCount() << ", den count "
++ << den_stats.TotCount();
++ KALDI_LOG << "Overall auxf impr/frame from weight update is " << (auxf_impr/count)
++ << " over " << count << " frames.";
++ } else {
++ KALDI_LOG << "Doing nothing because flags do not specify to update the weights.";
++ }
++
++ {
++ Output ko(model_out_filename, binary_write);
++ trans_model.Write(ko.Stream(), binary_write);
++ am_gmm.Write(ko.Stream(), binary_write);
++ }
++
++ KALDI_LOG << "Written model to " << model_out_filename;
++
++ } catch(const std::exception& e) {
++ std::cerr << e.what() << '\n';
++ return -1;
++ }
++}
diff --cc src/gmmbin/gmm-est.cc
index 6bc4ca87ce69680b8eede04b1bd406eeaffb9f83,94c6d9c2ed575d75a6c68d9cdfaa9ac996e843fa..ac72728b44009afc112fa6924b7e0f90a7c3cfef
+++ b/src/gmmbin/gmm-est.cc
state_occs.Resize(gmm_accs.NumAccs());
for (int i = 0; i < gmm_accs.NumAccs(); i++)
state_occs(i) = gmm_accs.GetAcc(i).occupancy().Sum();
--
++
if (mixdown != 0)
am_gmm.MergeByCount(state_occs, mixdown, power, min_count);
diff --cc src/gmmbin/gmm-ismooth-stats.cc
index 0000000000000000000000000000000000000000,0000000000000000000000000000000000000000..9bf1fa1e972305efc8daaa1d145cb642b52e9837
new file mode 100644 (file)
new file mode 100644 (file)
--- /dev/null
--- /dev/null
@@@ -1,0 -1,0 +1,119 @@@
++// gmmbin/gmm-ismooth-stats.cc
++
++// Copyright 2009-2011 Petr Motlicek Chao Weng
++
++// Licensed under the Apache License, Version 2.0 (the "License");
++// you may not use this file except in compliance with the License.
++// You may obtain a copy of the License at
++//
++// http://www.apache.org/licenses/LICENSE-2.0
++//
++// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
++// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
++// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
++// MERCHANTABLITY OR NON-INFRINGEMENT.
++// See the Apache 2 License for the specific language governing permissions and
++// limitations under the License.
++
++#include "base/kaldi-common.h"
++#include "util/common-utils.h"
++#include "gmm/am-diag-gmm.h"
++#include "tree/context-dep.h"
++#include "hmm/transition-model.h"
++#include "gmm/ebw-diag-gmm.h"
++
++int main(int argc, char *argv[]) {
++ try {
++ using namespace kaldi;
++ typedef kaldi::int32 int32;
++
++ const char *usage =
++ "Apply I-smoothing to statistics, e.g. for discriminative training\n"
++ "Usage: gmm-ismooth-stats [options] [--smooth-from-model] [<src-stats-in>|<src-model-in>] <dst-stats-in> <stats-out>\n"
++ "e.g.: gmm-ismooth-stats --tau=100 ml.acc num.acc smoothed.acc\n"
++ "or: gmm-ismooth-stats --tau=50 --smooth-from-model 1.mdl num.acc smoothed.acc\n"
++ "or: gmm-ismooth-stats --tau=100 num.acc num.acc smoothed.acc\n";
++
++ bool binary_write = false;
++ bool smooth_from_model = false;
++ BaseFloat tau = 100;
++
++ ParseOptions po(usage);
++ po.Register("binary", &binary_write, "Write output in binary mode");
++ po.Register("smooth-from-model", &smooth_from_model, "Expect second argument to be a model file");
++ po.Register("tau", &tau, "Tau value for I-smoothing");
++
++ po.Read(argc, argv);
++
++ if (po.NumArgs() != 3) {
++ po.PrintUsage();
++ exit(1);
++ }
++
++ std::string src_stats_or_model_filename = po.GetArg(1),
++ dst_stats_filename = po.GetArg(2),
++ stats_out_filename = po.GetArg(3);
++
++ if (src_stats_or_model_filename == dst_stats_filename) { // as an optimization, just read once.
++ KALDI_ASSERT(!smooth_from_model);
++ Vector<double> transition_accs;
++ AccumAmDiagGmm stats;
++ {
++ bool binary;
++ Input ki(dst_stats_filename, &binary);
++ transition_accs.Read(ki.Stream(), binary);
++ stats.Read(ki.Stream(), binary, true); // true == add; doesn't matter here.
++ }
++ stats.SmoothStats(tau);
++ Output ko(stats_out_filename, binary_write);
++ transition_accs.Write(ko.Stream(), binary_write);
++ stats.Write(ko.Stream(), binary_write);
++ } else if (smooth_from_model) { // Smoothing from model...
++ AmDiagGmm am_gmm;
++ TransitionModel trans_model;
++ Vector<double> dst_transition_accs;
++ AccumAmDiagGmm dst_stats;
++ { // read src model
++ bool binary;
++ Input ki(src_stats_or_model_filename, &binary);
++ trans_model.Read(ki.Stream(), binary);
++ am_gmm.Read(ki.Stream(), binary);
++ }
++ { // read dst stats.
++ bool binary;
++ Input ki(dst_stats_filename, &binary);
++ dst_transition_accs.Read(ki.Stream(), binary);
++ dst_stats.Read(ki.Stream(), binary, true); // true == add; doesn't matter here.
++ }
++ dst_stats.SmoothWithModel(tau, am_gmm);
++ Output ko(stats_out_filename, binary_write);
++ dst_transition_accs.Write(ko.Stream(), binary_write);
++ dst_stats.Write(ko.Stream(), binary_write);
++ } else { // Smooth from stats.
++ Vector<double> src_transition_accs;
++ Vector<double> dst_transition_accs;
++ AccumAmDiagGmm src_stats;
++ AccumAmDiagGmm dst_stats;
++ { // read src stats.
++ bool binary;
++ Input ki(src_stats_or_model_filename, &binary);
++ src_transition_accs.Read(ki.Stream(), binary);
++ src_stats.Read(ki.Stream(), binary, true); // true == add; doesn't matter here.
++ }
++ { // read dst stats.
++ bool binary;
++ Input ki(dst_stats_filename, &binary);
++ dst_transition_accs.Read(ki.Stream(), binary);
++ dst_stats.Read(ki.Stream(), binary, true); // true == add; doesn't matter here.
++ }
++ dst_stats.SmoothWithAccum(tau, src_stats);
++ Output ko(stats_out_filename, binary_write);
++ dst_transition_accs.Write(ko.Stream(), binary_write);
++ dst_stats.Write(ko.Stream(), binary_write);
++ }
++ KALDI_LOG << "Smoothed stats with tau = " << tau;
++ } catch(const std::exception& e) {
++ std::cerr << e.what() << '\n';
++ return -1;
++ }
++}
diff --cc src/latbin/lattice-to-post.cc
index 23b3513a9a655f866f86e24bd8ba3f36f9dbe2b7,b769c0c7cb6657eb2abd7638323469d95ae49fe7..04af0741cb28aab32ea133b84f36de583e15c547
}
if (acoustic_scale == 0.0)
- KALDI_EXIT << "Do not use a zero acoustic scale (cannot be inverted)";
+ KALDI_ERR << "Do not use a zero acoustic scale (cannot be inverted)";
std::string lats_rspecifier = po.GetArg(1),
- posteriors_wspecifier = po.GetArg(2);
+ posteriors_wspecifier = po.GetArg(2),
+ scores_wspecifier = po.GetOptArg(3);
// Read as regular lattice
kaldi::SequentialLatticeReader lattice_reader(lats_rspecifier);