[src,scripts,egs] Attention modeling, with example scripts (#1731)
authorDaniel Povey <dpovey@gmail.com>
Fri, 15 Sep 2017 20:41:19 +0000 (16:41 -0400)
committerGitHub <noreply@github.com>
Fri, 15 Sep 2017 20:41:19 +0000 (16:41 -0400)
23 files changed:
egs/swbd/s5c/local/chain/run_tdnn_attention.sh [new symlink]
egs/swbd/s5c/local/chain/tuning/run_tdnn_attention_1a.sh [new file with mode: 0755]
egs/tedlium/s5_r2/local/chain/run_tdnn_lstm_attention.sh [new symlink]
egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_attention_1a.sh [new file with mode: 0755]
egs/wsj/s5/steps/libs/nnet3/xconfig/attention.py [new file with mode: 0644]
egs/wsj/s5/steps/libs/nnet3/xconfig/layers.py
egs/wsj/s5/steps/libs/nnet3/xconfig/parser.py
src/nnet3/Makefile
src/nnet3/attention-test.cc [new file with mode: 0644]
src/nnet3/attention.cc [new file with mode: 0644]
src/nnet3/attention.h [new file with mode: 0644]
src/nnet3/convolution.cc
src/nnet3/convolution.h
src/nnet3/nnet-attention-component.cc [new file with mode: 0644]
src/nnet3/nnet-attention-component.h [new file with mode: 0644]
src/nnet3/nnet-compile-utils.cc
src/nnet3/nnet-compile-utils.h
src/nnet3/nnet-component-itf.cc
src/nnet3/nnet-compute-test.cc
src/nnet3/nnet-convolutional-component.cc
src/nnet3/nnet-convolutional-component.h
src/nnet3/nnet-test-utils.cc
src/nnet3/nnet-utils.cc

diff --git a/egs/swbd/s5c/local/chain/run_tdnn_attention.sh b/egs/swbd/s5c/local/chain/run_tdnn_attention.sh
new file mode 120000 (symlink)
index 0000000..6bc0859
--- /dev/null
@@ -0,0 +1 @@
+tuning/run_tdnn_attention_1a.sh
\ No newline at end of file
diff --git a/egs/swbd/s5c/local/chain/tuning/run_tdnn_attention_1a.sh b/egs/swbd/s5c/local/chain/tuning/run_tdnn_attention_1a.sh
new file mode 100755 (executable)
index 0000000..3ce4fa6
--- /dev/null
@@ -0,0 +1,269 @@
+#!/bin/bash
+
+# In this recipe everything is the same as tdnn_7k, except the
+# 7th TDNN layer has been replaced with an attention layer
+
+# local/chain/compare_wer_general.sh exp/chain/tdnn_7k_sp exp/chain/tdnn_attend_1a_sp
+# System                tdnn_7k_sp tdnn_attend_1a_sp
+# WER on train_dev(tg)      13.93     13.76
+# WER on train_dev(fg)      12.85     12.62
+# WER on eval2000(tg)        16.7      16.2
+# WER on eval2000(fg)        15.0      14.5
+# Final train prob         -0.085    -0.076
+# Final valid prob         -0.106    -0.098
+# Final train prob (xent)        -1.260    -0.997
+# Final valid prob (xent)       -1.3193   -1.0887
+
+# steps/info/chain_dir_info.pl exp/chain/tdnn_attend_1a_sp
+# exp/chain/tdnn_attend_1a_sp/: num-iters=262 nj=3..16 num-params=16.8M dim=40+100->6076 combine=-0.095->-0.095 xent:train/valid[173,261,final]=(-1.06,-0.993,-0.997/-1.14,-1.09,-1.09) logprob:train/valid[173,261,final]=(-0.084,-0.076,-0.076/-0.104,-0.099,-0.098)
+
+# steps/info/chain_dir_info.pl exp/chain/tdnn_7k_sp
+# exp/chain/tdnn_7k_sp: num-iters=262 nj=3..16 num-params=15.6M dim=40+100->6076 combine=-0.106->-0.106 xent:train/valid[173,261,final]=(-1.32,-1.25,-1.26/-1.36,-1.31,-1.32) logprob:train/valid[173,261,final]=(-0.093,-0.085,-0.085/-0.110,-0.106,-0.106)
+
+set -e
+
+# configs for 'chain'
+affix=1a
+stage=12
+train_stage=-10
+get_egs_stage=-10
+speed_perturb=true
+dir=exp/chain/tdnn_attend  # Note: _sp will get added to this if $speed_perturb == true.
+decode_iter=
+decode_nj=50
+
+# training options
+num_epochs=4
+initial_effective_lrate=0.001
+final_effective_lrate=0.0001
+leftmost_questions_truncate=-1
+max_param_change=2.0
+final_layer_normalize_target=0.5
+num_jobs_initial=3
+num_jobs_final=16
+minibatch_size=128
+frames_per_eg=150
+remove_egs=false
+common_egs_dir=
+xent_regularize=0.1
+
+test_online_decoding=false  # if true, it will run the last decoding stage.
+
+# End configuration section.
+echo "$0 $@"  # Print the command line for logging
+
+. ./cmd.sh
+. ./path.sh
+. ./utils/parse_options.sh
+
+if ! cuda-compiled; then
+  cat <<EOF && exit 1
+This script is intended to be used with GPUs but you have not compiled Kaldi with CUDA
+If you want to use GPUs (and have them), go to src/, and configure and make on a machine
+where "nvcc" is installed.
+EOF
+fi
+
+# The iVector-extraction and feature-dumping parts are the same as the standard
+# nnet3 setup, and you can skip them by setting "--stage 8" if you have already
+# run those things.
+
+suffix=
+if [ "$speed_perturb" == "true" ]; then
+  suffix=_sp
+fi
+
+dir=${dir}${affix:+_$affix}$suffix
+train_set=train_nodup$suffix
+ali_dir=exp/tri4_ali_nodup$suffix
+treedir=exp/chain/tri5_7d_tree$suffix
+lang=data/lang_chain_2y
+
+
+# if we are using the speed-perturbed data we need to generate
+# alignments for it.
+local/nnet3/run_ivector_common.sh --stage $stage \
+  --speed-perturb $speed_perturb \
+  --generate-alignments $speed_perturb || exit 1;
+
+
+if [ $stage -le 9 ]; then
+  # Get the alignments as lattices (gives the LF-MMI training more freedom).
+  # use the same num-jobs as the alignments
+  nj=$(cat exp/tri4_ali_nodup$suffix/num_jobs) || exit 1;
+  steps/align_fmllr_lats.sh --nj $nj --cmd "$train_cmd" data/$train_set \
+    data/lang exp/tri4 exp/tri4_lats_nodup$suffix
+  rm exp/tri4_lats_nodup$suffix/fsts.*.gz # save space
+fi
+
+
+if [ $stage -le 10 ]; then
+  # Create a version of the lang/ directory that has one state per phone in the
+  # topo file. [note, it really has two states.. the first one is only repeated
+  # once, the second one has zero or more repeats.]
+  rm -rf $lang
+  cp -r data/lang $lang
+  silphonelist=$(cat $lang/phones/silence.csl) || exit 1;
+  nonsilphonelist=$(cat $lang/phones/nonsilence.csl) || exit 1;
+  # Use our special topology... note that later on may have to tune this
+  # topology.
+  steps/nnet3/chain/gen_topo.py $nonsilphonelist $silphonelist >$lang/topo
+fi
+
+if [ $stage -le 11 ]; then
+  # Build a tree using our new topology. This is the critically different
+  # step compared with other recipes.
+  steps/nnet3/chain/build_tree.sh --frame-subsampling-factor 3 \
+      --leftmost-questions-truncate $leftmost_questions_truncate \
+      --context-opts "--context-width=2 --central-position=1" \
+      --cmd "$train_cmd" 7000 data/$train_set $lang $ali_dir $treedir
+fi
+
+if [ $stage -le 12 ]; then
+  echo "$0: creating neural net configs using the xconfig parser";
+  num_targets=$(tree-info $treedir/tree |grep num-pdfs|awk '{print $2}')
+  learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python)
+
+  mkdir -p $dir/configs
+  cat <<EOF > $dir/configs/network.xconfig
+  input dim=100 name=ivector
+  input dim=40 name=input
+
+  # please note that it is important to have input layer with the name=input
+  # as the layer immediately preceding the fixed-affine-layer to enable
+  # the use of short notation for the descriptor
+  fixed-affine-layer name=lda input=Append(-1,0,1,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat
+
+  # the first splicing is moved before the lda layer, so no splicing here
+  relu-batchnorm-layer name=tdnn1 dim=625
+  relu-batchnorm-layer name=tdnn2 input=Append(-1,0,1) dim=625
+  relu-batchnorm-layer name=tdnn3 input=Append(-1,0,1) dim=625
+  relu-batchnorm-layer name=tdnn4 input=Append(-3,0,3) dim=625
+  relu-batchnorm-layer name=tdnn5 input=Append(-3,0,3) dim=625
+  relu-batchnorm-layer name=tdnn6 input=Append(-3,0,3) dim=625
+  attention-relu-renorm-layer name=attention1 num-heads=15 value-dim=80 key-dim=40 num-left-inputs=5 num-right-inputs=2 time-stride=3
+
+  ## adding the layers for chain branch
+  relu-batchnorm-layer name=prefinal-chain input=attention1 dim=625 target-rms=0.5
+  output-layer name=output include-log-softmax=false dim=$num_targets max-change=1.5
+
+  # adding the layers for xent branch
+  # This block prints the configs for a separate output that will be
+  # trained with a cross-entropy objective in the 'chain' models... this
+  # has the effect of regularizing the hidden parts of the model.  we use
+  # 0.5 / args.xent_regularize as the learning rate factor- the factor of
+  # 0.5 / args.xent_regularize is suitable as it means the xent
+  # final-layer learns at a rate independent of the regularization
+  # constant; and the 0.5 was tuned so as to make the relative progress
+  # similar in the xent and regular final layers.
+  relu-batchnorm-layer name=prefinal-xent input=attention1 dim=625 target-rms=0.5
+  output-layer name=output-xent dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5
+
+EOF
+  steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/
+fi
+
+if [ $stage -le 13 ]; then
+  if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then
+    utils/create_split_dir.pl \
+     /export/b0{5,6,7,8}/$USER/kaldi-data/egs/swbd-$(date +'%m_%d_%H_%M')/s5c/$dir/egs/storage $dir/egs/storage
+  fi
+
+  steps/nnet3/chain/train.py --stage $train_stage \
+    --cmd "$decode_cmd" \
+    --feat.online-ivector-dir exp/nnet3/ivectors_${train_set} \
+    --feat.cmvn-opts "--norm-means=false --norm-vars=false" \
+    --chain.xent-regularize $xent_regularize \
+    --chain.leaky-hmm-coefficient 0.1 \
+    --chain.l2-regularize 0.00005 \
+    --chain.apply-deriv-weights false \
+    --chain.lm-opts="--num-extra-lm-states=2000" \
+    --egs.dir "$common_egs_dir" \
+    --egs.stage $get_egs_stage \
+    --egs.opts "--frames-overlap-per-eg 0" \
+    --egs.chunk-width $frames_per_eg \
+    --trainer.num-chunk-per-minibatch $minibatch_size \
+    --trainer.frames-per-iter 1500000 \
+    --trainer.num-epochs $num_epochs \
+    --trainer.optimization.num-jobs-initial $num_jobs_initial \
+    --trainer.optimization.num-jobs-final $num_jobs_final \
+    --trainer.optimization.initial-effective-lrate $initial_effective_lrate \
+    --trainer.optimization.final-effective-lrate $final_effective_lrate \
+    --trainer.max-param-change $max_param_change \
+    --cleanup.remove-egs $remove_egs \
+    --feat-dir data/${train_set}_hires \
+    --tree-dir $treedir \
+    --lat-dir exp/tri4_lats_nodup$suffix \
+    --dir $dir  || exit 1;
+
+fi
+
+if [ $stage -le 14 ]; then
+  # Note: it might appear that this $lang directory is mismatched, and it is as
+  # far as the 'topo' is concerned, but this script doesn't read the 'topo' from
+  # the lang directory.
+  utils/mkgraph.sh --self-loop-scale 1.0 data/lang_sw1_tg $dir $dir/graph_sw1_tg
+fi
+
+
+graph_dir=$dir/graph_sw1_tg
+iter_opts=
+if [ ! -z $decode_iter ]; then
+  iter_opts=" --iter $decode_iter "
+fi
+if [ $stage -le 15 ]; then
+  rm $dir/.error 2>/dev/null || true
+  for decode_set in train_dev eval2000; do
+      (
+      steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \
+          --nj $decode_nj --cmd "$decode_cmd" $iter_opts \
+          --online-ivector-dir exp/nnet3/ivectors_${decode_set} \
+          $graph_dir data/${decode_set}_hires \
+          $dir/decode_${decode_set}${decode_iter:+_$decode_iter}_sw1_tg || exit 1;
+      if $has_fisher; then
+          steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" \
+            data/lang_sw1_{tg,fsh_fg} data/${decode_set}_hires \
+            $dir/decode_${decode_set}${decode_iter:+_$decode_iter}_sw1_{tg,fsh_fg} || exit 1;
+      fi
+      ) || touch $dir/.error &
+  done
+  wait
+  if [ -f $dir/.error ]; then
+    echo "$0: something went wrong in decoding"
+    exit 1
+  fi
+fi
+
+if $test_online_decoding && [ $stage -le 16 ]; then
+  # note: if the features change (e.g. you add pitch features), you will have to
+  # change the options of the following command line.
+  steps/online/nnet3/prepare_online_decoding.sh \
+       --mfcc-config conf/mfcc_hires.conf \
+       $lang exp/nnet3/extractor $dir ${dir}_online
+
+  rm $dir/.error 2>/dev/null || true
+  for decode_set in train_dev eval2000; do
+    (
+      # note: we just give it "$decode_set" as it only uses the wav.scp, the
+      # feature type does not matter.
+
+      steps/online/nnet3/decode.sh --nj $decode_nj --cmd "$decode_cmd" \
+          --acwt 1.0 --post-decode-acwt 10.0 \
+         $graph_dir data/${decode_set}_hires \
+         ${dir}_online/decode_${decode_set}${decode_iter:+_$decode_iter}_sw1_tg || exit 1;
+      if $has_fisher; then
+          steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" \
+            data/lang_sw1_{tg,fsh_fg} data/${decode_set}_hires \
+            ${dir}_online/decode_${decode_set}${decode_iter:+_$decode_iter}_sw1_{tg,fsh_fg} || exit 1;
+      fi
+    ) || touch $dir/.error &
+  done
+  wait
+  if [ -f $dir/.error ]; then
+    echo "$0: something went wrong in decoding"
+    exit 1
+  fi
+fi
+
+
+exit 0;
diff --git a/egs/tedlium/s5_r2/local/chain/run_tdnn_lstm_attention.sh b/egs/tedlium/s5_r2/local/chain/run_tdnn_lstm_attention.sh
new file mode 120000 (symlink)
index 0000000..6af9f10
--- /dev/null
@@ -0,0 +1 @@
+tuning/run_tdnn_lstm_attention_1a.sh
\ No newline at end of file
diff --git a/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_attention_1a.sh b/egs/tedlium/s5_r2/local/chain/tuning/run_tdnn_lstm_attention_1a.sh
new file mode 100755 (executable)
index 0000000..31470e9
--- /dev/null
@@ -0,0 +1,341 @@
+#!/bin/bash
+
+# In this recipe we replace the last LSTM layer with an attention layer
+# which leads to some consistent improvements in WER
+
+# local/chain/compare_wer_general.sh --looped exp/chain_cleaned/tdnn_lstm1e_sp_bi exp/chain_cleaned/tdnn_lstm_attend1a_sp_bi
+# System                tdnn_lstm1e_sp_bi tdnn_lstm_attend1a_sp_bi
+# WER on dev(orig)            8.9       8.4
+#         [looped:]           8.9       8.5
+# WER on dev(rescored)        8.3       8.0
+#         [looped:]           8.3       8.1
+# WER on test(orig)           9.0       8.8
+#         [looped:]           8.9       8.8
+# WER on test(rescored)       8.5       8.2
+#         [looped:]           8.5       8.3
+# Final train prob        -0.0702   -0.0638
+# Final valid prob        -0.0920   -0.0897
+# Final train prob (xent)   -0.8499   -0.8189
+# Final valid prob (xent)   -0.9621   -0.9234
+
+
+# This setup has 3.5M more parameters than the baseline (see below) but most of
+# these extra parameters are due to the xent branch which is removed at
+# test time. There is a 20% decoding speed-up compared to TDNN-LSTM baseline.
+
+# steps/info/chain_dir_info.pl exp/chain_cleaned/tdnn_lstm_attend1a_sp_bi
+# exp/chain_cleaned/tdnn_lstm_attend1a_sp_bi: num-iters=253 nj=2..12 num-params=13.0M dim=40+100->3604 combine=-0.075->-0.074 xent:train/valid[167,252,final]=(-0.937,-0.827,-0.819/-0.996,-0.932,-0.923) logprob:train/valid[167,252,final]=(-0.078,-0.066,-0.064/-0.093,-0.091,-0.090)
+
+
+# steps/info/chain_dir_info.pl exp/chain_cleaned/tdnn_lstm1e_sp_bi
+# exp/chain_cleaned/tdnn_lstm1e_sp_bi/: num-iters=253 nj=2..12 num-params=9.5M dim=40+100->3604 combine=-0.084->-0.082 xent:train/valid[167,252,final]=(-0.944,-0.852,-0.850/-1.03,-0.971,-0.962) logprob:train/valid[167,252,final]=(-0.082,-0.071,-0.070/-0.098,-0.094,-0.092)
+
+set -e -o pipefail
+
+# First the options that are passed through to run_ivector_common.sh
+# (some of which are also used in this script directly).
+stage=17
+nj=30
+decode_nj=30
+min_seg_len=1.55
+label_delay=5
+xent_regularize=0.1
+train_set=train_cleaned
+gmm=tri3_cleaned  # the gmm for the target data
+num_threads_ubm=32
+nnet3_affix=_cleaned  # cleanup affix for nnet3 and chain dirs, e.g. _cleaned
+# training options
+chunk_left_context=40
+chunk_right_context=0
+chunk_left_context_initial=0
+chunk_right_context_final=0
+frames_per_chunk=140,100,160
+# decode options
+frames_per_chunk_primary=$(echo $frames_per_chunk | cut -d, -f1)
+extra_left_context=50
+extra_right_context=0
+extra_left_context_initial=0
+extra_right_context_final=0
+
+
+# The rest are configs specific to this script.  Most of the parameters
+# are just hardcoded at this level, in the commands below.
+train_stage=-10
+tree_affix=  # affix for tree directory, e.g. "a" or "b", in case we change the configuration.
+affix=1a     # affix for TDNN-LSTM-Attention directory, e.g. "a" or "b", in case we change the configuration.
+common_egs_dir=    # you can set this to use previously dumped egs.
+remove_egs=true
+
+test_online_decoding=false  # if true, it will run the last decoding stage.
+
+# End configuration section.
+echo "$0 $@"  # Print the command line for logging
+
+. cmd.sh
+. ./path.sh
+. ./utils/parse_options.sh
+
+
+if ! cuda-compiled; then
+  cat <<EOF && exit 1
+This script is intended to be used with GPUs but you have not compiled Kaldi with CUDA
+If you want to use GPUs (and have them), go to src/, and configure and make on a machine
+where "nvcc" is installed.
+EOF
+fi
+
+local/nnet3/run_ivector_common.sh --stage $stage \
+                                  --nj $nj \
+                                  --min-seg-len $min_seg_len \
+                                  --train-set $train_set \
+                                  --gmm $gmm \
+                                  --num-threads-ubm $num_threads_ubm \
+                             --nnet3-affix "$nnet3_affix"
+
+
+gmm_dir=exp/$gmm
+ali_dir=exp/${gmm}_ali_${train_set}_sp_comb
+tree_dir=exp/chain${nnet3_affix}/tree_bi${tree_affix}
+lat_dir=exp/chain${nnet3_affix}/${gmm}_${train_set}_sp_comb_lats
+dir=exp/chain${nnet3_affix}/tdnn_lstm_attend${affix}_sp_bi
+train_data_dir=data/${train_set}_sp_hires_comb
+lores_train_data_dir=data/${train_set}_sp_comb
+train_ivector_dir=exp/nnet3${nnet3_affix}/ivectors_${train_set}_sp_hires_comb
+
+
+for f in $gmm_dir/final.mdl $train_data_dir/feats.scp $train_ivector_dir/ivector_online.scp \
+    $lores_train_data_dir/feats.scp $ali_dir/ali.1.gz $gmm_dir/final.mdl; do
+  [ ! -f $f ] && echo "$0: expected file $f to exist" && exit 1
+done
+
+if [ $stage -le 14 ]; then
+  echo "$0: creating lang directory with one state per phone."
+  # Create a version of the lang/ directory that has one state per phone in the
+  # topo file. [note, it really has two states.. the first one is only repeated
+  # once, the second one has zero or more repeats.]
+  if [ -d data/lang_chain ]; then
+    if [ data/lang_chain/L.fst -nt data/lang/L.fst ]; then
+      echo "$0: data/lang_chain already exists, not overwriting it; continuing"
+    else
+      echo "$0: data/lang_chain already exists and seems to be older than data/lang..."
+      echo " ... not sure what to do.  Exiting."
+      exit 1;
+    fi
+  else
+    cp -r data/lang data/lang_chain
+    silphonelist=$(cat data/lang_chain/phones/silence.csl) || exit 1;
+    nonsilphonelist=$(cat data/lang_chain/phones/nonsilence.csl) || exit 1;
+    # Use our special topology... note that later on may have to tune this
+    # topology.
+    steps/nnet3/chain/gen_topo.py $nonsilphonelist $silphonelist >data/lang_chain/topo
+  fi
+fi
+
+if [ $stage -le 15 ]; then
+  # Get the alignments as lattices (gives the chain training more freedom).
+  # use the same num-jobs as the alignments
+  steps/align_fmllr_lats.sh --nj 100 --cmd "$train_cmd" ${lores_train_data_dir} \
+    data/lang $gmm_dir $lat_dir
+  rm $lat_dir/fsts.*.gz # save space
+fi
+
+if [ $stage -le 16 ]; then
+  # Build a tree using our new topology.  We know we have alignments for the
+  # speed-perturbed data (local/nnet3/run_ivector_common.sh made them), so use
+  # those.
+  if [ -f $tree_dir/final.mdl ]; then
+    echo "$0: $tree_dir/final.mdl already exists, refusing to overwrite it."
+    exit 1;
+  fi
+  steps/nnet3/chain/build_tree.sh --frame-subsampling-factor 3 \
+      --context-opts "--context-width=2 --central-position=1" \
+      --leftmost-questions-truncate -1 \
+      --cmd "$train_cmd" 4000 ${lores_train_data_dir} data/lang_chain $ali_dir $tree_dir
+fi
+
+
+if [ $stage -le 17 ]; then
+  mkdir -p $dir
+  echo "$0: creating neural net configs using the xconfig parser";
+
+  num_targets=$(tree-info $tree_dir/tree |grep num-pdfs|awk '{print $2}')
+  learning_rate_factor=$(echo "print 0.5/$xent_regularize" | python)
+
+  mkdir -p $dir/configs
+  cat <<EOF > $dir/configs/network.xconfig
+  input dim=100 name=ivector
+  input dim=40 name=input
+
+  # please note that it is important to have input layer with the name=input
+  # as the layer immediately preceding the fixed-affine-layer to enable
+  # the use of short notation for the descriptor
+  fixed-affine-layer name=lda input=Append(-2,-1,0,1,2,ReplaceIndex(ivector, t, 0)) affine-transform-file=$dir/configs/lda.mat
+
+  # the first splicing is moved before the lda layer, so no splicing here
+  relu-renorm-layer name=tdnn1 dim=512
+  relu-renorm-layer name=tdnn2 dim=512 input=Append(-1,0,1)
+  fast-lstmp-layer name=lstm1 cell-dim=512 recurrent-projection-dim=128 non-recurrent-projection-dim=128 decay-time=20 delay=-3
+  relu-renorm-layer name=tdnn3 dim=512 input=Append(-3,0,3)
+  relu-renorm-layer name=tdnn4 dim=512 input=Append(-3,0,3)
+  fast-lstmp-layer name=lstm2 cell-dim=512 recurrent-projection-dim=128 non-recurrent-projection-dim=128 decay-time=20 delay=-3
+  relu-renorm-layer name=tdnn5 dim=512 input=Append(-3,0,3)
+  relu-renorm-layer name=tdnn6 dim=512 input=Append(-3,0,3)
+  attention-relu-renorm-layer name=attention1 time-stride=3 num-heads=12 value-dim=60 key-dim=40 num-left-inputs=5 num-right-inputs=2
+
+
+  ## adding the layers for chain branch
+  output-layer name=output input=attention1 output-delay=$label_delay include-log-softmax=false dim=$num_targets max-change=1.5
+
+  # adding the layers for xent branch
+  # This block prints the configs for a separate output that will be
+  # trained with a cross-entropy objective in the 'chain' models... this
+  # has the effect of regularizing the hidden parts of the model.  we use
+  # 0.5 / args.xent_regularize as the learning rate factor- the factor of
+  # 0.5 / args.xent_regularize is suitable as it means the xent
+  # final-layer learns at a rate independent of the regularization
+  # constant; and the 0.5 was tuned so as to make the relative progress
+  # similar in the xent and regular final layers.
+  output-layer name=output-xent input=attention1 output-delay=$label_delay dim=$num_targets learning-rate-factor=$learning_rate_factor max-change=1.5
+
+EOF
+  steps/nnet3/xconfig_to_configs.py --xconfig-file $dir/configs/network.xconfig --config-dir $dir/configs/
+fi
+
+
+if [ $stage -le 18 ]; then
+  if [[ $(hostname -f) == *.clsp.jhu.edu ]] && [ ! -d $dir/egs/storage ]; then
+    utils/create_split_dir.pl \
+     /export/b0{5,6,7,8}/$USER/kaldi-data/egs/tedlium-$(date +'%m_%d_%H_%M')/s5_r2/$dir/egs/storage $dir/egs/storage
+  fi
+
+ steps/nnet3/chain/train.py --stage $train_stage \
+    --cmd "$decode_cmd" \
+    --feat.online-ivector-dir $train_ivector_dir \
+    --feat.cmvn-opts "--norm-means=false --norm-vars=false" \
+    --chain.xent-regularize $xent_regularize \
+    --chain.leaky-hmm-coefficient 0.1 \
+    --chain.l2-regularize 0.00005 \
+    --chain.apply-deriv-weights false \
+    --chain.lm-opts="--num-extra-lm-states=2000" \
+    --egs.dir "$common_egs_dir" \
+    --egs.opts "--frames-overlap-per-eg 0" \
+    --egs.chunk-width "$frames_per_chunk" \
+    --egs.chunk-left-context "$chunk_left_context" \
+    --egs.chunk-right-context "$chunk_right_context" \
+    --egs.chunk-left-context-initial "$chunk_left_context_initial" \
+    --egs.chunk-right-context-final "$chunk_right_context_final" \
+    --trainer.num-chunk-per-minibatch 128,64 \
+    --trainer.frames-per-iter 1500000 \
+    --trainer.max-param-change 2.0 \
+    --trainer.num-epochs 4 \
+    --trainer.deriv-truncate-margin 10 \
+    --trainer.optimization.shrink-value 0.99 \
+    --trainer.optimization.num-jobs-initial 2 \
+    --trainer.optimization.num-jobs-final 12 \
+    --trainer.optimization.initial-effective-lrate 0.001 \
+    --trainer.optimization.final-effective-lrate 0.0001 \
+    --trainer.optimization.momentum 0.0 \
+    --cleanup.remove-egs "$remove_egs" \
+    --feat-dir $train_data_dir \
+    --tree-dir $tree_dir \
+    --lat-dir $lat_dir \
+    --dir $dir
+fi
+
+
+
+if [ $stage -le 19 ]; then
+  # Note: it might appear that this data/lang_chain directory is mismatched, and it is as
+  # far as the 'topo' is concerned, but this script doesn't read the 'topo' from
+  # the lang directory.
+  utils/mkgraph.sh --self-loop-scale 1.0 data/lang $dir $dir/graph
+fi
+
+if [ $stage -le 20 ]; then
+  rm $dir/.error 2>/dev/null || true
+  for dset in dev test; do
+      (
+      steps/nnet3/decode.sh --num-threads 4 --nj $decode_nj --cmd "$decode_cmd" \
+          --acwt 1.0 --post-decode-acwt 10.0 \
+          --extra-left-context $extra_left_context  \
+          --extra-right-context $extra_right_context  \
+          --extra-left-context-initial $extra_left_context_initial \
+          --extra-right-context-final $extra_right_context_final \
+          --frames-per-chunk "$frames_per_chunk_primary" \
+          --online-ivector-dir exp/nnet3${nnet3_affix}/ivectors_${dset}_hires \
+          --scoring-opts "--min-lmwt 5 " \
+         $dir/graph data/${dset}_hires $dir/decode_${dset} || exit 1;
+      steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" data/lang data/lang_rescore \
+        data/${dset}_hires ${dir}/decode_${dset} ${dir}/decode_${dset}_rescore || exit 1
+    ) || touch $dir/.error &
+  done
+  wait
+  if [ -f $dir/.error ]; then
+    echo "$0: something went wrong in decoding"
+    exit 1
+  fi
+fi
+
+
+if [ $stage -le 21 ]; then
+  # 'looped' decoding.  we didn't write a -parallel version of this program yet,
+  # so it will take a bit longer as the --num-threads option is not supported.
+  # we just hardcode the --frames-per-chunk option as it doesn't have to
+  # match any value used in training, and it won't affect the results very much (unlike
+  # regular decoding)... [it will affect them slightly due to differences in the
+  # iVector extraction; probably smaller will be worse as it sees less of the future,
+  # but in a real scenario, long chunks will introduce excessive latency].
+  rm $dir/.error 2>/dev/null || true
+  for dset in dev test; do
+      (
+      steps/nnet3/decode_looped.sh --nj $decode_nj --cmd "$decode_cmd" \
+          --acwt 1.0 --post-decode-acwt 10.0 \
+          --extra-left-context-initial $extra_left_context_initial \
+          --frames-per-chunk 30 \
+          --online-ivector-dir exp/nnet3${nnet3_affix}/ivectors_${dset}_hires \
+          --scoring-opts "--min-lmwt 5 " \
+         $dir/graph data/${dset}_hires $dir/decode_looped_${dset} || exit 1;
+      steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" data/lang data/lang_rescore \
+        data/${dset}_hires ${dir}/decode_looped_${dset} ${dir}/decode_looped_${dset}_rescore || exit 1
+    ) || touch $dir/.error &
+  done
+  wait
+  if [ -f $dir/.error ]; then
+    echo "$0: something went wrong in decoding"
+    exit 1
+  fi
+fi
+
+
+if $test_online_decoding && [ $stage -le 22 ]; then
+  # note: if the features change (e.g. you add pitch features), you will have to
+  # change the options of the following command line.
+  steps/online/nnet3/prepare_online_decoding.sh \
+       --mfcc-config conf/mfcc_hires.conf \
+       data/lang_chain exp/nnet3${nnet3_affix}/extractor ${dir} ${dir}_online
+
+  rm $dir/.error 2>/dev/null || true
+  for dset in dev test; do
+    (
+      # note: we just give it "$dset" as it only uses the wav.scp, the
+      # feature type does not matter.
+
+      steps/online/nnet3/decode.sh --nj $decode_nj --cmd "$decode_cmd" \
+          --extra-left-context-initial $extra_left_context_initial \
+          --acwt 1.0 --post-decode-acwt 10.0 \
+          --scoring-opts "--min-lmwt 5 " \
+         $dir/graph data/${dset} ${dir}_online/decode_${dset} || exit 1;
+      steps/lmrescore_const_arpa.sh --cmd "$decode_cmd" data/lang data/lang_rescore \
+        data/${dset}_hires ${dir}_online/decode_${dset} ${dir}_online/decode_${dset}_rescore || exit 1
+    ) || touch $dir/.error &
+  done
+  wait
+  if [ -f $dir/.error ]; then
+    echo "$0: something went wrong in decoding"
+    exit 1
+  fi
+fi
+
+
+
+exit 0
diff --git a/egs/wsj/s5/steps/libs/nnet3/xconfig/attention.py b/egs/wsj/s5/steps/libs/nnet3/xconfig/attention.py
new file mode 100644 (file)
index 0000000..1df56a7
--- /dev/null
@@ -0,0 +1,241 @@
+# Copyright 2017    Johns Hopkins University (Dan Povey)
+#           2017    Hossein Hadian
+# Apache 2.0.
+
+""" This module has the implementation of attention layers.
+"""
+
+from __future__ import print_function
+import math
+import re
+import sys
+from libs.nnet3.xconfig.basic_layers import XconfigLayerBase
+
+# This class is for parsing lines like
+#  'attention-renorm-layer num-heads=10 value-dim=50 key-dim=50 time-stride=3 num-left-inputs=5 num-right-inputs=2.'
+#
+# Parameters of the class, and their defaults:
+#   input='[-1]'               [Descriptor giving the input of the layer.]
+#   self-repair-scale=1.0e-05  [Affects relu, sigmoid and tanh layers.]
+#   learning-rate-factor=1.0   [This can be used to make the affine component
+#                               train faster or slower].
+#   Documentation for the rest of the parameters (related to the
+#   attention component) can be found in nnet-attention-component.h
+
+class XconfigAttentionLayer(XconfigLayerBase):
+    def __init__(self, first_token, key_to_value, prev_names = None):
+        # Here we just list some likely combinations.. you can just add any
+        # combinations you want to use, to this list.
+        assert first_token in ['attention-renorm-layer',
+                               'attention-relu-renorm-layer',
+                               'relu-renorm-attention-layer']
+        XconfigLayerBase.__init__(self, first_token, key_to_value, prev_names)
+
+    def set_default_configs(self):
+        # note: self.config['input'] is a descriptor, '[-1]' means output
+        # the most recent layer.
+        self.config = { 'input':'[-1]',
+                        'dim': -1,
+                        'max-change' : 0.75,
+                        'self-repair-scale' : 1.0e-05,
+                        'target-rms' : 1.0,
+                        'learning-rate-factor' : 1.0,
+                        'ng-affine-options' : '',
+                        'num-left-inputs-required': -1,
+                        'num-right-inputs-required': -1,
+                        'output-context': True,
+                        'time-stride': 1,
+                        'num-heads': 1,
+                        'key-dim': -1,
+                        'key-scale': 0.0,
+                        'value-dim': -1,
+                        'num-left-inputs': -1,
+                        'num-right-inputs': -1,
+                        'dropout-proportion': 0.5}  # dropout-proportion only
+                                                    # affects layers with
+                                                    # 'dropout' in the name.
+
+    def check_configs(self):
+        if self.config['self-repair-scale'] < 0.0 or self.config['self-repair-scale'] > 1.0:
+            raise RuntimeError("self-repair-scale has invalid value {0}"
+                               .format(self.config['self-repair-scale']))
+        if self.config['target-rms'] < 0.0:
+            raise RuntimeError("target-rms has invalid value {0}"
+                               .format(self.config['target-rms']))
+        if self.config['learning-rate-factor'] <= 0.0:
+            raise RuntimeError("learning-rate-factor has invalid value {0}"
+                               .format(self.config['learning-rate-factor']))
+        for conf in ['value-dim', 'key-dim',
+                     'num-left-inputs', 'num-right-inputs']:
+            if self.config[conf] < 0:
+                raise RuntimeError("{0} has invalid value {1}"
+                                   .format(conf, self.config[conf]))
+        if self.config['key-scale'] == 0.0:
+            self.config['key-scale'] = 1.0 / math.sqrt(self.config['key-dim'])
+
+    def output_name(self, auxiliary_output=None):
+        # at a later stage we might want to expose even the pre-nonlinearity
+        # vectors
+        assert auxiliary_output == None
+
+        split_layer_name = self.layer_type.split('-')
+        assert split_layer_name[-1] == 'layer'
+        last_nonlinearity = split_layer_name[-2]
+        # return something like: layer3.renorm
+        return '{0}.{1}'.format(self.name, last_nonlinearity)
+
+    def attention_input_dim(self):
+        context_dim = (self.config['num-left-inputs'] +
+                       self.config['num-right-inputs'] + 1)
+        num_heads = self.config['num-heads']
+        key_dim = self.config['key-dim']
+        value_dim = self.config['value-dim']
+        query_dim = key_dim + context_dim;
+        return num_heads * (key_dim + value_dim + query_dim)
+
+    def attention_output_dim(self):
+        context_dim = (self.config['num-left-inputs'] +
+                       self.config['num-right-inputs'] + 1)
+        num_heads = self.config['num-heads']
+        value_dim = self.config['value-dim']
+        return (num_heads *
+                (value_dim +
+                 (context_dim if self.config['output-context'] else 0)))
+
+    def output_dim(self, auxiliary_output = None):
+      return self.attention_output_dim()
+
+    def get_full_config(self):
+        ans = []
+        config_lines = self._generate_config()
+
+        for line in config_lines:
+            for config_name in ['ref', 'final']:
+                # we do not support user specified matrices in this layer
+                # so 'ref' and 'final' configs are the same.
+                ans.append((config_name, line))
+        return ans
+
+
+    def _generate_config(self):
+        split_layer_name = self.layer_type.split('-')
+        assert split_layer_name[-1] == 'layer'
+        nonlinearities = split_layer_name[:-1]
+
+        # by 'descriptor_final_string' we mean a string that can appear in
+        # config-files, i.e. it contains the 'final' names of nodes.
+        input_desc = self.descriptors['input']['final-string']
+        input_dim = self.descriptors['input']['dim']
+
+        # the child classes e.g. tdnn might want to process the input
+        # before adding the other components
+
+        return self._add_components(input_desc, input_dim, nonlinearities)
+
+    def _add_components(self, input_desc, input_dim, nonlinearities):
+        dim = self.attention_input_dim()
+        self_repair_scale = self.config['self-repair-scale']
+        target_rms = self.config['target-rms']
+        max_change = self.config['max-change']
+        ng_affine_options = self.config['ng-affine-options']
+        learning_rate_factor=self.config['learning-rate-factor']
+        learning_rate_option=('learning-rate-factor={0}'.format(learning_rate_factor)
+                              if learning_rate_factor != 1.0 else '')
+
+        configs = []
+        # First the affine node.
+        line = ('component name={0}.affine'
+                ' type=NaturalGradientAffineComponent'
+                ' input-dim={1}'
+                ' output-dim={2}'
+                ' max-change={3}'
+                ' {4} {5} '
+                ''.format(self.name, input_dim, dim,
+                          max_change, ng_affine_options,
+                          learning_rate_option))
+        configs.append(line)
+
+        line = ('component-node name={0}.affine'
+                ' component={0}.affine input={1}'
+                ''.format(self.name, input_desc))
+        configs.append(line)
+        cur_node = '{0}.affine'.format(self.name)
+
+        for nonlinearity in nonlinearities:
+            if nonlinearity == 'relu':
+                line = ('component name={0}.{1}'
+                        ' type=RectifiedLinearComponent dim={2}'
+                        ' self-repair-scale={3}'
+                        ''.format(self.name, nonlinearity, dim,
+                            self_repair_scale))
+
+            elif nonlinearity == 'attention':
+                line = ('component name={0}.{1}'
+                        ' type=RestrictedAttentionComponent'
+                        ' value-dim={2}'
+                        ' key-dim={3}'
+                        ' num-left-inputs={4}'
+                        ' num-right-inputs={5}'
+                        ' num-left-inputs-required={6}'
+                        ' num-right-inputs-required={7}'
+                        ' output-context={8}'
+                        ' time-stride={9}'
+                        ' num-heads={10}'
+                        ''.format(self.name, nonlinearity,
+                                  self.config['value-dim'],
+                                  self.config['key-dim'],
+                                  self.config['num-left-inputs'],
+                                  self.config['num-right-inputs'],
+                                  self.config['num-left-inputs-required'],
+                                  self.config['num-right-inputs-required'],
+                                  self.config['output-context'],
+                                  self.config['time-stride'],
+                                  self.config['num-heads']))
+                dim = self.attention_output_dim()
+
+            elif nonlinearity == 'sigmoid':
+                line = ('component name={0}.{1}'
+                        ' type=SigmoidComponent dim={2}'
+                        ' self-repair-scale={3}'
+                        ''.format(self.name, nonlinearity, dim,
+                            self_repair_scale))
+
+            elif nonlinearity == 'tanh':
+                line = ('component name={0}.{1}'
+                        ' type=TanhComponent dim={2}'
+                        ' self-repair-scale={3}'
+                        ''.format(self.name, nonlinearity, dim,
+                            self_repair_scale))
+
+            elif nonlinearity == 'renorm':
+                line = ('component name={0}.{1}'
+                        ' type=NormalizeComponent dim={2}'
+                        ' target-rms={3}'
+                        ''.format(self.name, nonlinearity, dim,
+                            target_rms))
+
+            elif nonlinearity == 'batchnorm':
+                line = ('component name={0}.{1}'
+                        ' type=BatchNormComponent dim={2}'
+                        ' target-rms={3}'
+                        ''.format(self.name, nonlinearity, dim,
+                            target_rms))
+
+            elif nonlinearity == 'dropout':
+                line = ('component name={0}.{1} type=DropoutComponent '
+                           'dim={2} dropout-proportion={3}'.format(
+                               self.name, nonlinearity, dim,
+                               self.config['dropout-proportion']))
+
+            else:
+                raise RuntimeError("Unknown nonlinearity type: {0}"
+                                   .format(nonlinearity))
+
+            configs.append(line)
+            line = ('component-node name={0}.{1}'
+                    ' component={0}.{1} input={2}'
+                    ''.format(self.name, nonlinearity, cur_node))
+
+            configs.append(line)
+            cur_node = '{0}.{1}'.format(self.name, nonlinearity)
+        return configs
index ae57b219997859a26c1c6947ee5b6f1708da0129..97d20c4f7079960536150526dcc754ed4d724e2e 100644 (file)
@@ -5,5 +5,6 @@
 
 from basic_layers import *
 from convolution import *
+from attention import *
 from lstm import *
 from stats_layer import *
index 44ceff9eea47738702a97d0e4678d8a408994b0a..5cd9f2beef1abd65bafb311d777d80f0047ec550 100644 (file)
@@ -45,7 +45,10 @@ config_to_layer = {
         'conv-relu-batchnorm-dropout-layer': xlayers.XconfigConvLayer,
         'conv-relu-dropout-layer': xlayers.XconfigConvLayer,
         'res-block': xlayers.XconfigResBlock,
-        'channel-average-layer': xlayers.ChannelAverageLayer
+        'channel-average-layer': xlayers.ChannelAverageLayer,
+        'attention-renorm-layer': xlayers.XconfigAttentionLayer,
+        'attention-relu-renorm-layer': xlayers.XconfigAttentionLayer,
+        'relu-renorm-attention-layer': xlayers.XconfigAttentionLayer
 }
 
 # Turn a config line and a list of previous layers into
index 3fc16cde1f4f0be7a5614f831856e619ee34f93d..3236c52d60ffbab5c9a7731592a6697091aa047a 100644 (file)
@@ -12,7 +12,7 @@ TESTFILES = natural-gradient-online-test nnet-graph-test \
   nnet-compile-utils-test nnet-nnet-test nnet-utils-test \
   nnet-compile-test nnet-analyze-test nnet-compute-test \
   nnet-optimize-test nnet-derivative-test nnet-example-test \
-  nnet-common-test convolution-test
+  nnet-common-test convolution-test attention-test
 
 OBJFILES = nnet-common.o nnet-compile.o nnet-component-itf.o \
   nnet-simple-component.o \
@@ -30,7 +30,8 @@ OBJFILES = nnet-common.o nnet-compile.o nnet-component-itf.o \
   discriminative-training.o nnet-discriminative-training.o \
   nnet-compile-looped.o decodable-simple-looped.o \
   decodable-online-looped.o convolution.o \
-  nnet-convolutional-component.o
+  nnet-convolutional-component.o attention.o \
+  nnet-attention-component.o
 
 
 LIBNAME = kaldi-nnet3
diff --git a/src/nnet3/attention-test.cc b/src/nnet3/attention-test.cc
new file mode 100644 (file)
index 0000000..c07971d
--- /dev/null
@@ -0,0 +1,256 @@
+// nnet3/attention-test.cc
+
+// Copyright      2017  Hossein Hadian
+//                2017  Johns Hopkins University (author: Daniel Povey)
+
+// See ../../COPYING for clarification regarding multiple authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//  http://www.apache.org/licenses/LICENSE-2.0
+//
+// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+// MERCHANTABLITY OR NON-INFRINGEMENT.
+// See the Apache 2 License for the specific language governing permissions and
+// limitations under the License.
+
+#include "nnet3/attention.h"
+#include "util/common-utils.h"
+
+namespace kaldi {
+namespace nnet3 {
+namespace attention {
+
+
+// (*C)(i, j) = alpha * VecVec(A.Row(i), B.Row(i + j * row_shift))
+void GetAttentionDotProductsSimple(BaseFloat alpha,
+                                   const CuMatrixBase<BaseFloat> &A,
+                                   const CuMatrixBase<BaseFloat> &B,
+                                   CuMatrixBase<BaseFloat> *C) {
+  KALDI_ASSERT(A.NumCols() == B.NumCols() &&
+               A.NumRows() == C->NumRows());
+  int32 input_num_cols = A.NumCols(),
+      num_extra_rows = B.NumRows() - A.NumRows(),
+      context_dim = C->NumCols();
+  KALDI_ASSERT(num_extra_rows > 0 && num_extra_rows % (context_dim - 1) == 0);
+  int32 row_shift = num_extra_rows / (context_dim - 1);
+  for (int32 i = 0; i < C->NumRows(); i++) {
+    for (int32 j = 0; j < C->NumCols(); j++) {
+      (*C)(i, j) = 0.0;
+      for (int32 k = 0; k < input_num_cols; k++) {
+        (*C)(i, j) += alpha * A(i, k) * B(i + (j * row_shift), k);
+      }
+    }
+  }
+}
+
+//     A->Row(i) += \sum_k alpha * C(i, k) * B.Row(i + k * row_shift).
+void ApplyScalesToOutputSimple(BaseFloat alpha,
+                               const CuMatrixBase<BaseFloat> &B,
+                               const CuMatrixBase<BaseFloat> &C,
+                               CuMatrixBase<BaseFloat> *A) {
+  KALDI_ASSERT(A->NumCols() == B.NumCols() &&
+               A->NumRows() == C.NumRows());
+  int32 num_extra_rows = B.NumRows() - A->NumRows(),
+      context_dim = C.NumCols();
+  KALDI_ASSERT(num_extra_rows > 0 && num_extra_rows % (context_dim - 1) == 0);
+  int32 row_shift = num_extra_rows / (context_dim - 1);
+  for (int32 i = 0; i < A->NumRows(); i++) {
+    for (int32 j = 0; j < A->NumCols(); j++) {
+      for (int32 k = 0; k < context_dim; k++) {
+        (*A)(i, j) += alpha * C(i, k) * B(i + (k * row_shift), j);
+      }
+    }
+  }
+}
+
+//     B->Row(i + j * row_shift) += alpha * C(i, j) * A.Row(i).
+void ApplyScalesToInputSimple(BaseFloat alpha,
+                              const CuMatrixBase<BaseFloat> &A,
+                              const CuMatrixBase<BaseFloat> &C,
+                              CuMatrixBase<BaseFloat> *B) {
+  KALDI_ASSERT(A.NumCols() == B->NumCols() &&
+               A.NumRows() == C.NumRows());
+  int32 num_extra_rows = B->NumRows() - A.NumRows(),
+      context_dim = C.NumCols();
+  KALDI_ASSERT(num_extra_rows > 0 && num_extra_rows % (context_dim - 1) == 0);
+  int32 row_shift = num_extra_rows / (context_dim - 1);
+  for (int32 i = 0; i < A.NumRows(); i++) {
+    for (int32 j = 0; j < A.NumCols(); j++) {
+      for (int32 k = 0; k < context_dim; k++) {
+        (*B)(i + (k * row_shift), j) += alpha * C(i, k) * A(i, j);
+      }
+    }
+  }
+}
+
+void UnitTestAttentionDotProductAndAddScales() {
+  int32 output_num_rows = RandInt(1, 50), input_num_cols = RandInt(1, 10),
+      row_shift = RandInt(1, 5), context_dim = RandInt(2, 5),
+      num_extra_rows = (context_dim - 1) * row_shift,
+      input_num_rows = output_num_rows + num_extra_rows;
+  BaseFloat alpha = 0.25 * RandInt(1, 5);
+  CuMatrix<BaseFloat> A(output_num_rows, input_num_cols),
+      B(input_num_rows, input_num_cols),
+      C(output_num_rows, context_dim);
+
+  B.SetRandn();
+  C.SetRandn();
+  A.Set(0.0);
+  CuMatrix<BaseFloat> A2(A);
+  ApplyScalesToOutput(alpha, B, C, &A);
+  ApplyScalesToOutputSimple(alpha, B, C, &A2);
+  AssertEqual(A, A2);
+
+  CuMatrix<BaseFloat> C2(C);
+  GetAttentionDotProductsSimple(alpha, A, B, &C);
+  GetAttentionDotProducts(alpha, A, B, &C2);
+  AssertEqual(C, C2);
+
+  CuMatrix<BaseFloat> B2(B);
+  ApplyScalesToInput(alpha, A, C, &B);
+  ApplyScalesToInputSimple(alpha, A, C, &B2);
+  AssertEqual(B, B2);
+}
+
+void TestAttentionForwardBackward() {
+  BaseFloat key_scale = 0.5 * RandInt(1, 3);
+  BaseFloat epsilon = 1.0e-03;
+  int32 test_dim = 3;
+  bool output_context = (RandInt(0, 1) == 0);
+  int32 output_num_rows = RandInt(1, 50),
+      value_dim = RandInt(10, 30), key_dim = RandInt(10, 30),
+      row_shift = RandInt(1, 5), context_dim = RandInt(2, 5),
+      num_extra_rows = (context_dim - 1) * row_shift,
+      input_num_rows = output_num_rows + num_extra_rows,
+      query_dim = key_dim + context_dim;
+  CuMatrix<BaseFloat> keys(input_num_rows, key_dim),
+      queries(output_num_rows, query_dim),
+      values(input_num_rows, value_dim),
+      C(output_num_rows, context_dim),
+      output(output_num_rows, value_dim + (output_context ? context_dim : 0));
+
+
+  keys.SetRandn();
+  queries.SetRandn();
+  values.SetRandn();
+
+
+  AttentionForward(key_scale, keys, queries, values, &C, &output);
+
+  CuMatrix<BaseFloat> keys_deriv(input_num_rows, key_dim),
+      queries_deriv(output_num_rows, query_dim),
+      values_deriv(input_num_rows, value_dim),
+      output_deriv(output_num_rows, output.NumCols());
+
+  output_deriv.SetRandn();
+
+  AttentionBackward(key_scale, keys, queries, values, C,
+                    output_deriv, &keys_deriv, &queries_deriv,
+                    &values_deriv);
+
+  BaseFloat objf_baseline = TraceMatMat(output_deriv, output, kTrans);
+
+
+
+
+  {  // perturb the values and see if the objf changes as predicted.
+    Vector<BaseFloat> predicted_vec(test_dim), observed_vec(test_dim);
+    for (int32 i = 0; i < test_dim; i++) {
+      CuMatrix<BaseFloat> values2(input_num_rows, value_dim);
+      values2.SetRandn();
+      values2.Scale(epsilon);
+      BaseFloat predicted_delta_objf = TraceMatMat(values_deriv, values2, kTrans);
+      values2.AddMat(1.0, values);
+
+      output.SetZero();
+      AttentionForward(key_scale, keys, queries, values2, &C, &output);
+      BaseFloat objf2 = TraceMatMat(output_deriv, output, kTrans),
+          observed_delta_objf = objf2 - objf_baseline;
+      KALDI_LOG << "Changing values: predicted objf change is "
+                << predicted_delta_objf << ", observed objf change is "
+                << observed_delta_objf;
+      predicted_vec(i) = predicted_delta_objf;
+      observed_vec(i) = observed_delta_objf;
+    }
+    KALDI_ASSERT(predicted_vec.ApproxEqual(observed_vec, 0.1));
+  }
+
+  {  // perturb the keys and see if the objf changes as predicted.
+    Vector<BaseFloat> predicted_vec(test_dim), observed_vec(test_dim);
+    for (int32 i = 0; i < test_dim; i++) {
+      CuMatrix<BaseFloat> keys2(input_num_rows, key_dim);
+      keys2.SetRandn();
+      keys2.Scale(epsilon);
+      BaseFloat predicted_delta_objf = TraceMatMat(keys_deriv, keys2, kTrans);
+      keys2.AddMat(1.0, keys);
+
+      output.SetZero();
+      AttentionForward(key_scale, keys2, queries, values, &C, &output);
+      BaseFloat objf2 = TraceMatMat(output_deriv, output, kTrans),
+          observed_delta_objf = objf2 - objf_baseline;
+      KALDI_LOG << "Changing keys: predicted objf change is "
+                << predicted_delta_objf << ", observed objf change is "
+                << observed_delta_objf;
+      predicted_vec(i) = predicted_delta_objf;
+      observed_vec(i) = observed_delta_objf;
+    }
+    KALDI_ASSERT(predicted_vec.ApproxEqual(observed_vec, 0.1));
+  }
+
+
+  {  // perturb the queries and see if the objf changes as predicted.
+    Vector<BaseFloat> predicted_vec(test_dim), observed_vec(test_dim);
+    for (int32 i = 0; i < test_dim; i++) {
+      CuMatrix<BaseFloat> queries2(output_num_rows, query_dim);
+      queries2.SetRandn();
+      queries2.Scale(epsilon);
+      BaseFloat predicted_delta_objf = TraceMatMat(queries_deriv, queries2, kTrans);
+      queries2.AddMat(1.0, queries);
+
+      output.SetZero();
+      AttentionForward(key_scale, keys, queries2, values, &C, &output);
+      BaseFloat objf2 = TraceMatMat(output_deriv, output, kTrans),
+          observed_delta_objf = objf2 - objf_baseline;
+      KALDI_LOG << "Changing queries: predicted objf change is "
+                << predicted_delta_objf << ", observed objf change is "
+                << observed_delta_objf;
+      predicted_vec(i) = predicted_delta_objf;
+      observed_vec(i) = observed_delta_objf;
+    }
+    KALDI_ASSERT(predicted_vec.ApproxEqual(observed_vec, 0.1));
+  }
+}
+
+void UnitTestAttention() {
+  UnitTestAttentionDotProductAndAddScales();
+  TestAttentionForwardBackward();
+}
+
+
+} // namespace attention
+} // namespace nnet3
+} // namespace kaldi
+
+
+int main() {
+  using namespace kaldi;
+  using namespace kaldi::nnet3;
+  using namespace kaldi::nnet3::attention;
+  for (int32 loop = 0; loop < 2; loop++) {
+#if HAVE_CUDA == 1
+    CuDevice::Instantiate().SetDebugStrideMode(true);
+    if (loop == 0)
+      CuDevice::Instantiate().SelectGpuId("no"); // -1 means no GPU
+    else
+      CuDevice::Instantiate().SelectGpuId("optional"); // -2 .. automatic selection
+#endif
+    for (int32 i = 0; i < 5; i++) {
+      UnitTestAttention();
+    }
+  }
+}
diff --git a/src/nnet3/attention.cc b/src/nnet3/attention.cc
new file mode 100644 (file)
index 0000000..bd8cb6b
--- /dev/null
@@ -0,0 +1,247 @@
+// nnet3/attention.cc
+
+// Copyright      2017  Johns Hopkins University (author: Daniel Povey)
+//                      Hossein Hadian
+
+// See ../../COPYING for clarification regarding multiple authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//  http://www.apache.org/licenses/LICENSE-2.0
+//
+// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+// MERCHANTABLITY OR NON-INFRINGEMENT.
+// See the Apache 2 License for the specific language governing permissions and
+// limitations under the License.
+
+#include <iterator>
+#include <sstream>
+#include <iomanip>
+#include "nnet3/attention.h"
+#include "nnet3/nnet-parse.h"
+
+namespace kaldi {
+namespace nnet3 {
+namespace attention {
+
+
+void GetAttentionDotProducts(BaseFloat alpha,
+                             const CuMatrixBase<BaseFloat> &A,
+                             const CuMatrixBase<BaseFloat> &B,
+                             CuMatrixBase<BaseFloat> *C) {
+  KALDI_ASSERT(A.NumCols() == B.NumCols() &&
+               A.NumRows() == C->NumRows());
+  int32 num_output_rows = A.NumRows(),
+      input_num_cols = A.NumCols(),
+      num_extra_rows = B.NumRows() - A.NumRows(),
+      context_dim = C->NumCols();
+  KALDI_ASSERT(num_extra_rows > 0 && num_extra_rows % (context_dim - 1) == 0);
+  int32 row_shift = num_extra_rows / (context_dim - 1);
+  CuMatrix<BaseFloat> Ctrans(C->NumCols(),
+                             C->NumRows());
+  for (int32 o = 0; o < context_dim; o++) {
+    CuSubVector<BaseFloat> c_col(Ctrans, o);
+    CuSubMatrix<BaseFloat> B_part(B, o * row_shift, num_output_rows,
+                                  0, input_num_cols);
+    c_col.AddDiagMatMat(alpha, A, kNoTrans, B_part, kTrans, 0.0);
+  }
+  C->CopyFromMat(Ctrans, kTrans);
+}
+
+void ApplyScalesToOutput(BaseFloat alpha,
+                         const CuMatrixBase<BaseFloat> &B,
+                         const CuMatrixBase<BaseFloat> &C,
+                         CuMatrixBase<BaseFloat> *A) {
+  KALDI_ASSERT(A->NumCols() == B.NumCols() &&
+               A->NumRows() == C.NumRows());
+  int32 num_output_rows = A->NumRows(),
+      input_num_cols = A->NumCols(),
+      num_extra_rows = B.NumRows() - A->NumRows(),
+      context_dim = C.NumCols();
+  KALDI_ASSERT(num_extra_rows > 0 && num_extra_rows % (context_dim - 1) == 0);
+  int32 row_shift = num_extra_rows / (context_dim - 1);
+  CuMatrix<BaseFloat> Ctrans(C, kTrans);
+  for (int32 o = 0; o < context_dim; o++) {
+    CuSubVector<BaseFloat> c_col(Ctrans, o);
+    CuSubMatrix<BaseFloat> B_part(B, o * row_shift, num_output_rows,
+                                  0, input_num_cols);
+    A->AddDiagVecMat(alpha, c_col, B_part, kNoTrans, 1.0);
+  }
+}
+
+void ApplyScalesToInput(BaseFloat alpha,
+                        const CuMatrixBase<BaseFloat> &A,
+                        const CuMatrixBase<BaseFloat> &C,
+                        CuMatrixBase<BaseFloat> *B) {
+  KALDI_ASSERT(A.NumCols() == B->NumCols() &&
+               A.NumRows() == C.NumRows());
+  int32 num_output_rows = A.NumRows(),
+      input_num_cols = A.NumCols(),
+      num_extra_rows = B->NumRows() - A.NumRows(),
+      context_dim = C.NumCols();
+  KALDI_ASSERT(num_extra_rows > 0 && num_extra_rows % (context_dim - 1) == 0);
+  int32 row_shift = num_extra_rows / (context_dim - 1);
+  CuMatrix<BaseFloat> Ctrans(C, kTrans);
+  for (int32 o = 0; o < context_dim; o++) {
+    CuSubVector<BaseFloat> c_col(Ctrans, o);
+    CuSubMatrix<BaseFloat> B_part(*B, o * row_shift, num_output_rows,
+                                  0, input_num_cols);
+    B_part.AddDiagVecMat(alpha, c_col, A, kNoTrans, 1.0);
+  }
+}
+
+void AttentionForward(BaseFloat key_scale,
+                      const CuMatrixBase<BaseFloat> &keys,
+                      const CuMatrixBase<BaseFloat> &queries,
+                      const CuMatrixBase<BaseFloat> &values,
+                      CuMatrixBase<BaseFloat> *c,
+                      CuMatrixBase<BaseFloat> *output) {
+  // First check the dimensions and values.
+  KALDI_ASSERT(key_scale > 0.0);
+  int32 num_input_rows = keys.NumRows(),
+      key_dim = keys.NumCols(),
+      num_output_rows = queries.NumRows(),
+      context_dim = queries.NumCols() - key_dim,
+      value_dim = values.NumCols();
+  KALDI_ASSERT(num_input_rows > 0 && key_dim > 0 &&
+               num_input_rows > num_output_rows &&
+               context_dim > 0 &&
+               (num_input_rows - num_output_rows) % (context_dim - 1) == 0 &&
+               values.NumRows() == num_input_rows);
+  KALDI_ASSERT(c->NumRows() == num_output_rows &&
+               c->NumCols() == context_dim);
+  KALDI_ASSERT(output->NumRows() == num_output_rows &&
+               (output->NumCols() == value_dim ||
+                output->NumCols() == value_dim + context_dim));
+
+  CuSubMatrix<BaseFloat> queries_key_part(
+      queries, 0, num_output_rows,
+      0, key_dim),
+      queries_context_part(
+          queries, 0, num_output_rows,
+          key_dim, context_dim);
+
+  GetAttentionDotProducts(key_scale,
+                          queries_key_part,
+                          keys, c);
+  // think of 'queries_context_part' as a position-dependent bias term.
+  c->AddMat(1.0, queries_context_part);
+  // compute the soft-max function.  Up till this point, 'c'
+  // actually contained what in attention.h we called 'b', which is
+  // the input to the softmax.
+  c->ApplySoftMaxPerRow(*c);
+
+
+  // the part of the output that is weighted
+  // combinations of the input values.
+  CuSubMatrix<BaseFloat> output_values_part(
+      *output, 0, num_output_rows, 0, value_dim);
+
+  ApplyScalesToOutput(1.0, values, *c, &output_values_part);
+
+
+  if (output->NumCols() == value_dim + context_dim) {
+    CuSubMatrix<BaseFloat> output_context_part(
+        *output, 0, num_output_rows, value_dim, context_dim);
+    output_context_part.CopyFromMat(*c);
+  }
+}
+
+void AttentionBackward(BaseFloat key_scale,
+                       const CuMatrixBase<BaseFloat> &keys,
+                       const CuMatrixBase<BaseFloat> &queries,
+                       const CuMatrixBase<BaseFloat> &values,
+                       const CuMatrixBase<BaseFloat> &c,
+                       const CuMatrixBase<BaseFloat> &output_deriv,
+                       CuMatrixBase<BaseFloat> *keys_deriv,
+                       CuMatrixBase<BaseFloat> *queries_deriv,
+                       CuMatrixBase<BaseFloat> *values_deriv) {
+
+  // First check the dimensions and values.
+  KALDI_ASSERT(key_scale > 0.0);
+  int32 num_input_rows = keys.NumRows(),
+      key_dim = keys.NumCols(),
+      num_output_rows = queries.NumRows(),
+      context_dim = queries.NumCols() - key_dim,
+      value_dim = values.NumCols();
+  KALDI_ASSERT(num_input_rows > 0 && key_dim > 0 &&
+               num_input_rows > num_output_rows &&
+               context_dim > 0 &&
+               (num_input_rows - num_output_rows) % (context_dim - 1) == 0 &&
+               values.NumRows() == num_input_rows);
+  KALDI_ASSERT(SameDim(keys, *keys_deriv) &&
+               SameDim(queries, *queries_deriv) &&
+               SameDim(values, *values_deriv));
+
+  KALDI_ASSERT(c.NumRows() == num_output_rows &&
+               c.NumCols() == context_dim);
+  KALDI_ASSERT(output_deriv.NumRows() == num_output_rows &&
+               (output_deriv.NumCols() == value_dim ||
+                output_deriv.NumCols() == value_dim + context_dim));
+
+  CuMatrix<BaseFloat> c_deriv(num_output_rows, context_dim,
+                              kUndefined);
+
+  CuSubMatrix<BaseFloat> output_values_part_deriv(
+      output_deriv, 0, num_output_rows, 0, value_dim);
+  // This is the backprop w.r.t. the forward-pass statement:
+  // ApplyScalesToOutput(1.0, values, *c, &output_values_part);
+  GetAttentionDotProducts(1.0, output_values_part_deriv,
+                          values, &c_deriv);
+
+  if (output_deriv.NumCols() == value_dim + context_dim) {
+    CuSubMatrix<BaseFloat> output_deriv_context_part(
+        output_deriv, 0, num_output_rows, value_dim, context_dim);
+    // this is the backprop w.r.t. the
+    // forward-pass statement: output_context_part.CopyFromMat(*c);
+    c_deriv.AddMat(1.0, output_deriv_context_part);
+  }
+
+  // Propagate the derivatives back through the softmax nonlinearity,
+  // in-place; this is the backprop w.r.t. the statement
+  // 'c->SoftMaxPerRow(*c);'.  From this point on, c_deriv actually
+  // contains the derivative to the pre-softmax values which we call
+  // 'b' in the math.
+  c_deriv.DiffSoftmaxPerRow(c, c_deriv);
+
+
+  CuSubMatrix<BaseFloat> queries_key_part(
+      queries, 0, num_output_rows,
+      0, key_dim),
+      queries_key_part_deriv(
+          *queries_deriv, 0, num_output_rows,
+          0, key_dim),
+      queries_context_part_deriv(
+          *queries_deriv, 0, num_output_rows,
+          key_dim, context_dim);
+
+  // Below is the backprop corresponding to the forward-propagation command:
+  // c->AddMat(1.0, queries_context_part)
+  queries_context_part_deriv.AddMat(1.0, c_deriv);
+
+  // The following statement is the part of the backprop w.r.t. the
+  // statement:
+  // GetAttentionDotProducts(key_scale, queries_key_part, keys, c);
+  // which propagates the derivative back to 'queries_key_part'.
+  ApplyScalesToOutput(key_scale, keys, c_deriv, &queries_key_part_deriv);
+
+  // The following statement is the part of the backprop w.r.t. the
+  // statement:
+  // GetAttentionDotProducts(key_scale, queries_key_part, keys, c);
+  // which propagates the derivative back to 'keys'.
+  ApplyScalesToInput(key_scale, queries_key_part, c_deriv, keys_deriv);
+
+  // The followign statement is the part of the backprop w.r.t.
+  // the statement:
+  // ApplyScalesToOutput(1.0, values, *c, &output_values_part);
+  // which propagates the derivative back to 'values'.
+  ApplyScalesToInput(1.0, output_values_part_deriv, c,  values_deriv);
+}
+
+} // namespace attention
+} // namespace nnet3
+} // namespace kaldi
diff --git a/src/nnet3/attention.h b/src/nnet3/attention.h
new file mode 100644 (file)
index 0000000..0993b78
--- /dev/null
@@ -0,0 +1,330 @@
+// nnet3/attention.h
+
+// Copyright      2017  Johns Hopkins University (author: Daniel Povey)
+//                      Hossein Hadian
+
+// See ../../COPYING for clarification regarding multiple authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//  http://www.apache.org/licenses/LICENSE-2.0
+//
+// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+// MERCHANTABLITY OR NON-INFRINGEMENT.
+// See the Apache 2 License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef KALDI_NNET3_ATTENTION_H_
+#define KALDI_NNET3_ATTENTION_H_
+
+#include "base/kaldi-common.h"
+#include "util/common-utils.h"
+#include "itf/options-itf.h"
+#include "matrix/matrix-lib.h"
+#include "cudamatrix/cu-matrix-lib.h"
+#include "nnet3/nnet-common.h"
+#include "nnet3/convolution.h"
+
+#include <iostream>
+
+namespace kaldi {
+namespace nnet3 {
+namespace attention {
+
+/// @file  attention.h
+///
+/// This file contains the lower-level interface for self-attention.
+/// This is a form of self-attention, inspired by Google's paper
+/// "Attention is all you need", but implemented in a way that's more
+/// obviously suitable for speech tasks.  The main difference is that
+/// instead of taking as input *all frames* from the previous layer,
+/// we accept a limited grid of frames (so the left-context and
+/// right-context are finite).  Also time-encoding is handled in a different
+/// way-- we encode the time as a relative offset.
+
+
+
+// Our attention is "multi-head", like in Google's paper.  Note: we're basically
+// implementing multi-head attention as a fixed nonlinearity, with the actual
+// parameters relegated to the previous layer.  That is, the attention layer
+// won't have any parameters of its own, but the parameters of the preceding
+// layer will be interpretable as the parameters.  It doesn't change what's
+// computed, it just affects how the neural net is divided into components.
+//
+//  * Basic restricted self-attention (without positional encoding).
+//
+// To explain what's going on, we start with the simplest form of attention:
+// single-head, and no positional encoding, but with restricted context.  For purposes
+// of exposition we assume that the time offsets we need form a contigous
+// range, i.e. with time-stride == 1; the code does have the notion of a stride (you'll
+// see later).
+//
+// Using notation similar to the Google paper, suppose we have a time-sequence
+// of inputs, and the inputs are (keys, values and queries):
+//
+//   k_t, v_t, q_t
+//
+// where k_t and q_t are vectors of dimension 'key_dim' and v_t is a vector
+// of dimension 'value_dim' (you may choose to make this the same as key_dim, but
+// that isn't a constraint).
+
+// Let's make num_left_inputs and num_right_inputs be the number of
+// left-context and right-context frames required, and for some t,
+// let input_indexes(t) be the set
+//  [ t - num_left_inputs, t - num_left_inputs + 1, ... t + num_right_inputs].
+// To evaluate the output (which we'll write u_t), we need the query
+// value q_t, plus the keys and values k_s and v_s for all s in input_indexes(t).
+// If the inputs are not available for some subset of input_indexes(t),
+// we just let them be zeros; the network can learn to ignore them if it wants,
+// but making them zeros is simpler to implement.
+//
+//
+// Anyway, the output u_t (without positional encoding yet) is:
+//
+//  u_t := \sum_{s in input_indexes(t)}  Z_t exp(q_t . k_s) v_s
+//
+// where Z_t is 1/(\sum_s exp(q_t . k_s)).  We'll handle scaling
+// issues (the 1/sqrt(dim) factor in the Google paper) later on,
+// by scaling the keys.
+//
+//
+// * Positional encoding
+// We now explain how we include positional encoding in the model.
+//
+//
+// Let context_dim = 1 + num_left_inputs + num_right_inputs.
+// Let v be a vector, and let the function Extend(v, o) (where
+// 0 <= o < context_dim) extend v with extra dimensions
+// that encode the time-offset.  To be precise, we have
+//
+//  Extend(v, o) = Append(v, u_o)
+//
+// where u_o is a unit vector of dimension context_dim that is nonzero in the
+// o'th dimension (assuming zero-based indexing).
+//
+// So when we add the positional encoding (and the scale on the keys), we replace
+// the equation:
+//  u_t := \sum_{s in input_indexes(t)}  Z_t exp(q_t . k_s) v_s
+// with:
+//  u_t := \sum_{s in input_indexes(t)}  Z_t exp(alpha q_t . Extend(key-scale * k_s, s - t + num_left_inputs)) Extend(v_s, s - t + num_left_inputs)
+//
+// (we won't actually physically extend the vectors as we compute this,
+// we'll do it a different way, but it's equivalent to what we wrote
+// above. The dimension of q_t is key_dim + context_dim, and the dimension
+// of the output u_t is value_dim + context_dim.
+//
+//
+// * Multi-head attention
+//
+// The attention component if we had a single head, would have an input dimension
+// of (2*key_dim + context_dim + value_dim), which would be interpreted
+// as Append(k_t, q_t, v_t), of dimensions respectively
+// (key_dim, key_dim + context_dim, value_dim).  It would have an output
+// dimension of value_dim + context_dim.
+//
+// In any case, the multi-head version has input and output dimension that
+// are larger by a factor of 'num_heads', and which is equivalent to
+// several single-head components appended together.
+//
+//
+//
+//  * The actual calculation
+//
+// Let's assume that we might have multiple independent sequences; we'll
+// call this 'num_images' because we're borrowing certain structures from
+// the convolution code.
+
+// The input is formatted as a matrix, whose NumRows() could be written as
+// num_images * num_t_in, where num_t_in is the number of distinct input 't'
+// values, and whose output is num_images * num_t_out.  To keep it simple we'll
+// explain this under the assumption that we don't have any 't' stride in the
+// attention (t_stride == 1 in the code), and that num_heads == 1; both of
+// those are fairly simple modifications to the basic scheme.
+// The image (normally 'n') index has a higher stride than the 't' index in
+// both the input and the output.  We assume that there is 'enough'
+// context of the input to compute all required offsets of the output.
+//
+// Define the intermediate quantity b_{t,o}, which you can think of
+// as the input to softmax; the index 't' is the output time-index
+// index at the output, and o ranges from 0 to context_dim - 1.
+//
+//    b_{t,o} =  q_t . Extend(key-scale * k_{t + o - num_left_inputs}, o)
+//
+// To get rid of the Extend() expressions, define sub-ranges of q_t by
+// writing q_t = Append(r_t, s_t) where r_t is of dimension 'key_dim'
+// and s_t is of dimension context_dim.
+//
+//    b_{t,o} =   s_{t,o}  +  key-scale (r_t . k_{t+o-num_left_inputs})  [eqn:b]
+//
+// The 'b' quantity is the input to the softmax.  Define
+//     c_t = Softmax(b_t)
+// so \sum_o c_{t,o} = 1.0.  These are the weights on the values.
+//
+//
+//  The output can be written as:
+//        u_t :=  \sum_o c_{t,o} Extend(v_{t+o-num_left_inputs}, o)
+//  but we can write this in a form more suitable for computation as:
+//     u_t :=  Append(\sum_o c_{t,o} v_{t+o-num_left_inputs},  c_t)     [eqn:u]
+//
+//
+//  * Implementation
+//
+// The most time-consuming parts of this computation, we imagine, would be the
+// dot-products in [eqn:b] and the weighted sum (\sum_o) in [eqn:u].  They have
+// an awkward band-diagonal structure that would not be particularly convenient
+// to implement using CUBLAS and the like; I don't believe the relevant operations
+// exist in the BLAS interface, at least for [eqn:u].
+//
+// In the future I hope to implement this with block-wise matrix
+// multiplies-- imagine covering the band-diagonal part of a matrix with
+// rectangular blocks in such a way that all the nonzero elements are covered,
+// but the blocks might go over the zero parts a bit.   This could be done with
+// Or perhaps we can figure out how to implement the block-diagonal matrix
+// multiplies in CUDA.
+
+
+
+/**
+   This function is a utility function that is at the core of how we implement
+   attention.  It may in future need to be renamed and possibly moved into the
+   cudamatrix directory and implemented in CUDA.  The current implementation is
+   quite inefficient.  We can also consider doing a complete redesign of how the
+   implementation works, such that this function doesn't exist at all; or we
+   could have a batched version of this function that would operate on a batch
+   of A, B and C at once (or a "strided, batched" version where the difference
+   between the members of the batch is expressed as a stride).
+
+   This function implements a special operation that you could view as some kind
+   of matrix multiplication where only a band of the product is retained.
+
+   The inputs A and B must have the same number of columns
+   (A.NumCols() == B.NumCols()), and A and C must have the same
+   number of rows (A.NumRows() == C->NumRows()).  The number of
+   rows of B must exceed the number of rows of A.  Define
+      num_extra_rows = B.NumRows() - A.NumRows().
+   Then C.NumCols() - 1 must divide num_extra_rows.
+   Define
+      row_shift = num_extra_rows / (C.NumCols() - 1).
+
+   This function implements:
+      (*C)(i, j) = alpha * VecVec(A.Row(i), B.Row(i + j * row_shift))
+ */
+void GetAttentionDotProducts(BaseFloat alpha,
+                             const CuMatrixBase<BaseFloat> &A,
+                             const CuMatrixBase<BaseFloat> &B,
+                             CuMatrixBase<BaseFloat> *C);
+
+
+/**
+   This function is related to GetAttentionDotProducts(); it
+   is used in scaling the values by the softmax scales, and
+   in backprop.
+
+   We have put the A, B and C in an unusual order here in order
+   to make clearer the relationship with GetAttentionDotProducts().
+   The matrices have the same relationship in terms of their
+   dimensions, as A, B and C do in GetAttentionDotProducts().
+
+   This function implements:
+
+     A->Row(i) += \sum_j alpha * C(i, j) * B.Row(i + j * row_shift).
+ */
+void ApplyScalesToOutput(BaseFloat alpha,
+                         const CuMatrixBase<BaseFloat> &B,
+                         const CuMatrixBase<BaseFloat> &C,
+                         CuMatrixBase<BaseFloat> *A);
+
+
+/**
+   This function is related to GetAttentionDotProducts(); it
+   is used in backprop.
+
+   We have put the A, B and C in an unusual order here in order
+   to make clearer the relationship with GetAttentionDotProducts().
+   The matrices have the same relationship in terms of their
+   dimensions, as A, B and C do in GetAttentionDotProducts().
+
+   This function implements:
+
+     B->Row(i + j * row_shift) += alpha * C(i, j) * A.Row(i).
+ */
+void ApplyScalesToInput(BaseFloat alpha,
+                        const CuMatrixBase<BaseFloat> &A,
+                        const CuMatrixBase<BaseFloat> &C,
+                        CuMatrixBase<BaseFloat> *B);
+
+
+
+/**
+   This is a higher-level interface to the attention code.
+   Read the extended comment in the file nnet3/attention.h for context.
+
+     @param [in] key_scale   Scale on the non-context part of the keys.
+     @param [in] keys       Matrix whose rows contains the keys, dimension is
+                            num-input-rows by key-dim.
+     @param [in] queries    Matrix whose rows contains the queries, dimension
+                            is num-output-rows by query-dim, where query-dim
+                            == key-dim + context-dim.
+                            num-output-rows - num-input-rows must be a multiple
+                            of context-dim - 1 (we'll 'shift' the keys by multiples
+                            of 0, n, 2n, ... (context-dim - 1) * n.
+     @param [in] values     Values to average at the output, of dimension
+                            num-input-rows by value-dim.  [we may add context
+                            information to these averages if required, see comment
+                            for 'output'].
+     @param [out] c         Expected to be finite at entry (no infs or nan's);
+                            at exit this will contain the output of the softmax.
+                            Must be of dimension num-output-rows by context-dim.
+     @param [out] output    The output of the attention mechanism will be *added*
+                            to this location.  Dimension must be num-output-rows
+                            by either value-dim, or value-dim + context-dim.  To
+                            the first 'value-dim' columns of this will be added
+                            the weighted combination of 'values', weighted by
+                            the corresponding weights of 'c' (e.g. the first
+                            column of 'c' scaling the first 'output-dim' rows of
+                            'values', then the next column of 'c' scaling the
+                            submatrix of 'values' shifted by 'n', and so on.
+                            If the output->NumCols() is value-dim + context-dim,
+                            'c' will be added to the remaining columns of
+                            'output'.
+ */
+void AttentionForward(BaseFloat key_scale,
+                      const CuMatrixBase<BaseFloat> &keys,
+                      const CuMatrixBase<BaseFloat> &queries,
+                      const CuMatrixBase<BaseFloat> &values,
+                      CuMatrixBase<BaseFloat> *c,
+                      CuMatrixBase<BaseFloat> *output);
+
+/** Performs the backward pass corresponding to 'AttentionForward',
+    propagating the derivative back to the keys, queries and values.
+
+    The interface should be easy to understand with reference
+    to AttentionForward(), so we won't document it, except to note
+    that 'keys_deriv', 'queries_deriv' and 'values_deriv' are
+    *added to*, not set, by this function.
+ */
+void AttentionBackward(BaseFloat key_scale,
+                       const CuMatrixBase<BaseFloat> &keys,
+                       const CuMatrixBase<BaseFloat> &queries,
+                       const CuMatrixBase<BaseFloat> &values,
+                       const CuMatrixBase<BaseFloat> &c,
+                       const CuMatrixBase<BaseFloat> &output_deriv,
+                       CuMatrixBase<BaseFloat> *keys_deriv,
+                       CuMatrixBase<BaseFloat> *queries_deriv,
+                       CuMatrixBase<BaseFloat> *values_deriv);
+
+
+
+
+
+
+} // namespace attention
+} // namespace nnet3
+} // namespace kaldi
+
+
+#endif
index 279759b2bd1223bb691c59be9d292b5acafe3731..b69215f8d54bb3c123cb9a06bddacde895ab3ada 100644 (file)
@@ -22,6 +22,7 @@
 #include <iomanip>
 #include "nnet3/convolution.h"
 #include "nnet3/nnet-parse.h"
+#include "nnet3/nnet-compile-utils.h"
 
 namespace kaldi {
 namespace nnet3 {
@@ -1413,46 +1414,6 @@ void CompileConvolutionComputation(
 }
 
 
-
-// This function outputs a sorted list of pairs of (n, x) values that are
-// encountered in the provided list of Indexes.
-static void GetNxList(const std::vector<Index> &indexes,
-                      std::vector<std::pair<int32, int32> > *pairs) {
-  // set of (n,x) pairs
-  std::unordered_set<std::pair<int32, int32>, PairHasher<int32> > n_x_set;
-
-  for (std::vector<Index>::const_iterator iter = indexes.begin();
-       iter != indexes.end(); ++iter)
-    n_x_set.insert(std::pair<int32, int32>(iter->n, iter->x));
-  pairs->clear();
-  pairs->reserve(n_x_set.size());
-  for (std::unordered_set<std::pair<int32, int32>, PairHasher<int32> >::iterator
-           iter = n_x_set.begin(); iter != n_x_set.end(); ++iter)
-    pairs->push_back(*iter);
-  std::sort(pairs->begin(), pairs->end());
-}
-
-
-// This function outputs a sorted list of the 't' values that are
-// encountered in the provided list of Indexes.
-static void GetTList(const std::vector<Index> &indexes,
-                     std::vector<int32> *t_values) {
-  // set of t values
-  std::unordered_set<int32> t_set;
-
-  for (std::vector<Index>::const_iterator iter = indexes.begin();
-       iter != indexes.end(); ++iter)
-    if (iter->t != kNoTime)
-      t_set.insert(iter->t);
-  t_values->clear();
-  t_values->reserve(t_set.size());
-  for (std::unordered_set<int32>::iterator iter = t_set.begin();
-       iter != t_set.end(); ++iter)
-    t_values->push_back(*iter);
-  std::sort(t_values->begin(), t_values->end());
-}
-
-
 // Returns the greatest common divisor of the differences between the values in
 // 'vec', or zero if the vector has zero or one element.  It is an error if
 // 'vec' has repeated elements (which could cause a crash in 'Gcd').
@@ -1677,6 +1638,34 @@ void MakeComputation(const ConvolutionModel &model,
   ComputeTempMatrixSize(opts, computation);
 }
 
+
+void ConvolutionComputationIo::Write(std::ostream &os, bool binary) const {
+  WriteToken(os, binary, "<ConvCompIo>");
+  WriteBasicType(os, binary, num_images);
+  WriteBasicType(os, binary, start_t_in);
+  WriteBasicType(os, binary, t_step_in);
+  WriteBasicType(os, binary, num_t_in);
+  WriteBasicType(os, binary, start_t_out);
+  WriteBasicType(os, binary, t_step_out);
+  WriteBasicType(os, binary, num_t_out);
+  WriteBasicType(os, binary, reorder_t_in);
+  WriteToken(os, binary, "</ConvCompIo>");
+}
+
+
+void ConvolutionComputationIo::Read(std::istream &is, bool binary) {
+  ExpectToken(is, binary, "<ConvCompIo>");
+  ReadBasicType(is, binary, &num_images);
+  ReadBasicType(is, binary, &start_t_in);
+  ReadBasicType(is, binary, &t_step_in);
+  ReadBasicType(is, binary, &num_t_in);
+  ReadBasicType(is, binary, &start_t_out);
+  ReadBasicType(is, binary, &t_step_out);
+  ReadBasicType(is, binary, &num_t_out);
+  ReadBasicType(is, binary, &reorder_t_in);
+  ExpectToken(is, binary, "</ConvCompIo>");
+}
+
 } // namespace time_height_convolution
 } // namespace nnet3
 } // namespace kaldi
index 8a09c2e9551eb86f5bc98aa6c948cbdeac9cb884..cb4d8331cbfdec68f5913360882fd0e7c3f6251e 100644 (file)
@@ -409,6 +409,9 @@ struct ConvolutionComputationIo {
   // a reshaping such that we can imagine that the input and output have the
   // same 't' increment; it's useful in subsampling convolutions..
   int32 reorder_t_in;
+
+  void Write(std::ostream &os, bool binary) const;
+  void Read(std::istream &is, bool binary);
 };
 
 /**
@@ -547,6 +550,7 @@ void ConvolveBackwardParams(
    (e.g. as supplied to ReorderIndexes()), and figures out a regular structure
    for them (i.e. the smallest grid that will completely cover all the t,n
    pairs).
+   This function ignores any 't' values that are kNoTime.
 */
 void GetComputationIo(
     const std::vector<Index> &input_indexes,
diff --git a/src/nnet3/nnet-attention-component.cc b/src/nnet3/nnet-attention-component.cc
new file mode 100644 (file)
index 0000000..58e662a
--- /dev/null
@@ -0,0 +1,671 @@
+// nnet3/nnet-attention-component.cc
+
+// Copyright      2017  Johns Hopkins University (author: Daniel Povey)
+//                2017  Hossein Hadian
+
+// See ../../COPYING for clarification regarding multiple authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//  http://www.apache.org/licenses/LICENSE-2.0
+//
+// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+// MERCHANTABLITY OR NON-INFRINGEMENT.
+// See the Apache 2 License for the specific language governing permissions and
+// limitations under the License.
+
+#include <iterator>
+#include <sstream>
+#include <iomanip>
+#include "nnet3/nnet-attention-component.h"
+#include "nnet3/nnet-parse.h"
+#include "nnet3/nnet-compile-utils.h"
+
+namespace kaldi {
+namespace nnet3 {
+
+
+std::string RestrictedAttentionComponent::Info() const {
+  std::stringstream stream;
+  stream << Type() << ", input-dim=" << InputDim()
+         << ", output-dim=" << OutputDim()
+         << ", num-heads=" << num_heads_
+         << ", time-stride=" << time_stride_
+         << ", key-dim=" << key_dim_
+         << ", value-dim=" << value_dim_
+         << ", num-left-inputs=" << num_left_inputs_
+         << ", num-right-inputs=" << num_right_inputs_
+         << ", context-dim=" << context_dim_
+         << ", num-left-inputs-required=" << num_left_inputs_required_
+         << ", num-right-inputs-required=" << num_right_inputs_required_
+         << ", output-context=" << (output_context_ ? "true" : "false")
+         << ", key-scale=" << key_scale_;
+  if (stats_count_ != 0.0) {
+    stream << ", entropy=";
+    for (int32 i = 0; i < entropy_stats_.Dim(); i++)
+      stream << (entropy_stats_(i) / stats_count_) << ',';
+    for (int32 i = 0; i < num_heads_ && i < 5; i++) {
+      stream << " posterior-stats[" << i <<"]=";
+      for (int32 j = 0; j < posterior_stats_.NumCols(); j++)
+        stream << (posterior_stats_(i,j) / stats_count_) << ',';
+    }
+    stream << " stats-count=" << stats_count_;
+  }
+  return stream.str();
+}
+
+RestrictedAttentionComponent::RestrictedAttentionComponent(
+    const RestrictedAttentionComponent &other):
+    num_heads_(other.num_heads_),
+    key_dim_(other.key_dim_),
+    value_dim_(other.value_dim_),
+    num_left_inputs_(other.num_left_inputs_),
+    num_right_inputs_(other.num_right_inputs_),
+    time_stride_(other.time_stride_),
+    context_dim_(other.context_dim_),
+    num_left_inputs_required_(other.num_left_inputs_required_),
+    num_right_inputs_required_(other.num_right_inputs_required_),
+    output_context_(other.output_context_),
+    key_scale_(other.key_scale_),
+    stats_count_(other.stats_count_),
+    entropy_stats_(other.entropy_stats_),
+    posterior_stats_(other.posterior_stats_) { }
+
+
+
+void RestrictedAttentionComponent::InitFromConfig(ConfigLine *cfl) {
+  num_heads_ = 1;
+  key_dim_ = -1;
+  value_dim_ = -1;
+  num_left_inputs_ = -1;
+  num_right_inputs_ = -1;
+  time_stride_ = 1;
+  num_left_inputs_required_ = -1;
+  num_right_inputs_required_ = -1;
+  output_context_ = true;
+  key_scale_ = -1.0;
+
+
+  // mandatory arguments.
+  bool ok = cfl->GetValue("key-dim", &key_dim_) &&
+      cfl->GetValue("value-dim", &value_dim_) &&
+      cfl->GetValue("num-left-inputs", &num_left_inputs_) &&
+      cfl->GetValue("num-right-inputs", &num_right_inputs_);
+
+  if (!ok)
+    KALDI_ERR << "All of the values key-dim, value-dim, "
+        "num-left-inputs and num-right-inputs must be defined.";
+  // optional arguments.
+  cfl->GetValue("num-heads", &num_heads_);
+  cfl->GetValue("time-stride", &time_stride_);
+  cfl->GetValue("num-left-inputs-required", &num_left_inputs_required_);
+  cfl->GetValue("num-right-inputs-required", &num_right_inputs_required_);
+  cfl->GetValue("output-context", &output_context_);
+  cfl->GetValue("key-scale", &key_scale_);
+
+  if (key_scale_ < 0.0) key_scale_ = 1.0 / sqrt(key_dim_);
+  if (num_left_inputs_required_ < 0)
+    num_left_inputs_required_ = num_left_inputs_;
+  if (num_right_inputs_required_ < 0)
+    num_right_inputs_required_ = num_right_inputs_;
+
+  if (num_heads_ <= 0 || key_dim_ <= 0 || value_dim_ <= 0 ||
+      num_left_inputs_ < 0 || num_right_inputs_ < 0 ||
+      (num_left_inputs_ + num_right_inputs_) <= 0 ||
+      num_left_inputs_required_ > num_left_inputs_ ||
+      num_right_inputs_required_ > num_right_inputs_ ||
+      time_stride_ <= 0)
+    KALDI_ERR << "Config line contains invalid values: "
+              << cfl->WholeLine();
+  stats_count_ = 0.0;
+  context_dim_ = num_left_inputs_ + 1 + num_right_inputs_;
+  Check();
+}
+
+
+
+void*
+RestrictedAttentionComponent::Propagate(const ComponentPrecomputedIndexes *indexes_in,
+                                        const CuMatrixBase<BaseFloat> &in,
+                                        CuMatrixBase<BaseFloat> *out) const {
+  const PrecomputedIndexes *indexes = dynamic_cast<const PrecomputedIndexes*>(
+      indexes_in);
+  KALDI_ASSERT(indexes != NULL &&
+               in.NumRows() == indexes->io.num_t_in * indexes->io.num_images &&
+               out->NumRows() == indexes->io.num_t_out * indexes->io.num_images);
+
+
+  Memo *memo = new Memo();
+  memo->c.Resize(out->NumRows(), context_dim_ * num_heads_);
+
+  int32 query_dim = key_dim_ + context_dim_;
+  int32 input_dim_per_head = key_dim_ + value_dim_ + query_dim,
+      output_dim_per_head = value_dim_ + (output_context_ ? context_dim_ : 0);
+  for (int32 h = 0; h < num_heads_; h++) {
+    CuSubMatrix<BaseFloat> in_part(in, 0, in.NumRows(),
+                                   h * input_dim_per_head, input_dim_per_head),
+        c_part(memo->c, 0, out->NumRows(),
+               h * context_dim_, context_dim_),
+        out_part(*out, 0, out->NumRows(),
+                 h * output_dim_per_head, output_dim_per_head);
+    PropagateOneHead(indexes->io, in_part, &c_part, &out_part);
+  }
+  return static_cast<void*>(memo);
+}
+
+void RestrictedAttentionComponent::PropagateOneHead(
+    const time_height_convolution::ConvolutionComputationIo &io,
+    const CuMatrixBase<BaseFloat> &in,
+    CuMatrixBase<BaseFloat> *c,
+    CuMatrixBase<BaseFloat> *out) const {
+  int32 query_dim = key_dim_ + context_dim_,
+      full_value_dim = value_dim_ + (output_context_ ? context_dim_ : 0);
+  KALDI_ASSERT(in.NumRows() == io.num_images * io.num_t_in &&
+               out->NumRows() == io.num_images * io.num_t_out &&
+               out->NumCols() == full_value_dim &&
+               in.NumCols() == (key_dim_ + value_dim_ + query_dim) &&
+               io.t_step_in == io.t_step_out &&
+               (io.start_t_out - io.start_t_in) % io.t_step_in == 0);
+
+  // 'steps_left_context' is the number of time-steps the input has on the left
+  // that don't appear in the output.
+  int32 steps_left_context = (io.start_t_out - io.start_t_in) / io.t_step_in,
+      rows_left_context = steps_left_context * io.num_images;
+  KALDI_ASSERT(rows_left_context >= 0);
+
+  // 'queries' contains the queries.  We don't use all rows of the input
+  // queries; only the rows that correspond to the time-indexes at the
+  // output, i.e. we exclude the left-context and right-context.
+  // 'out'; the remaining rows of 'in' that we didn't select correspond to left
+  // and right temporal context.
+  CuSubMatrix<BaseFloat> queries(in, rows_left_context, out->NumRows(),
+                                 key_dim_ + value_dim_, query_dim);
+  // 'keys' contains the keys; note, these are not extended with
+  // context information; that happens further in.
+  CuSubMatrix<BaseFloat> keys(in, 0, in.NumRows(), 0, key_dim_);
+
+  // 'values' contains the values which we will interpolate.
+  // these don't contain the context information; that will be added
+  // later if output_context_ == true.
+  CuSubMatrix<BaseFloat> values(in, 0, in.NumRows(), key_dim_, value_dim_);
+
+  attention::AttentionForward(key_scale_, keys, queries, values, c, out);
+}
+
+
+void RestrictedAttentionComponent::StoreStats(
+    const CuMatrixBase<BaseFloat> &, // in_value
+    const CuMatrixBase<BaseFloat> &, // out_value
+    void *memo_in) {
+  const Memo *memo = static_cast<const Memo*>(memo_in);
+  KALDI_ASSERT(memo != NULL);
+  if (entropy_stats_.Dim() != num_heads_) {
+    entropy_stats_.Resize(num_heads_);
+    posterior_stats_.Resize(num_heads_, context_dim_);
+    stats_count_ = 0.0;
+  }
+  const CuMatrix<BaseFloat> &c = memo->c;
+  if (RandInt(0, 2) == 0)
+    return;  // only actually store the stats for one in three minibatches, to
+             // save time.
+
+  { // first get the posterior stats.
+    CuVector<BaseFloat> c_sum(num_heads_ * context_dim_);
+    c_sum.AddRowSumMat(1.0, c, 0.0);
+    // view the vector as a matrix.
+    CuSubMatrix<BaseFloat> c_sum_as_mat(c_sum.Data(), num_heads_,
+                                        context_dim_, context_dim_);
+    CuMatrix<double> c_sum_as_mat_dbl(c_sum_as_mat);
+    posterior_stats_.AddMat(1.0, c_sum_as_mat_dbl);
+    KALDI_ASSERT(c.NumCols() == num_heads_ * context_dim_);
+  }
+  { // now get the entropy stats.
+    CuMatrix<BaseFloat> log_c(c);
+    log_c.ApplyFloor(1.0e-20);
+    log_c.ApplyLog();
+    CuVector<BaseFloat> dot_prod(num_heads_ * context_dim_);
+    dot_prod.AddDiagMatMat(-1.0, c, kTrans, log_c, kNoTrans, 0.0);
+    // dot_prod is the sum over the matrix's rows (which correspond
+    // to heads, and context positions), of - c * log(c), which is
+    // part of the entropy.  To get the actual contribution to the
+    // entropy, we have to sum 'dot_prod' over blocks of
+    // size 'context_dim_'; that gives us the entropy contribution
+    // per head.  We'd have to divide by c.NumRows() to get the
+    // actual entropy, but that's reflected in stats_count_.
+    CuSubMatrix<BaseFloat> entropy_mat(dot_prod.Data(), num_heads_,
+                                       context_dim_, context_dim_);
+    CuVector<BaseFloat> entropy_vec(num_heads_);
+    entropy_vec.AddColSumMat(1.0, entropy_mat);
+    Vector<double> entropy_vec_dbl(entropy_vec);
+    entropy_stats_.AddVec(1.0, entropy_vec_dbl);
+  }
+  stats_count_ += c.NumRows();
+}
+
+void RestrictedAttentionComponent::ZeroStats() {
+  entropy_stats_.SetZero();
+  posterior_stats_.SetZero();
+  stats_count_ = 0.0;
+}
+
+void RestrictedAttentionComponent::Scale(BaseFloat scale) {
+  entropy_stats_.Scale(scale);
+  posterior_stats_.Scale(scale);
+  stats_count_ *= scale;
+}
+
+void RestrictedAttentionComponent::Add(BaseFloat alpha, const Component &other_in) {
+  const RestrictedAttentionComponent *other =
+      dynamic_cast<const RestrictedAttentionComponent*>(&other_in);
+  KALDI_ASSERT(other != NULL);
+  if (entropy_stats_.Dim() == 0 && other->entropy_stats_.Dim() != 0)
+    entropy_stats_.Resize(other->entropy_stats_.Dim());
+  if (posterior_stats_.NumRows() == 0 && other->posterior_stats_.NumRows() != 0)
+    posterior_stats_.Resize(other->posterior_stats_.NumRows(), other->posterior_stats_.NumCols());
+  if (other->entropy_stats_.Dim() != 0)
+    entropy_stats_.AddVec(alpha, other->entropy_stats_);
+  if (other->posterior_stats_.NumRows() != 0)
+    posterior_stats_.AddMat(alpha, other->posterior_stats_);
+  stats_count_ += alpha * other->stats_count_;
+}
+
+
+void RestrictedAttentionComponent::Check() const {
+  KALDI_ASSERT(num_heads_ > 0 && key_dim_ > 0 && value_dim_ > 0 &&
+               num_left_inputs_ >= 0 && num_right_inputs_ >= 0 &&
+               (num_left_inputs_ + num_right_inputs_) > 0 &&
+               time_stride_ > 0 &&
+               context_dim_ == (num_left_inputs_ + 1 + num_right_inputs_) &&
+               num_left_inputs_required_ >= 0 &&
+               num_left_inputs_required_ <= num_left_inputs_ &&
+               num_right_inputs_required_ >= 0 &&
+               num_right_inputs_required_ <= num_right_inputs_ &&
+               key_scale_ > 0.0 && key_scale_ <= 1.0 &&
+               stats_count_ >= 0.0);
+}
+
+
+void RestrictedAttentionComponent::Backprop(
+    const std::string &debug_info,
+    const ComponentPrecomputedIndexes *indexes_in,
+    const CuMatrixBase<BaseFloat> &in_value,
+    const CuMatrixBase<BaseFloat> &, // out_value
+    const CuMatrixBase<BaseFloat> &out_deriv,
+    void *memo_in,
+    Component *to_update_in,
+    CuMatrixBase<BaseFloat> *in_deriv) const {
+  const PrecomputedIndexes *indexes =
+      dynamic_cast<const PrecomputedIndexes*>(indexes_in);
+  KALDI_ASSERT(indexes != NULL);
+  Memo *memo = static_cast<Memo*>(memo_in);
+  KALDI_ASSERT(memo != NULL);
+  const time_height_convolution::ConvolutionComputationIo &io = indexes->io;
+  KALDI_ASSERT(indexes != NULL &&
+               in_value.NumRows() == io.num_t_in * io.num_images &&
+               out_deriv.NumRows() == io.num_t_out * io.num_images &&
+               in_deriv != NULL && SameDim(in_value, *in_deriv));
+
+  const CuMatrix<BaseFloat> &c = memo->c;
+
+  int32 query_dim = key_dim_ + context_dim_,
+      input_dim_per_head = key_dim_ + value_dim_ + query_dim,
+      output_dim_per_head = value_dim_ + (output_context_ ? context_dim_ : 0);
+
+  for (int32 h = 0; h < num_heads_; h++) {
+    CuSubMatrix<BaseFloat>
+        in_value_part(in_value, 0, in_value.NumRows(),
+                      h * input_dim_per_head, input_dim_per_head),
+        c_part(c, 0, out_deriv.NumRows(),
+               h * context_dim_, context_dim_),
+        out_deriv_part(out_deriv, 0, out_deriv.NumRows(),
+                       h * output_dim_per_head, output_dim_per_head),
+        in_deriv_part(*in_deriv, 0, in_value.NumRows(),
+                      h * input_dim_per_head, input_dim_per_head);
+    BackpropOneHead(io, in_value_part, c_part, out_deriv_part,
+                    &in_deriv_part);
+  }
+}
+
+
+void RestrictedAttentionComponent::BackpropOneHead(
+    const time_height_convolution::ConvolutionComputationIo &io,
+    const CuMatrixBase<BaseFloat> &in_value,
+    const CuMatrixBase<BaseFloat> &c,
+    const CuMatrixBase<BaseFloat> &out_deriv,
+    CuMatrixBase<BaseFloat> *in_deriv) const {
+  // the easiest way to understand this is to compare it with PropagateOneHead().
+  int32 query_dim = key_dim_ + context_dim_,
+      full_value_dim = value_dim_ + (output_context_ ? context_dim_ : 0);
+  KALDI_ASSERT(in_value.NumRows() == io.num_images * io.num_t_in &&
+               out_deriv.NumRows() == io.num_images * io.num_t_out &&
+               out_deriv.NumCols() == full_value_dim &&
+               in_value.NumCols() == (key_dim_ + value_dim_ + query_dim) &&
+               io.t_step_in == io.t_step_out &&
+               (io.start_t_out - io.start_t_in) % io.t_step_in == 0 &&
+               SameDim(in_value, *in_deriv) &&
+               c.NumRows() == out_deriv.NumRows() &&
+               c.NumCols() == context_dim_);
+
+  // 'steps_left_context' is the number of time-steps the input has on the left
+  // that don't appear in the output.
+  int32 steps_left_context = (io.start_t_out - io.start_t_in) / io.t_step_in,
+      rows_left_context = steps_left_context * io.num_images;
+  KALDI_ASSERT(rows_left_context >= 0);
+
+
+  CuSubMatrix<BaseFloat> queries(in_value, rows_left_context, out_deriv.NumRows(),
+                                 key_dim_ + value_dim_, query_dim),
+      queries_deriv(*in_deriv, rows_left_context, out_deriv.NumRows(),
+                    key_dim_ + value_dim_, query_dim),
+      keys(in_value, 0, in_value.NumRows(), 0, key_dim_),
+      keys_deriv(*in_deriv,  0, in_value.NumRows(), 0, key_dim_),
+      values(in_value, 0, in_value.NumRows(), key_dim_, value_dim_),
+      values_deriv(*in_deriv, 0, in_value.NumRows(), key_dim_, value_dim_);
+
+  attention::AttentionBackward(key_scale_, keys, queries, values, c, out_deriv,
+                               &keys_deriv, &queries_deriv, &values_deriv);
+}
+
+
+
+void RestrictedAttentionComponent::ReorderIndexes(
+    std::vector<Index> *input_indexes,
+    std::vector<Index> *output_indexes) const {
+  using namespace time_height_convolution;
+  ConvolutionComputationIo io;
+  GetComputationStructure(*input_indexes, *output_indexes, &io);
+  std::vector<Index> new_input_indexes, new_output_indexes;
+  GetIndexes(*input_indexes, *output_indexes, io,
+             &new_input_indexes, &new_output_indexes);
+  input_indexes->swap(new_input_indexes);
+  output_indexes->swap(new_output_indexes);
+}
+
+void RestrictedAttentionComponent::GetComputationStructure(
+      const std::vector<Index> &input_indexes,
+      const std::vector<Index> &output_indexes,
+      time_height_convolution::ConvolutionComputationIo *io) const {
+  GetComputationIo(input_indexes, output_indexes, io);
+  // if there was only one output and/or input index (unlikely),
+  // just let the grid periodicity be t_stride_.
+  if (io->t_step_out == 0) io->t_step_out = time_stride_;
+  if (io->t_step_in == 0) io->t_step_in = time_stride_;
+
+  // We need the grid size on the input and output to be the same, and to divide
+  // t_stride_.  If someone is requesting the output more frequently than
+  // t_stride_, then after this change we may end up computing more outputs than
+  // we need, but this is not a configuration that I think is very likely.  We
+  // let the grid step be the gcd of the input and output steps, and of
+  // t_stride_.
+  // The next few statements may have the effect of making the grid finer at the
+  // input and output, while having the same start and end point.
+  int32 t_step = Gcd(Gcd(io->t_step_out, io->t_step_in), time_stride_);
+  int32 multiple_out = io->t_step_out / t_step,
+      multiple_in = io->t_step_in / t_step;
+  io->t_step_in = t_step;
+  io->t_step_out = t_step;
+  io->num_t_out = 1 + multiple_out * (io->num_t_out - 1);
+  io->num_t_in = 1 + multiple_in * (io->num_t_in - 1);
+
+  // Now ensure that the extent of the input has at least
+  // the requested left-context and right context; if
+  // this increases the amount of input, we'll do zero-padding.
+  int32 first_requested_input =
+          io->start_t_out - (time_stride_ * num_left_inputs_),
+      first_required_input =
+         io->start_t_out - (time_stride_ * num_left_inputs_required_),
+      last_t_out = io->start_t_out + (io->num_t_out - 1) * t_step,
+      last_t_in = io->start_t_in + (io->num_t_in - 1) * t_step,
+      last_requested_input = last_t_out + (time_stride_ * num_right_inputs_),
+      last_required_input =
+           last_t_out + (time_stride_ * num_right_inputs_required_);
+
+  // check that we don't have *more* than the requested context,
+  // but that we have at least the required context.
+  KALDI_ASSERT(io->start_t_in >= first_requested_input &&
+               last_t_in <= last_requested_input &&
+               io->start_t_in <= first_required_input &&
+               last_t_in >= last_required_input);
+
+  // For the inputs that were requested, but not required,
+  // we pad with zeros.  We pad the 'io' object, adding these
+  // extra inputs structurally; in runtime they'll be set to zero.
+  io->start_t_in = first_requested_input;
+  io->num_t_in = 1 + (last_requested_input - first_requested_input) / t_step;
+}
+
+void RestrictedAttentionComponent::Write(std::ostream &os, bool binary) const {
+  WriteToken(os, binary, "<RestrictedAttentionComponent>");
+  WriteToken(os, binary, "<NumHeads>");
+  WriteBasicType(os, binary, num_heads_);
+  WriteToken(os, binary, "<KeyDim>");
+  WriteBasicType(os, binary, key_dim_);
+  WriteToken(os, binary, "<ValueDim>");
+  WriteBasicType(os, binary, value_dim_);
+  WriteToken(os, binary, "<NumLeftInputs>");
+  WriteBasicType(os, binary, num_left_inputs_);
+  WriteToken(os, binary, "<NumRightInputs>");
+  WriteBasicType(os, binary, num_right_inputs_);
+  WriteToken(os, binary, "<TimeStride>");
+  WriteBasicType(os, binary, time_stride_);
+  WriteToken(os, binary, "<NumLeftInputsRequired>");
+  WriteBasicType(os, binary, num_left_inputs_required_);
+  WriteToken(os, binary, "<NumRightInputsRequired>");
+  WriteBasicType(os, binary, num_right_inputs_required_);
+  WriteToken(os, binary, "<OutputContext>");
+  WriteBasicType(os, binary, output_context_);
+  WriteToken(os, binary, "<KeyScale>");
+  WriteBasicType(os, binary, key_scale_);
+  WriteToken(os, binary, "<StatsCount>");
+  WriteBasicType(os, binary, stats_count_);
+  WriteToken(os, binary, "<EntropyStats>");
+  entropy_stats_.Write(os, binary);
+  WriteToken(os, binary, "<PosteriorStats>");
+  posterior_stats_.Write(os, binary);
+  WriteToken(os, binary, "</RestrictedAttentionComponent>");
+}
+
+void RestrictedAttentionComponent::Read(std::istream &is, bool binary) {
+  ExpectOneOrTwoTokens(is, binary, "<RestrictedAttentionComponent>",
+                       "<NumHeads>");
+  ReadBasicType(is, binary, &num_heads_);
+  ExpectToken(is, binary, "<KeyDim>");
+  ReadBasicType(is, binary, &key_dim_);
+  ExpectToken(is, binary, "<ValueDim>");
+  ReadBasicType(is, binary, &value_dim_);
+  ExpectToken(is, binary, "<NumLeftInputs>");
+  ReadBasicType(is, binary, &num_left_inputs_);
+  ExpectToken(is, binary, "<NumRightInputs>");
+  ReadBasicType(is, binary, &num_right_inputs_);
+  ExpectToken(is, binary, "<TimeStride>");
+  ReadBasicType(is, binary, &time_stride_);
+  ExpectToken(is, binary, "<NumLeftInputsRequired>");
+  ReadBasicType(is, binary, &num_left_inputs_required_);
+  ExpectToken(is, binary, "<NumRightInputsRequired>");
+  ReadBasicType(is, binary, &num_right_inputs_required_);
+  ExpectToken(is, binary, "<OutputContext>");
+  ReadBasicType(is, binary, &output_context_);
+  ExpectToken(is, binary, "<KeyScale>");
+  ReadBasicType(is, binary, &key_scale_);
+  ExpectToken(is, binary, "<StatsCount>");
+  ReadBasicType(is, binary, &stats_count_);
+  ExpectToken(is, binary, "<EntropyStats>");
+  entropy_stats_.Read(is, binary);
+  ExpectToken(is, binary, "<PosteriorStats>");
+  posterior_stats_.Read(is, binary);
+  ExpectToken(is, binary, "</RestrictedAttentionComponent>");
+
+  context_dim_ = num_left_inputs_ + 1 + num_right_inputs_;
+}
+
+
+void RestrictedAttentionComponent::GetInputIndexes(
+    const MiscComputationInfo &misc_info,
+    const Index &output_index,
+    std::vector<Index> *desired_indexes) const {
+  KALDI_ASSERT(output_index.t != kNoTime);
+  int32 first_time = output_index.t - (time_stride_ * num_left_inputs_),
+      last_time = output_index.t + (time_stride_ * num_right_inputs_);
+  desired_indexes->clear();
+  desired_indexes->resize(context_dim_);
+  int32 n = output_index.n, x = output_index.x,
+      i = 0;
+  for (int32 t = first_time; t <= last_time; t += time_stride_, i++) {
+    (*desired_indexes)[i].n = n;
+    (*desired_indexes)[i].t = t;
+    (*desired_indexes)[i].x = x;
+  }
+  KALDI_ASSERT(i == context_dim_);
+}
+
+
+bool RestrictedAttentionComponent::IsComputable(
+    const MiscComputationInfo &misc_info,
+    const Index &output_index,
+    const IndexSet &input_index_set,
+    std::vector<Index> *used_inputs) const {
+  KALDI_ASSERT(output_index.t != kNoTime);
+  Index index(output_index);
+
+  if (used_inputs != NULL) {
+    int32 first_time = output_index.t - (time_stride_ * num_left_inputs_),
+        last_time = output_index.t + (time_stride_ * num_right_inputs_);
+    used_inputs->clear();
+    used_inputs->reserve(context_dim_);
+
+    for (int32 t = first_time; t <= last_time; t += time_stride_) {
+      index.t = t;
+      if (input_index_set(index)) {
+        // This input index is available.
+        used_inputs->push_back(index);
+      } else {
+        // This input index is not available.
+        int32 offset = (t - output_index.t) / time_stride_;
+        if (offset >= num_left_inputs_required_ &&
+            offset <= num_right_inputs_required_) {
+          used_inputs->clear();
+          return false;
+        }
+      }
+    }
+    // All required time-offsets of the output were computable. -> return true.
+    return true;
+  } else {
+    int32 t = output_index.t,
+        first_time_required = t - (time_stride_ * num_left_inputs_required_),
+        last_time_required = t + (time_stride_ * num_right_inputs_required_);
+    for (int32 t = first_time_required;
+         t <= last_time_required;
+         t += time_stride_) {
+      index.t = t;
+      if (!input_index_set(index))
+        return false;
+    }
+    return true;
+  }
+}
+
+
+// static
+void RestrictedAttentionComponent::CreateIndexesVector(
+    const std::vector<std::pair<int32, int32> > &n_x_pairs,
+    int32 t_start, int32 t_step, int32 num_t_values,
+    const std::unordered_set<Index, IndexHasher> &index_set,
+    std::vector<Index> *output_indexes) {
+  output_indexes->resize(static_cast<size_t>(num_t_values) * n_x_pairs.size());
+  std::vector<Index>::iterator out_iter = output_indexes->begin();
+  for (int32 t = t_start; t < t_start + (t_step * num_t_values); t += t_step) {
+    std::vector<std::pair<int32, int32> >::const_iterator
+        iter = n_x_pairs.begin(), end = n_x_pairs.end();
+    for (; iter != end; ++iter) {
+      out_iter->n = iter->first;
+      out_iter->t = t;
+      out_iter->x = iter->second;
+      if (index_set.count(*out_iter) == 0)
+        out_iter->t = kNoTime;
+      ++out_iter;
+    }
+  }
+  KALDI_ASSERT(out_iter == output_indexes->end());
+}
+
+void RestrictedAttentionComponent::GetIndexes(
+      const std::vector<Index> &input_indexes,
+      const std::vector<Index> &output_indexes,
+      time_height_convolution::ConvolutionComputationIo &io,
+      std::vector<Index> *new_input_indexes,
+      std::vector<Index> *new_output_indexes) const {
+
+  std::unordered_set<Index, IndexHasher> input_set, output_set;
+  for (std::vector<Index>::const_iterator iter = input_indexes.begin();
+       iter != input_indexes.end(); ++iter)
+    input_set.insert(*iter);
+  for (std::vector<Index>::const_iterator iter = output_indexes.begin();
+       iter != output_indexes.end(); ++iter)
+    output_set.insert(*iter);
+
+  std::vector<std::pair<int32, int32> > n_x_pairs;
+  GetNxList(input_indexes, &n_x_pairs);  // the n,x pairs at the output will be
+                                         // identical.
+  KALDI_ASSERT(n_x_pairs.size() == io.num_images);
+  CreateIndexesVector(n_x_pairs, io.start_t_in, io.t_step_in, io.num_t_in,
+                      input_set, new_input_indexes);
+  CreateIndexesVector(n_x_pairs, io.start_t_out, io.t_step_out, io.num_t_out,
+                      output_set, new_output_indexes);
+}
+
+ComponentPrecomputedIndexes* RestrictedAttentionComponent::PrecomputeIndexes(
+    const MiscComputationInfo &,  // misc_info
+    const std::vector<Index> &input_indexes,
+    const std::vector<Index> &output_indexes,
+    bool) // need_backprop
+    const {
+  PrecomputedIndexes *ans = new PrecomputedIndexes();
+  GetComputationStructure(input_indexes, output_indexes, &(ans->io));
+  if (GetVerboseLevel() >= 2) {
+    // what goes next is just a check.
+    std::vector<Index> new_input_indexes, new_output_indexes;
+    GetIndexes(input_indexes, output_indexes, ans->io,
+               &new_input_indexes, &new_output_indexes);
+    // input_indexes and output_indexes should be the ones that were
+    // output by ReorderIndexes(), so they should already
+    // have gone through the GetComputationStructure()->GetIndexes()
+    // procedure.  Applying the same procedure twice is supposed to
+    // give an unchanged results.
+    KALDI_ASSERT(input_indexes == new_input_indexes &&
+                 output_indexes == new_output_indexes);
+  }
+  return ans;
+}
+
+
+
+RestrictedAttentionComponent::PrecomputedIndexes*
+RestrictedAttentionComponent::PrecomputedIndexes::Copy() const {
+  return new PrecomputedIndexes(*this);
+}
+
+void RestrictedAttentionComponent::PrecomputedIndexes::Write(
+    std::ostream &os, bool binary) const {
+  WriteToken(os, binary, "<RestrictedAttentionComponentPrecomputedIndexes>");
+  WriteToken(os, binary, "<Io>");
+  io.Write(os, binary);
+  WriteToken(os, binary, "</RestrictedAttentionComponentPrecomputedIndexes>");
+}
+
+void RestrictedAttentionComponent::PrecomputedIndexes::Read(
+    std::istream &is, bool binary) {
+  ExpectOneOrTwoTokens(is, binary,
+                       "<RestrictedAttentionComponentPrecomputedIndexes>",
+                       "<Io>");
+  io.Read(is, binary);
+  ExpectToken(is, binary, "</RestrictedAttentionComponentPrecomputedIndexes>");
+}
+
+
+} // namespace nnet3
+} // namespace kaldi
diff --git a/src/nnet3/nnet-attention-component.h b/src/nnet3/nnet-attention-component.h
new file mode 100644 (file)
index 0000000..6072fcd
--- /dev/null
@@ -0,0 +1,311 @@
+// nnet3/nnet-attention-component.h
+
+// Copyright      2017  Johns Hopkins University (author: Daniel Povey)
+
+// See ../../COPYING for clarification regarding multiple authors
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//  http://www.apache.org/licenses/LICENSE-2.0
+//
+// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+// MERCHANTABLITY OR NON-INFRINGEMENT.
+// See the Apache 2 License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef KALDI_NNET3_NNET_ATTENTION_COMPONENT_H_
+#define KALDI_NNET3_NNET_ATTENTION_COMPONENT_H_
+
+#include "nnet3/nnet-common.h"
+#include "nnet3/nnet-component-itf.h"
+#include "nnet3/natural-gradient-online.h"
+#include "nnet3/attention.h"
+#include <iostream>
+
+namespace kaldi {
+namespace nnet3 {
+
+/// @file  nnet-attention-component.h
+///
+/// Contains component(s) related to attention models.
+
+
+
+/**
+   RestrictedAttentionComponent implements an attention model with restricted
+   temporal context.  What is implemented here is a case of self-attention,
+   meaning that the set of indexes on the input is the same set as the indexes
+   on the output (like an N->N mapping, ignoring edge effects, as opposed to an
+   N->M mapping that you might see in a translation model).  "Restricted" means
+   that the source indexes are constrained to be close to the destination
+   indexes, i.e. when outputting something for time 't' we attend to a narrow
+   range of source time indexes close to 't'.
+
+   This component is just a fixed nonlinearity (albeit of a type that
+   "knows about" time, i.e. the output at time 't' depends on inputs
+   at a range of time values).  This component is not updatable; all the
+   parameters are expected to live in the previous component which is most
+   likely going to be of type NaturalGradientAffineComponent.  For a more
+   in-depth explanation, please see comments in the source of the file
+   attention.h.  Also, look at the comments for InputDim() and OutputDim() which
+   help to clarify the input and output formats.
+
+   The following are the parameters accepted on the config line, with examples
+   of their values.
+
+
+     num-heads        E.g. num-heads=10.  Defaults to 1.  Having multiple heads
+                      just means the same nonlinearity is repeated many times.
+                      InputDim() and OutputDim() are multiples of num-heads.
+     key-dim          E.g. key-dim=60.  Must be specified.  Dimension of input keys.
+     value-dim        E.g. value-dim=60.  Must be specified.  Dimension of input
+                      values (these are the things over which the component forms
+                      a weighted sum, although if output-context=true we append
+                      to the output the weights of the weighted sum, as they might
+                      also carry useful information.
+     time-stride      Stride for 't' value, e.g. 1 or 3.  For example, if time-stride=3,
+                      to compute the output for t=10 we'd use the input for time
+                      values like ... t=7, t=10, t=13, ... (the ends of
+                      this range depend on num-left-inputs and num-right-inputs).
+     num-left-inputs  Number of frames to the left of the current frame, that we
+                      use as input, e.g. 5.  (The 't' values used will be separated
+                      by 'time-stride').  num-left-inputs must be >= 0.  Must be
+                      specified.
+     num-right-inputs  Number of frames to the right of the current frame, that we
+                      use as input, e.g. 2.  Must be >= 0 and must be specified.
+                      You are not allowed to set both num-left-inputs and
+                      num-right-inputs to zero.
+     num-left-inputs-required  The number of frames to the left, that are
+                      required in order to produce an output.  Defaults to the
+                      same as num-left-inputs, but you can set it to a smaller
+                      value if you want.  We'll use zero-padding for
+                      non-required inputs that are not present in the input.  Be
+                      careful with this because it interacts with decoding
+                      settings; for non-online decoding and for dumping of egs
+                      it would be advisable to increase the extra-left-context
+                      parameter by the sum of the difference between
+                      num-left-inputs-required and num-left-inputs, although you
+                      could leave extra-left-context-initial at zero.
+     num-right-inputs-required  See num-left-inputs-required for explanation;
+                      it's the mirror image.  Defaults to num-right-inputs.
+                      However, be even more careful with the right-hand version;
+                      if you set this, online (looped) decoding will not work
+                      correctly.  It might be wiser just to reduce num-right-inputs
+                      if you care about real-time decoding.
+     key-scale        Scale on the keys (but not the added context).  Defaults to 1.0 /
+                      sqrt(key-dim), like the 1/sqrt(d_k) value in the
+                      "Attention is all you need" paper.  This helps prevent saturation
+                      of the softmax.
+     output-context  (Default: true).  If true, output the softmax that encodes which
+                     positions we chose, in addition to the input values.
+ */
+class RestrictedAttentionComponent: public Component {
+ public:
+
+  // The use of this constructor should only precede InitFromConfig()
+  RestrictedAttentionComponent() { }
+
+  // Copy constructor
+  RestrictedAttentionComponent(const RestrictedAttentionComponent &other);
+
+  virtual int32 InputDim() const {
+    // the input is interpreted as being appended blocks one for each head; each
+    // such block is interpreted as (key, value, query).
+    int32 query_dim = key_dim_ + context_dim_;
+    return num_heads_ * (key_dim_ + value_dim_ + query_dim);
+  }
+  virtual int32 OutputDim() const {
+    // the output consists of appended blocks, one for each head; each such
+    // block is is the attention weighted average of the input values, to which
+    // we append softmax encoding of the positions we chose, if output_context_
+    // == true.
+    return num_heads_ * (value_dim_ + (output_context_ ? context_dim_ : 0));
+  }
+  virtual std::string Info() const;
+  virtual void InitFromConfig(ConfigLine *cfl);
+  virtual std::string Type() const { return "RestrictedAttentionComponent"; }
+  virtual int32 Properties() const {
+    return kReordersIndexes|kBackpropNeedsInput|kPropagateAdds|kBackpropAdds|
+        kStoresStats|kUsesMemo;
+  }
+  virtual void* Propagate(const ComponentPrecomputedIndexes *indexes,
+                         const CuMatrixBase<BaseFloat> &in,
+                         CuMatrixBase<BaseFloat> *out) const;
+  virtual void StoreStats(const CuMatrixBase<BaseFloat> &in_value,
+                          const CuMatrixBase<BaseFloat> &out_value,
+                          void *memo);
+  virtual void Scale(BaseFloat scale);
+  virtual void Add(BaseFloat alpha, const Component &other);
+  virtual void ZeroStats();
+
+  virtual void Backprop(const std::string &debug_info,
+                        const ComponentPrecomputedIndexes *indexes,
+                        const CuMatrixBase<BaseFloat> &in_value,
+                        const CuMatrixBase<BaseFloat> &out_value,
+                        const CuMatrixBase<BaseFloat> &out_deriv,
+                        void *memo,
+                        Component *to_update,
+                        CuMatrixBase<BaseFloat> *in_deriv) const;
+  virtual void Read(std::istream &is, bool binary);
+  virtual void Write(std::ostream &os, bool binary) const;
+  virtual Component* Copy() const {
+    return new RestrictedAttentionComponent(*this);
+  }
+  virtual void DeleteMemo(void *memo) const { delete static_cast<Memo*>(memo); }
+
+  // Some functions that are only to be reimplemented for GeneralComponents.
+
+  // This ReorderIndexes function may insert 'blank' indexes (indexes with
+  // t == kNoTime) as well as reordering the indexes.  This is allowed
+  // behavior of ReorderIndexes functions.
+  virtual void ReorderIndexes(std::vector<Index> *input_indexes,
+                              std::vector<Index> *output_indexes) const;
+
+  virtual void GetInputIndexes(const MiscComputationInfo &misc_info,
+                               const Index &output_index,
+                               std::vector<Index> *desired_indexes) const;
+
+  // This function returns true if at least one of the input indexes used to
+  // compute this output index is computable.
+  virtual bool IsComputable(const MiscComputationInfo &misc_info,
+                            const Index &output_index,
+                            const IndexSet &input_index_set,
+                            std::vector<Index> *used_inputs) const;
+
+  virtual ComponentPrecomputedIndexes* PrecomputeIndexes(
+      const MiscComputationInfo &misc_info,
+      const std::vector<Index> &input_indexes,
+      const std::vector<Index> &output_indexes,
+      bool need_backprop) const;
+
+  class PrecomputedIndexes: public ComponentPrecomputedIndexes {
+   public:
+    PrecomputedIndexes() { }
+    PrecomputedIndexes(const PrecomputedIndexes &other):
+        io(other.io) { }
+    virtual PrecomputedIndexes *Copy() const;
+    virtual void Write(std::ostream &os, bool binary) const;
+    virtual void Read(std::istream &os, bool binary);
+    virtual std::string Type() const {
+      return "RestrictedAttentionComponentPrecomputedIndexes";
+    }
+    virtual ~PrecomputedIndexes() { }
+
+    time_height_convolution::ConvolutionComputationIo io;
+  };
+
+  // This is what's returned as the 'memo' from the Propagate() function.
+  struct Memo {
+    // c is of dimension (num_heads_ * num-output-frames) by context_dim_,
+    // where num-output-frames is the number of frames of output the
+    // corresponding Propagate function produces.
+    // Each block of 'num-output-frames' rows of c_t is the
+    // post-softmax matrix of weights.
+    CuMatrix<BaseFloat> c;
+  };
+
+ private:
+
+  // Does the propagation for one head; this is called for each
+  // head by the top-level Propagate function.  Later on we may
+  // figure out a way to avoid doing this sequentially.
+  // 'in' and 'out' are submatrices of the 'in' and 'out' passed
+  // to the top-level Propagate function, and 'c' is a submatrix
+  // of the 'c' matrix in the memo we're creating.
+  //
+  // Assumes 'c' has already been zerooed.
+  void PropagateOneHead(
+      const time_height_convolution::ConvolutionComputationIo &io,
+      const CuMatrixBase<BaseFloat> &in,
+      CuMatrixBase<BaseFloat> *c,
+      CuMatrixBase<BaseFloat> *out) const;
+
+
+  // does the backprop for one head; called by Backprop().
+  void BackpropOneHead(
+      const time_height_convolution::ConvolutionComputationIo &io,
+      const CuMatrixBase<BaseFloat> &in_value,
+      const CuMatrixBase<BaseFloat> &c,
+      const CuMatrixBase<BaseFloat> &out_deriv,
+      CuMatrixBase<BaseFloat> *in_deriv) const;
+
+  // This function, used in ReorderIndexes() and PrecomputedIndexes(),
+  // works out what grid structure over time we will have at the input
+  // and the output.
+  // Note: it may produce a grid that encompasses more than what was
+  // listed in 'input_indexes' and 'output_indexes'.  This is OK.
+  // ReorderIndexes() will add placeholders with t == kNoTime for
+  // padding, and at the input of this component those placeholders
+  // will be zero; at the output the placeholders will be ignored.
+  void GetComputationStructure(
+      const std::vector<Index> &input_indexes,
+      const std::vector<Index> &output_indexes,
+      time_height_convolution::ConvolutionComputationIo *io) const;
+
+  // This function, used in ReorderIndexes(), obtains the indexes with the
+  // correct structure and order (the structure is specified in the 'io' object.
+  // This may involve not just reordering the provided indexes, but padding them
+  // with indexes that have kNoTime as the time.
+  //
+  // Basically the indexes this function outputs form a grid where 't' has the
+  // larger stride than the (n, x) pairs.   The number of distinct (n, x) pairs
+  // should equal io.num_images.  Where 't' values need to appear in the
+  // new indexes that were not present in the old indexes, they get replaced with
+  // kNoTime.
+  void GetIndexes(
+      const std::vector<Index> &input_indexes,
+      const std::vector<Index> &output_indexes,
+      time_height_convolution::ConvolutionComputationIo &io,
+      std::vector<Index> *new_input_indexes,
+      std::vector<Index> *new_output_indexes) const;
+
+  /// Utility function used in GetIndexes().  Creates a grid of Indexes, where
+  /// 't' has the larger stride, and within each block of Indexes for a given
+  /// 't', we have the given list of (n, x) pairs.  For Indexes that we create
+  /// where the 't' value was not present in 'index_set', we set the 't' value
+  /// to kNoTime (indicating that it's only for padding, not a real input or an
+  /// output that's ever used).
+  static void CreateIndexesVector(
+      const std::vector<std::pair<int32, int32> > &n_x_pairs,
+      int32 t_start, int32 t_step, int32 num_t_values,
+      const std::unordered_set<Index, IndexHasher> &index_set,
+      std::vector<Index> *output_indexes);
+
+
+  void Check() const;
+
+  int32 num_heads_;
+  int32 key_dim_;
+  int32 value_dim_;
+  int32 num_left_inputs_;
+  int32 num_right_inputs_;
+  int32 time_stride_;
+  int32 context_dim_;  // This derived parameter equals 1 + num_left_inputs_ +
+                       // num_right_inputs_.
+  int32 num_left_inputs_required_;
+  int32 num_right_inputs_required_;
+  bool output_context_;
+  BaseFloat key_scale_;
+
+  double stats_count_;  // Count of frames corresponding to the stats.
+  Vector<double> entropy_stats_;  // entropy stats, indexed per head.
+                                  // (dimension is num_heads_).  Divide
+                                  // by stats_count_ to normalize.
+  CuMatrix<double> posterior_stats_;  // stats of posteriors of different
+                                      // offsets, of dimension num_heads_ *
+                                      // context_dim_ (num-heads has the
+                                      // larger stride).
+};
+
+
+
+
+} // namespace nnet3
+} // namespace kaldi
+
+
+#endif
index fa23d4d305ad163d8bfcb6fa26da6f3007bb8a32..8b10f78547dc8ca96a1bc8b6ae1026e40693df63 100644 (file)
@@ -598,5 +598,44 @@ bool HasContiguousProperty(
   return true;
 }
 
+
+// see comment in header.
+void GetNxList(const std::vector<Index> &indexes,
+               std::vector<std::pair<int32, int32> > *pairs) {
+  // set of (n,x) pairs
+  std::unordered_set<std::pair<int32, int32>, PairHasher<int32> > n_x_set;
+
+  for (std::vector<Index>::const_iterator iter = indexes.begin();
+       iter != indexes.end(); ++iter)
+    n_x_set.insert(std::pair<int32, int32>(iter->n, iter->x));
+  pairs->clear();
+  pairs->reserve(n_x_set.size());
+  for (std::unordered_set<std::pair<int32, int32>, PairHasher<int32> >::iterator
+           iter = n_x_set.begin(); iter != n_x_set.end(); ++iter)
+    pairs->push_back(*iter);
+  std::sort(pairs->begin(), pairs->end());
+}
+
+
+// see comment in header.
+void GetTList(const std::vector<Index> &indexes,
+              std::vector<int32> *t_values) {
+  // set of t values
+  std::unordered_set<int32> t_set;
+
+  for (std::vector<Index>::const_iterator iter = indexes.begin();
+       iter != indexes.end(); ++iter)
+    if (iter->t != kNoTime)
+      t_set.insert(iter->t);
+  t_values->clear();
+  t_values->reserve(t_set.size());
+  for (std::unordered_set<int32>::iterator iter = t_set.begin();
+       iter != t_set.end(); ++iter)
+    t_values->push_back(*iter);
+  std::sort(t_values->begin(), t_values->end());
+}
+
+
+
 }  // namespace nnet3
 }  // namespace kaldi
index 8caad5f757c1600f9a008117a401f2fdd60dd752..ee69c2028f41e506a75da5c98da506bcdcae6bf7 100644 (file)
@@ -131,6 +131,31 @@ void EnsureContiguousProperty(
     const std::vector<int32> &indexes,
     std::vector<std::vector<int32> > *indexes_out);
 
+/**
+   This function outputs a sorted, unique list of the 't' values that are
+   encountered in the provided list of Indexes
+   If 't' values equal to kNoTime are encountered, they are ignored and
+   are not output.
+*/
+void GetTList(const std::vector<Index> &indexes,
+              std::vector<int32> *t_values);
+
+
+/**
+   This function outputs a sorted, unique list of the 't' values that are
+   encountered in the provided list of Indexes
+   If 't' values equal to kNoTime are encountered, they are ignored and
+   are not output.
+*/
+void GetTList(const std::vector<Index> &indexes,
+              std::vector<int32> *t_values);
+
+/**
+   This function outputs a unique, lexicographically sorted list of the pairs of
+   (n, x) values that are encountered in the provided list of Indexes.
+*/
+void GetNxList(const std::vector<Index> &indexes,
+               std::vector<std::pair<int32, int32> > *pairs);
 
 
 } // namespace nnet3
index 574c707db108284cc2d213d3fd08f1a875d249cc..3d4e1706aa8e6ccc390eada001538bd60148f013 100644 (file)
 #include "nnet3/nnet-simple-component.h"
 #include "nnet3/nnet-general-component.h"
 #include "nnet3/nnet-convolutional-component.h"
+#include "nnet3/nnet-attention-component.h"
 #include "nnet3/nnet-parse.h"
 #include "nnet3/nnet-computation-graph.h"
 
 
+
 // \file This file contains some more-generic component code: things in base classes.
 //       See nnet-component.cc for the code of the actual Components.
 
@@ -61,6 +63,8 @@ ComponentPrecomputedIndexes* ComponentPrecomputedIndexes::NewComponentPrecompute
     ans = new BackpropTruncationComponentPrecomputedIndexes();
   } else if (cpi_type == "TimeHeightConvolutionComponentPrecomputedIndexes") {
     ans = new TimeHeightConvolutionComponent::PrecomputedIndexes();
+  } else if (cpi_type == "RestrictedAttentionComponentPrecomputedIndexes") {
+    ans = new RestrictedAttentionComponent::PrecomputedIndexes();
   }
   if (ans != NULL) {
     KALDI_ASSERT(cpi_type == ans->Type());
@@ -159,6 +163,8 @@ Component* Component::NewComponentOfType(const std::string &component_type) {
     ans = new BatchNormComponent();
   } else if (component_type == "TimeHeightConvolutionComponent") {
     ans = new TimeHeightConvolutionComponent();
+  } else if (component_type == "RestrictedAttentionComponent") {
+    ans = new RestrictedAttentionComponent();
   } else if (component_type == "SumBlockComponent") {
     ans = new SumBlockComponent();
   }
index 4072ac56eea8868d4e1257c6cbfb8cd0bc191c1b..cd9f87abe309ee59f7e2a63e91d066b9e4a75cf1 100644 (file)
@@ -125,9 +125,13 @@ void TestNnetDecodable(Nnet *nnet) {
   }
 
 
+  // the components that we exclude from this test, are excluded because they
+  // all take "optional" right context, and this destroys the equivalence that
+  // we are testing.
   if (!NnetIsRecurrent(*nnet) &&
       nnet->Info().find("statistics-extraction") == std::string::npos &&
-      nnet->Info().find("TimeHeightConvolutionComponent") == std::string::npos) {
+      nnet->Info().find("TimeHeightConvolutionComponent") == std::string::npos &&
+      nnet->Info().find("RestrictedAttentionComponent") == std::string::npos) {
     // this equivalence will not hold for recurrent nnets, or those that
     // have the statistics-extraction/statistics-pooling layers,
     // or in general for nnets with convolution components (because these
index 524c39807122a4c6a4a307114f43482ee7c48f9b..19eb4c845193d62d8900e5552a3f82cbc2effe64 100644 (file)
@@ -49,7 +49,7 @@ TimeHeightConvolutionComponent::TimeHeightConvolutionComponent(
 }
 
 
-void TimeHeightConvolutionComponent::Check() {
+void TimeHeightConvolutionComponent::Check() const {
   model_.Check();
   KALDI_ASSERT(bias_params_.Dim() == model_.num_filters_out &&
                linear_params_.NumRows() == model_.ParamRows() &&
index 60c92036c31d9ecbad3edde162b470fb3603355a..cceab937b30dc9edf0777e99b7d14578f48638db 100644 (file)
@@ -235,12 +235,6 @@ class TimeHeightConvolutionComponent: public UpdatableComponent {
                         void *memo,
                         Component *to_update,
                         CuMatrixBase<BaseFloat> *in_deriv) const;
-  // This ReorderIndexes function may insert 'blank' indexes (indexes with
-  // t == kNoTime) as well as reordering the indexes.  This is allowed
-  // behavior of ReorderIndexes functions.
-  virtual void ReorderIndexes(std::vector<Index> *input_indexes,
-                              std::vector<Index> *output_indexes) const;
-
 
   virtual void Read(std::istream &is, bool binary);
   virtual void Write(std::ostream &os, bool binary) const;
@@ -250,6 +244,13 @@ class TimeHeightConvolutionComponent: public UpdatableComponent {
 
 
   // Some functions that are only to be reimplemented for GeneralComponents.
+
+  // This ReorderIndexes function may insert 'blank' indexes (indexes with
+  // t == kNoTime) as well as reordering the indexes.  This is allowed
+  // behavior of ReorderIndexes functions.
+  virtual void ReorderIndexes(std::vector<Index> *input_indexes,
+                              std::vector<Index> *output_indexes) const;
+
   virtual void GetInputIndexes(const MiscComputationInfo &misc_info,
                                const Index &output_index,
                                std::vector<Index> *desired_indexes) const;
@@ -297,7 +298,7 @@ class TimeHeightConvolutionComponent: public UpdatableComponent {
   void ScaleLinearParams(BaseFloat alpha) { linear_params_.Scale(alpha); }
  private:
 
-  void Check();
+  void Check() const;
 
   // computes derived parameters required_time_offsets_ and all_time_offsets_.
   void ComputeDerived();
index a138fcaccebcf96aef2962b68a87cc2a4b98c41f..47a040a1789cf19936e577430aa216023b019602 100644 (file)
@@ -1076,6 +1076,63 @@ void GenerateConfigSequenceCnnNew(
 }
 
 
+
+void GenerateConfigSequenceRestrictedAttention(
+    const NnetGenerationOptions &opts,
+    std::vector<std::string> *configs) {
+  std::ostringstream ss;
+
+
+  int32 input_dim = RandInt(100, 150),
+      num_heads = RandInt(1, 2),
+      key_dim = RandInt(20, 40),
+      value_dim = RandInt(20, 40),
+      time_stride = RandInt(1, 3),
+      num_left_inputs = RandInt(1, 4),
+      num_right_inputs = RandInt(0, 2),
+      num_left_inputs_required = RandInt(0, num_left_inputs),
+      num_right_inputs_required = RandInt(0, num_right_inputs);
+  bool output_context = (RandInt(0, 1) == 0);
+  int32 context_dim = (num_left_inputs + 1 + num_right_inputs),
+      query_dim = key_dim + context_dim;
+  int32 attention_input_dim = num_heads * (key_dim + value_dim + query_dim);
+
+  std::string cur_layer_descriptor = "input";
+
+  { // input layer.
+    ss << "input-node name=input dim=" << input_dim
+       << std::endl;
+  }
+
+  { // affine component
+    ss << "component name=affine type=NaturalGradientAffineComponent input-dim="
+       << input_dim << " output-dim=" << attention_input_dim << std::endl;
+    ss << "component-node name=affine component=affine input=input"
+       << std::endl;
+  }
+
+  { // attention component
+    ss << "component-node name=attention component=attention input=affine"
+       << std::endl;
+    ss << "component name=attention type=RestrictedAttentionComponent"
+       << " num-heads=" << num_heads << " key-dim=" << key_dim
+       << " value-dim=" << value_dim << " time-stride=" << time_stride
+       << " num-left-inputs=" << num_left_inputs << " num-right-inputs="
+       << num_right_inputs << " num-left-inputs-required="
+       << num_left_inputs_required << " num-right-inputs-required="
+       << num_right_inputs_required
+       << " output-context=" << (output_context ? "true" : "false")
+       << (RandInt(0, 1) == 0 ? " key-scale=1.0" : "")
+       << std::endl;
+  }
+
+  { // output
+    ss << "output-node name=output input=attention" << std::endl;
+  }
+  configs->push_back(ss.str());
+}
+
+
 // generates a config sequence involving DistributeComponent.
 void GenerateConfigSequenceDistribute(
     const NnetGenerationOptions &opts,
@@ -1212,11 +1269,16 @@ start:
       // We're allocating more case statements to the most recently
       // added type of model, to give more thorough testing where
       // it's needed most.
-    case 12: case 13: case 14:
+    case 12:
       if (!opts.allow_nonlinearity || !opts.allow_context)
         goto start;
       GenerateConfigSequenceCnnNew(opts, configs);
       break;
+    case 13: case 14:
+      if (!opts.allow_nonlinearity || !opts.allow_context)
+        goto start;
+      GenerateConfigSequenceRestrictedAttention(opts, configs);
+      break;
     default:
       KALDI_ERR << "Error generating config sequence.";
   }
index 0019fbef414b9c50483b644ff317eaa37395bb88..4b230b7fdb967a9e5f5d84289fdda39093e85f3d 100644 (file)
@@ -155,7 +155,8 @@ void ComputeSimpleNnetContext(const Nnet &nnet,
 
   // This will crash if the total context (left + right) is greater
   // than window_size.
-  int32 window_size = 150;
+  int32 window_size = 200;
+
   // by going "<= modulus" instead of "< modulus" we do one more computation
   // than we really need; it becomes a sanity check.
   for (int32 input_start = 0; input_start <= modulus; input_start++)