]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - processor-sdk/kaldi.git/commitdiff
chain branch: various changes to get decoding working; various mostly minor fixes...
authorDaniel Povey <dpovey@gmail.com>
Sun, 22 Nov 2015 01:05:05 +0000 (20:05 -0500)
committerDaniel Povey <dpovey@gmail.com>
Sun, 22 Nov 2015 01:05:05 +0000 (20:05 -0500)
14 files changed:
egs/swbd/s5c/local/chain/run_tdnn_a.sh
egs/swbd/s5c/local/score_sclite.sh
egs/wsj/s5/steps/nnet3/chain/train_tdnn.sh
egs/wsj/s5/steps/nnet3/decode.sh
src/chain/chain-den-graph.cc
src/chain/chain-kernels.cu
src/chain/chain-supervision.cc
src/chain/language-model.cc
src/chainbin/chain-est-phone-lm.cc
src/lat/word-align-lattice.cc
src/nnet3/nnet-am-decodable-simple.cc
src/nnet3/nnet-am-decodable-simple.h
src/nnet3/nnet-cctc-decodable-simple.cc
src/nnet3/nnet-simple-component.cc

index 3f89db075a653b7c9f234aba530e7fc3366f9543..da68a6bad67d4c48bb159957caa42077b851d728 100755 (executable)
@@ -119,7 +119,8 @@ fi
 
 if [ $stage -le 13 ]; then
   # Note: it might appear that this $lang directory is mismatched, and it is as
-  # far as the 'topo'
+  # far as the 'topo' is concerned, but this script doesn't read the 'topo' from
+  # the lang directory.
   utils/mkgraph.sh --transition-scale 0.0 \
       --self-loop-scale 0.0 data/lang_sw1_tg $dir $dir/graph_sw1_tg
 fi
@@ -129,7 +130,7 @@ graph_dir=$dir/graph_sw1_tg
 if [ $stage -le 14 ]; then
   for decode_set in train_dev eval2000; do
       (
-      steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 --iter 298_cached \
+      steps/nnet3/decode.sh --acwt 1.0 --post-decode-acwt 10.0 \
           --nj 50 --cmd "$decode_cmd" \
           --online-ivector-dir exp/nnet3/ivectors_${decode_set} \
          $graph_dir data/${decode_set}_hires $dir/decode_${decode_set}_${decode_suff} || exit 1;
@@ -138,7 +139,7 @@ if [ $stage -le 14 ]; then
             data/lang_sw1_{tg,fsh_fg} data/${decode_set}_hires \
             $dir/decode_${decode_set}_sw1_{tg,fsh_fg} || exit 1;
       fi
-      ) &
+      ) &
   done
 fi
 wait;
index 2031e472fe534d9a02db05b18f2c79c8ec503ec4..556afbf02bc11892fc812a7b14cceec6faba90a9 100755 (executable)
@@ -51,18 +51,16 @@ nnet3-ctc-info --print-args=false $model  1>/dev/null 2>&1;
 [ $? -eq 0 ] && is_ctc=true;
 [ -z $is_ctc ] && echo "Unknown model type, verify if $model exists" && exit -1;
 align_word=
-reorder=
+reorder_opt=
 if $reverse; then
   align_word="lattice-reverse ark:- ark:- |"
-  reorder="--reorder=false"
+  reorder_opt="--reorder=false"
 fi
 
 if $is_ctc ; then
-  echo "Warning : This is a CTC model, using corresponding scoring pipeline."
+  echo "Warning : This is a 'chain' model, using corresponding scoring pipeline."
   factor=$(cat $dir/../frame_subsampling_factor) || exit 1
   frame_shift_opt="--frame-shift=0.0$factor"
-else
-  align_word="$align_word lattice-align-words $reorder $lang/phones/word_boundary.int $model ark:- ark:- |"
 fi
 
 name=`basename $data`; # e.g. eval2000
@@ -75,7 +73,8 @@ if [ $stage -le 0 ]; then
       mkdir -p $dir/score_LMWT_${wip}/ '&&' \
       lattice-scale --lm-scale=LMWT "ark:gunzip -c $dir/lat.*.gz|" ark:- \| \
       lattice-add-penalty --word-ins-penalty=$wip ark:- ark:- \| \
-      lattice-1best ark:- ark:- \| $align_word \
+      lattice-1best ark:- ark:- \| \
+      lattice-align-words $reorder_opt $lang/phones/word_boundary.int $model ark:- ark:- \| \
       nbest-to-ctm $frame_shift_opt ark:- - \| \
       utils/int2sym.pl -f 5 $lang/words.txt  \| \
       utils/convert_ctm.pl $data/segments $data/reco2file_and_channel \
index 0f24bebdc996a2dabf583c7f8a44d516c534b851..0ef570209704e8b987a23f2c3cce7b0b313ab8dc 100755 (executable)
@@ -182,7 +182,7 @@ if [ $stage -le -6 ]; then
   echo "$0: creating denominator FST"
   copy-transition-model $treedir/final.mdl $dir/0.trans_mdl
   $cmd $dir/log/make_den_fst.log \
-    chain-make-den-graph $dir/tree $dir/0.trans_mdl $dir/phone_lm.fst \
+    chain-make-den-fst $dir/tree $dir/0.trans_mdl $dir/phone_lm.fst \
        $dir/den.fst $dir/normalization.fst || exit 1;
 fi
 
index 17133fc88de7de2d3a523b941e01815ff2e791c6..f4de09740ae7d14937cd28b2fa5511a73d048a4b 100755 (executable)
@@ -128,7 +128,7 @@ fi
 
 if [ ! -z "$online_ivector_dir" ]; then
   ivector_period=$(cat $online_ivector_dir/ivector_period) || exit 1;
-  ivector_opts="--online-ivectors=scp:$online_ivector_dir/ivector_online.scp --online-ivector_period=$ivector_period"
+  ivector_opts="--online-ivectors=scp:$online_ivector_dir/ivector_online.scp --online-ivector-period=$ivector_period"
 fi
 
 if [ "$post_decode_acwt" == 1.0 ]; then
@@ -137,9 +137,14 @@ else
   lat_wspecifier="ark:|lattice-scale --acoustic-scale=$post_decode_acwt ark:- ark:- | gzip -c >$dir/lat.JOB.gz"
 fi
 
+if [ -f $srcdir/frame_subsampling_factor ]; then
+  # e.g. for 'chain' systems
+  frame_subsampling_opt="--frame-subsampling-factor=$(cat $srcdir/frame_subsampling_factor)"
+fi
+
 if [ $stage -le 1 ]; then
   $cmd --num-threads $num_threads JOB=1:$nj $dir/log/decode.JOB.log \
-    nnet3-latgen-faster$thread_string $ivector_opts \
+    nnet3-latgen-faster$thread_string $ivector_opts $frame_subsampling_opt \
      --frames-per-chunk=$frames_per_chunk \
      --minimize=$minimize --max-active=$max_active --min-active=$min_active --beam=$beam \
      --lattice-beam=$lattice_beam --acoustic-scale=$acwt --allow-partial=true \
index 68a7c88e8888b4153270e763baf593cfb999e16c..1a8b6219a41a689914c3d9b61e79b69fec65fb9c 100644 (file)
@@ -291,7 +291,6 @@ void CreateDenominatorFst(const ContextDependency &ctx_dep,
   KALDI_LOG << "Number of states and arcs in phone-LM FST is "
             << phone_lm.NumStates() << " and " << NumArcs(phone_lm);
 
-
   int32 subsequential_symbol = trans_model.GetPhones().back() + 1;
   if (ctx_dep.CentralPosition() != ctx_dep.ContextWidth() - 1) {
     // note: this function only adds the subseq symbol to the input of what was
@@ -316,14 +315,14 @@ void CreateDenominatorFst(const ContextDependency &ctx_dep,
 
   std::vector<int32> disambig_syms_h; // disambiguation symbols on input side
   // of H -- will be empty.
-  HTransducerConfig h_cfg;
-  h_cfg.transition_scale = 0.0;  // we don't want transition probs.
-  h_cfg.push_weights = false;  // there's nothing to push.
+  HTransducerConfig h_config;
+  h_config.transition_scale = 0.0;  // we don't want transition probs.
+  h_config.push_weights = false;  // there's nothing to push.
 
   StdVectorFst *h_fst = GetHTransducer(cfst.ILabelInfo(),
                                        ctx_dep,
                                        trans_model,
-                                       h_cfg,
+                                       h_config,
                                        &disambig_syms_h);
   KALDI_ASSERT(disambig_syms_h.empty());
   StdVectorFst transition_id_fst;
index 2330f2cc31587ef4b79cfdd57202427cdd709aa1..ca7c8faa7922f8454f4ca3b064dc52131c358f10 100644 (file)
@@ -40,9 +40,9 @@ __device__ inline void atomic_add_thresholded(Real* address, Real value) {
   // threshold itself with probability (value / threshold).  This preserves
   // expectations.  Note: we assume that value >= 0.
 
-  // you can chose any value for the threshold, but powers of 2 are nice
+  // you can choose any value for the threshold, but powers of 2 are nice
   // because they will exactly preserve the precision of the value.
-  const Real threshold = 1.0 / (1 << 16);
+  const Real threshold = 1.0 / (1 << 14);
   if (value >= threshold) {
     atomic_add(address, value);
   } else {
index ab3b3e2f3d7002c7f1250eba333fe89d12a3360a..d6d2412d568ddb916b47c86a77db3b21dcb9be95 100644 (file)
@@ -58,7 +58,7 @@ bool AlignmentToProtoSupervision(const SupervisionOptions &opts,
   std::vector<int32> labels(phones.size());
   int32 num_frames = std::accumulate(durations.begin(), durations.end(), 0),
       factor = opts.frame_subsampling_factor,
-      num_frames_subsampled = num_frames / factor;
+      num_frames_subsampled = (num_frames + factor - 1) / factor;
   proto_supervision->allowed_phones.clear();
   proto_supervision->allowed_phones.resize(num_frames_subsampled);
   proto_supervision->fst.DeleteStates();
@@ -69,14 +69,15 @@ bool AlignmentToProtoSupervision(const SupervisionOptions &opts,
   for (int32 i = 0; i < num_phones; i++) {
     int32 phone = phones[i], duration = durations[i];
     KALDI_ASSERT(phone > 0 && duration > 0);
-    int32 t_start_subsampled =
-        std::max<int32>(0,
-                        (current_frame - opts.left_tolerance) / factor),
-        t_end_subsampled = std::min<int32>(
-            num_frames_subsampled,
-            (current_frame + duration + opts.right_tolerance) / factor);
+    int32 t_start = std::max<int32>(0, (current_frame - opts.left_tolerance)),
+            t_end = std::min<int32>(num_frames,
+                                    (current_frame + duration + opts.right_tolerance)),
+       t_start_subsampled = (t_start + factor - 1) / factor,
+       t_end_subsampled = (t_end + factor - 1) / factor;
+
     // note: if opts.Check() passed, the following assert should pass too.
-    KALDI_ASSERT(t_end_subsampled > t_start_subsampled);
+    KALDI_ASSERT(t_end_subsampled > t_start_subsampled &&
+                 t_end_subsampled <= num_frames_subsampled);
     for (int32 t_subsampled = t_start_subsampled;
          t_subsampled < t_end_subsampled; t_subsampled++)
       proto_supervision->allowed_phones[t_subsampled].push_back(phone);
@@ -127,13 +128,7 @@ bool PhoneLatticeToProtoSupervision(const SupervisionOptions &opts,
   std::vector<int32> state_times;
   int32 num_frames = CompactLatticeStateTimes(lat, &state_times),
       factor = opts.frame_subsampling_factor,
-      num_frames_subsampled = num_frames / factor;
-  if (num_frames < opts.frame_subsampling_factor) {
-    KALDI_WARN << "Number of frames in lattice " << num_frames
-               << " is less than --frame-subsampling-factor="
-               << opts.frame_subsampling_factor;
-    return false;
-  }
+    num_frames_subsampled = (num_frames + factor - 1) / factor;
   for (int32 state = 0; state < num_states; state++)
     proto_supervision->fst.AddState();
   proto_supervision->fst.SetStart(lat.Start());
@@ -156,12 +151,11 @@ bool PhoneLatticeToProtoSupervision(const SupervisionOptions &opts,
                                     fst::StdArc(phone, phone,
                                                 fst::TropicalWeight::One(),
                                                 lat_arc.nextstate));
-      int32 t_begin_subsampled =
-          std::max<int32>(0,
-                          (state_time - opts.left_tolerance) / factor),
-          t_end_subsampled = std::min<int32>(
-              num_frames_subsampled,
-              (next_state_time + opts.right_tolerance) / factor);
+      int32 t_begin = std::max<int32>(0, (state_time - opts.left_tolerance)),
+              t_end = std::min<int32>(num_frames,
+                                      (next_state_time + opts.right_tolerance)),
+ t_begin_subsampled = (t_begin + factor - 1)/ factor,
+   t_end_subsampled = (t_end + factor - 1)/ factor;
     for (int32 t_subsampled = t_begin_subsampled;
          t_subsampled < t_end_subsampled; t_subsampled++)
       proto_supervision->allowed_phones[t_subsampled].push_back(phone);
index 90c2fa8d90033d7c0d762fc02e5da20f3488d6e2..61082a1c659e11f54bb4d60eeb0fc5c8a93006d0 100644 (file)
@@ -293,6 +293,7 @@ void LanguageModelEstimator::OutputToFst(
   int64 tot_den = std::accumulate(den_counts.begin(),
                                   den_counts.end(), 0),
       tot_num = 0;  // for self-testing code.
+  double tot_logprob = 0.0;
 
   PairMapType::const_iterator
       iter = num_counts.begin(), end = num_counts.end();
@@ -304,6 +305,7 @@ void LanguageModelEstimator::OutputToFst(
     int32 den_count = den_counts[this_state];
     KALDI_ASSERT(den_count >= num_count);
     BaseFloat prob = num_count / static_cast<BaseFloat>(den_count);
+    tot_logprob += num_count * log(prob);
     if (phone > 0) {
       // it's a real phone.  find out where the transition is to.
       PairMapType::const_iterator
@@ -320,9 +322,15 @@ void LanguageModelEstimator::OutputToFst(
   }
   KALDI_ASSERT(tot_num == tot_den);
   KALDI_LOG << "Total number of phone instances seen was " << tot_num;
+  BaseFloat perplexity = exp(-(tot_logprob / tot_num));
+  KALDI_LOG << "Perplexity on training data is: " << perplexity;
+  KALDI_LOG << "Note: perplexity on unseen data will be infinity as there is "
+            << "no smoothing.  This is by design, to reduce the number of arcs.";
   fst::Connect(fst);
   // Make sure that Connect does not delete any states.
   KALDI_ASSERT(fst->NumStates() == num_states);
+  // arc-sort.  ilabel or olabel doesn't matter, it's an acceptor.
+  fst::ArcSort(fst, fst::ILabelCompare<fst::StdArc>());
   KALDI_LOG << "Created phone language model with " << num_states << " states.";
 }
 
index b5936501d87e07861b5570c96acd3e42590cd7cb..0d5380649830182b3890ec5678a67df95c65fd63 100644 (file)
@@ -64,6 +64,7 @@ int main(int argc, char *argv[]) {
       const std::vector<int32> &phone_seq = phones_reader.Value();
       lm_estimator.AddCounts(phone_seq);
     }
+    KALDI_LOG << "Estimating phone LM";
     fst::StdVectorFst fst;
     lm_estimator.Estimate(&fst);
 
index 51886e810f80657a9710023f7404d203d5b05444..db22c0c85e47941882c46e47c12b37406e203e0c 100644 (file)
@@ -28,12 +28,12 @@ class LatticeWordAligner {
  public:
   typedef CompactLatticeArc::StateId StateId;
   typedef CompactLatticeArc::Label Label;
-  
+
   class ComputationState { /// The state of the computation in which,
     /// along a single path in the lattice, we work out the word
     /// boundaries and output aligned arcs.
    public:
-    
+
     /// Advance the computation state by adding the symbols and weights
     /// from this arc.  We'll put the weight on the output arc; this helps
     /// keep the state-space smaller.
@@ -71,18 +71,18 @@ class LatticeWordAligner {
     bool OutputSilenceArc(const WordBoundaryInfo &info,
                           const TransitionModel &tmodel,
                           CompactLatticeArc *arc_out,
-                          bool *error); 
+                          bool *error);
     bool OutputOnePhoneWordArc(const WordBoundaryInfo &info,
                                const TransitionModel &tmodel,
                                CompactLatticeArc *arc_out,
-                               bool *error); 
+                               bool *error);
     bool OutputNormalWordArc(const WordBoundaryInfo &info,
                              const TransitionModel &tmodel,
                              CompactLatticeArc *arc_out,
                              bool *error);
-    
+
     bool IsEmpty() { return (transition_ids_.empty() && word_labels_.empty()); }
-    
+
     /// FinalWeight() will return "weight" if both transition_ids
     /// and word_labels are empty, otherwise it will return
     /// Weight::Zero().
@@ -104,7 +104,7 @@ class LatticeWordAligner {
                         const TransitionModel &tmodel,
                         CompactLatticeArc *arc_out,
                         bool *error);
-  
+
     size_t Hash() const {
       VectorHasher<int32> vh;
       return vh(transition_ids_) + 90647 * vh(word_labels_);
@@ -121,7 +121,7 @@ class LatticeWordAligner {
               && word_labels_ == other.word_labels_
               && weight_ == other.weight_);
     }
-    
+
     ComputationState(): weight_(LatticeWeight::One()) { } // initial state.
     ComputationState(const ComputationState &other):
         transition_ids_(other.transition_ids_), word_labels_(other.word_labels_),
@@ -143,7 +143,7 @@ class LatticeWordAligner {
   struct TupleHash {
     size_t operator() (const Tuple &state) const {
       return state.input_state + 102763 * state.comp_state.Hash();
-      // 102763 is just an arbitrary prime number 
+      // 102763 is just an arbitrary prime number
     }
   };
   struct TupleEqual {
@@ -153,7 +153,7 @@ class LatticeWordAligner {
               && state1.comp_state == state2.comp_state);
     }
   };
-  
+
   typedef unordered_map<Tuple, StateId, TupleHash, TupleEqual> MapType;
 
   StateId GetStateForTuple(const Tuple &tuple, bool add_to_queue) {
@@ -168,17 +168,17 @@ class LatticeWordAligner {
       return iter->second;
     }
   }
-  
+
   void ProcessFinal(Tuple tuple, StateId output_state) {
     // ProcessFinal is only called if the input_state has
     // final-prob of One().  [else it should be zero.  This
     // is because we called CreateSuperFinal().]
-    
+
     if (tuple.comp_state.IsEmpty()) { // computation state doesn't have
       // anything pending.
       std::vector<int32> empty_vec;
       CompactLatticeWeight cw(tuple.comp_state.FinalWeight(), empty_vec);
-      lat_out_->SetFinal(output_state, Plus(lat_out_->Final(output_state), cw));      
+      lat_out_->SetFinal(output_state, Plus(lat_out_->Final(output_state), cw));
     } else {
       // computation state has something pending, i.e. input or
       // output symbols that need to be flushed out.  Note: OutputArc() would
@@ -197,7 +197,7 @@ class LatticeWordAligner {
     }
   }
 
-  
+
   void ProcessQueueElement() {
     KALDI_ASSERT(!queue_.empty());
     Tuple tuple = queue_.back().first;
@@ -248,7 +248,7 @@ class LatticeWordAligner {
       }
     }
   }
-  
+
   LatticeWordAligner(const CompactLattice &lat,
                      const TransitionModel &tmodel,
                      const WordBoundaryInfo &info,
@@ -266,7 +266,7 @@ class LatticeWordAligner {
     }
     fst::CreateSuperFinal(&lat_); // Creates a super-final state, so the
     // only final-probs are One().
-    
+
     // Inside this class, we don't want to use zero for the silence
     // or partial-word labels, as this will interfere with the RmEpsilon
     // stage, where we don't want the arcs corresponding to silence or
@@ -296,10 +296,10 @@ class LatticeWordAligner {
       syms_to_remove.push_back(info_.silence_label);
     if (!syms_to_remove.empty()) {
       RemoveSomeInputSymbols(syms_to_remove, lat_out_);
-      Project(lat_out_, fst::PROJECT_INPUT);      
+      Project(lat_out_, fst::PROJECT_INPUT);
     }
   }
-  
+
   bool AlignLattice() {
     lat_out_->DeleteStates();
     if (lat_.Start() == fst::kNoStateId) {
@@ -310,7 +310,7 @@ class LatticeWordAligner {
     Tuple initial_tuple(lat_.Start(), initial_comp_state);
     StateId start_state = GetStateForTuple(initial_tuple, true); // True = add this to queue.
     lat_out_->SetStart(start_state);
-    
+
     while (!queue_.empty()) {
       if (max_states_ > 0 && lat_out_->NumStates() > max_states_) {
         KALDI_WARN << "Number of states in lattice exceeded max-states of "
@@ -323,10 +323,10 @@ class LatticeWordAligner {
     }
 
     RemoveEpsilonsFromLattice();
-    
+
     return !error_;
   }
-  
+
   CompactLattice lat_;
   const TransitionModel &tmodel_;
   const WordBoundaryInfo &info_in_;
@@ -335,12 +335,12 @@ class LatticeWordAligner {
   CompactLattice *lat_out_;
 
   std::vector<std::pair<Tuple, StateId> > queue_;
-  
-  
-  
+
+
+
   MapType map_; // map from tuples to StateId.
   bool error_;
-  
+
 };
 
 bool LatticeWordAligner::ComputationState::OutputSilenceArc(
@@ -355,7 +355,7 @@ bool LatticeWordAligner::ComputationState::OutputSilenceArc(
   size_t len = transition_ids_.size(), i;
   // Keep going till we reach a "final" transition-id; note, if
   // reorder==true, we have to go a bit further after this.
-  for (i = 1; i < len; i++) {
+  for (i = 0; i < len; i++) {
     int32 tid = transition_ids_[i];
     int32 this_phone = tmodel.TransitionIdToPhone(tid);
     if (this_phone != phone && ! *error) { // error condition: should have reached final transition-id first.
@@ -379,7 +379,7 @@ bool LatticeWordAligner::ComputationState::OutputSilenceArc(
   }
   // interpret i as the number of transition-ids to consume.
   std::vector<int32> tids_out(transition_ids_.begin(), transition_ids_.begin()+i);
-  
+
   // consumed transition ids from our internal state.
   *arc_out = CompactLatticeArc(info.silence_label, info.silence_label,
                                CompactLatticeWeight(weight_, tids_out), fst::kNoStateId);
@@ -396,11 +396,11 @@ bool LatticeWordAligner::ComputationState::OutputOnePhoneWordArc(
   if (word_labels_.empty()) return false;
   int32 phone = tmodel.TransitionIdToPhone(transition_ids_[0]);
   if (info.TypeOfPhone(phone) != WordBoundaryInfo::kWordBeginAndEndPhone)
-    return false;  
+    return false;
   // we assume the start of transition_ids_ is the start of the phone.
   // this is a precondition.
   size_t len = transition_ids_.size(), i;
-  for (i = 1; i < len; i++) {
+  for (i = 0; i < len; i++) {
     int32 tid = transition_ids_[i];
     int32 this_phone = tmodel.TransitionIdToPhone(tid);
     if (this_phone != phone && ! *error) { // error condition: should have reached final transition-id first.
@@ -416,17 +416,17 @@ bool LatticeWordAligner::ComputationState::OutputOnePhoneWordArc(
   if (info.reorder) // we have to consume the following self-loop transition-ids.
     while (i < len && tmodel.IsSelfLoop(transition_ids_[i])) i++;
   if (i == len) return false; // we don't know if it ends here... so can't output arc.
-  
+
   if (tmodel.TransitionIdToPhone(transition_ids_[i-1]) != phone
       && ! *error) { // another check.
     KALDI_WARN << "Phone changed unexpectedly in lattice "
         "[broken lattice or mismatched model?]";
     *error = true;
   }
-  
+
   // interpret i as the number of transition-ids to consume.
   std::vector<int32> tids_out(transition_ids_.begin(), transition_ids_.begin()+i);
-  
+
   // consumed transition ids from our internal state.
   int32 word = word_labels_[0];
   *arc_out = CompactLatticeArc(word, word,
@@ -447,7 +447,7 @@ bool LatticeWordAligner::ComputationState::OutputNormalWordArc(
   if (word_labels_.empty()) return false;
   int32 begin_phone = tmodel.TransitionIdToPhone(transition_ids_[0]);
   if (info.TypeOfPhone(begin_phone) != WordBoundaryInfo::kWordBeginPhone)
-    return false;  
+    return false;
   // we assume the start of transition_ids_ is the start of the phone.
   // this is a precondition.
   size_t len = transition_ids_.size(), i;
@@ -488,7 +488,7 @@ bool LatticeWordAligner::ComputationState::OutputNormalWordArc(
   // a "final-transition".
 
   // this variable just used for checks.
-  int32 final_phone = tmodel.TransitionIdToPhone(transition_ids_[i]); 
+  int32 final_phone = tmodel.TransitionIdToPhone(transition_ids_[i]);
   for (; i < len; i++) {
     int32 this_phone = tmodel.TransitionIdToPhone(transition_ids_[i]);
     if (this_phone != final_phone && ! *error) {
@@ -515,7 +515,7 @@ bool LatticeWordAligner::ComputationState::OutputNormalWordArc(
   // OK, we're ready to output the word.
   // Interpret i as the number of transition-ids to consume.
   std::vector<int32> tids_out(transition_ids_.begin(), transition_ids_.begin()+i);
-  
+
   // consumed transition ids from our internal state.
   int32 word = word_labels_[0];
   *arc_out = CompactLatticeArc(word, word,
@@ -550,7 +550,7 @@ static bool IsPlausibleWord(const WordBoundaryInfo &info,
   } else return false;
 }
 
-    
+
 void LatticeWordAligner::ComputationState::OutputArcForce(
     const WordBoundaryInfo &info, const TransitionModel &tmodel,
     CompactLatticeArc *arc_out,  bool *error) {
@@ -560,7 +560,7 @@ void LatticeWordAligner::ComputationState::OutputArcForce(
       && !transition_ids_.empty()) { // We have at least one word to
     // output, and some transition-ids.  We assume that the normal OutputArc was called
     // and failed, so this means we didn't see the end of that
-    // word. 
+    // word.
     int32 word = word_labels_[0];
     if (! *error && !IsPlausibleWord(info, tmodel, transition_ids_)) {
       *error = true;
@@ -686,7 +686,7 @@ WordBoundaryInfo::WordBoundaryInfo(const WordBoundaryInfoNewOpts &opts,
 void WordBoundaryInfo::Init(std::istream &stream) {
   std::string line;
   while (std::getline(stream, line)) {
-    std::vector<std::string> split_line;  
+    std::vector<std::string> split_line;
     SplitStringToVector(line, " \t\r", true, &split_line);// split the line by space or tab
     int32 p = 0;
     if (split_line.size() != 2 ||
@@ -701,13 +701,13 @@ void WordBoundaryInfo::Init(std::istream &stream) {
     else if (t == "singleton") phone_to_type[p] = kWordBeginAndEndPhone;
     else if (t == "end") phone_to_type[p] = kWordEndPhone;
     else if (t == "internal") phone_to_type[p] = kWordInternalPhone;
-    else 
+    else
       KALDI_ERR << "Invalid line in word-boundary file: " << line;
   }
   if (phone_to_type.empty())
     KALDI_ERR << "Empty word-boundary file";
 }
-  
+
 bool WordAlignLattice(const CompactLattice &lat,
                       const TransitionModel &tmodel,
                       const WordBoundaryInfo &info,
@@ -726,7 +726,7 @@ class WordAlignedLatticeTester {
                            const WordBoundaryInfo &info,
                            const CompactLattice &aligned_lat):
       lat_(lat), tmodel_(tmodel), info_(info), aligned_lat_(aligned_lat) { }
-  
+
   void Test() {
     // First test that each aligned arc is valid.
     typedef CompactLattice::StateId StateId ;
@@ -766,7 +766,7 @@ class WordAlignedLatticeTester {
       return false;
     for (size_t i = 0; i < tids.size(); i++)
       if (tmodel_.TransitionIdToPhone(tids[i]) != first_phone) return false;
-      
+
     if (!info_.reorder) return tmodel_.IsFinal(tids.back());
     else {
       for (size_t i = 0; i < tids.size(); i++) {
@@ -794,7 +794,7 @@ class WordAlignedLatticeTester {
         WordBoundaryInfo::kWordBeginAndEndPhone) return false;
     for (size_t i = 0; i < tids.size(); i++)
       if (tmodel_.TransitionIdToPhone(tids[i]) != first_phone) return false;
-      
+
     if (!info_.reorder) return tmodel_.IsFinal(tids.back());
     else {
       for (size_t i = 0; i < tids.size(); i++) {
@@ -871,7 +871,7 @@ class WordAlignedLatticeTester {
     if (tids.empty()) return false;
     return true; // We're pretty liberal when it comes to partial words here.
   }
-  
+
   void TestFinal(const CompactLatticeWeight &w) {
     if (!w.String().empty())
       KALDI_ERR << "Expect to have no strings on final-weights of lattices.";
@@ -890,14 +890,14 @@ class WordAlignedLatticeTester {
       KALDI_ERR << "Equivalence test failed (testing word-alignment of lattices.) "
                 << "Make sure your model and lattices match!";
   }
-  
+
   const CompactLattice &lat_;
   const TransitionModel &tmodel_;
   const WordBoundaryInfo &info_;
   const CompactLattice &aligned_lat_;
 };
-  
-  
+
+
 
 
 /// You should only test a lattice if WordAlignLattice returned true (i.e. it
index ae1ada946c8d870df0ef2097690bd79b6c98b37d..00dcde2047cf14a7c6943c6d1d77f1c89c03e01b 100644 (file)
@@ -40,14 +40,17 @@ NnetDecodableBase::NnetDecodableBase(
     ivector_(ivector), online_ivector_feats_(online_ivectors),
     online_ivector_period_(online_ivector_period),
     compiler_(nnet_, opts_.optimize_config),
-    current_log_post_offset_(0) {
+    current_log_post_subsampled_offset_(0) {
+  num_subsampled_frames_ =
+      (feats_.NumRows() + opts_.frame_subsampling_factor - 1) /
+      opts_.frame_subsampling_factor;
   KALDI_ASSERT(IsSimpleNnet(nnet));
   ComputeSimpleNnetContext(nnet, &nnet_left_context_, &nnet_right_context_);
   KALDI_ASSERT(!(ivector != NULL && online_ivectors != NULL));
   KALDI_ASSERT(!(online_ivectors != NULL && online_ivector_period <= 0 &&
                  "You need to set the --online-ivector-period option!"));
   log_priors_.ApplyLog();
-  PossiblyWarnForFramesPerChunk();
+  CheckAndFixConfigs();
 }
 
 
@@ -83,9 +86,9 @@ int32 NnetDecodableBase::GetIvectorDim() const {
     return 0;
 }
 
-void NnetDecodableBase::EnsureFrameIsComputed(int32 frame) {
-  KALDI_ASSERT(frame >= 0 && frame  < feats_.NumRows());
-
+void NnetDecodableBase::EnsureFrameIsComputed(int32 subsampled_frame) {
+  KALDI_ASSERT(subsampled_frame >= 0 &&
+               subsampled_frame < num_subsampled_frames_);
   int32 feature_dim = feats_.NumCols(),
       ivector_dim = GetIvectorDim(),
       nnet_input_dim = nnet_.InputDim("input"),
@@ -98,30 +101,44 @@ void NnetDecodableBase::EnsureFrameIsComputed(int32 frame) {
     KALDI_ERR << "Neural net expects 'ivector' features with dimension "
               << nnet_ivector_dim << " but you provided " << ivector_dim;
 
-  int32 current_frames_computed = current_log_post_.NumRows(),
-      current_offset = current_log_post_offset_;
-  KALDI_ASSERT(frame < current_offset ||
-               frame >= current_offset + current_frames_computed);
-  // allow the output to be computed for frame 0 ... num_input_frames - 1.
-  int32 start_output_frame = frame,
-      num_output_frames = std::min<int32>(feats_.NumRows() - start_output_frame,
-                                          opts_.frames_per_chunk);
-  KALDI_ASSERT(num_output_frames > 0);
+  int32 current_subsampled_frames_computed = current_log_post_.NumRows(),
+      current_subsampled_offset = current_log_post_subsampled_offset_;
+  KALDI_ASSERT(subsampled_frame < current_subsampled_offset ||
+               subsampled_frame >= current_subsampled_offset +
+                                   current_subsampled_frames_computed);
+
+  // all subsampled frames pertain to the output of the network,
+  // they are output frames divided by opts_.frame_subsampling_factor.
+  int32 subsampling_factor = opts_.frame_subsampling_factor,
+      subsampled_frames_per_chunk = opts_.frames_per_chunk / subsampling_factor,
+      start_subsampled_frame = subsampled_frame,
+     num_subsampled_frames = std::min<int32>(num_subsampled_frames_ -
+                                             start_subsampled_frame,
+                                             subsampled_frames_per_chunk),
+      last_subsampled_frame = start_subsampled_frame + num_subsampled_frames - 1;
+  KALDI_ASSERT(num_subsampled_frames > 0);
+  // the output-frame numbers are the subsampled-frame numbers
+  int32 first_output_frame = start_subsampled_frame * subsampling_factor,
+      last_output_frame = last_subsampled_frame * subsampling_factor;
+
   KALDI_ASSERT(opts_.extra_left_context >= 0);
   int32 left_context = nnet_left_context_ + opts_.extra_left_context;
-  int32 first_input_frame = start_output_frame - left_context,
-      num_input_frames = nnet_left_context_ + num_output_frames +
-      nnet_right_context_;
+  int32 first_input_frame = first_output_frame - left_context,
+      last_input_frame = last_output_frame + nnet_right_context_,
+      num_input_frames = last_input_frame + 1 - first_input_frame;
+
   Vector<BaseFloat> ivector;
-  GetCurrentIvector(start_output_frame, num_output_frames, &ivector);
+  GetCurrentIvector(first_output_frame,
+                    last_output_frame - first_output_frame,
+                    &ivector);
 
   Matrix<BaseFloat> input_feats;
   if (first_input_frame >= 0 &&
-      first_input_frame + num_input_frames <= feats_.NumRows()) {
+      last_input_frame < feats_.NumRows()) {
     SubMatrix<BaseFloat> input_feats(feats_.RowRange(first_input_frame,
                                                      num_input_frames));
     DoNnetComputation(first_input_frame, input_feats, ivector,
-                      start_output_frame, num_output_frames);
+                      first_output_frame, num_subsampled_frames);
   } else {
     Matrix<BaseFloat> feats_block(num_input_frames, feats_.NumCols());
     int32 tot_input_feats = feats_.NumRows();
@@ -134,21 +151,25 @@ void NnetDecodableBase::EnsureFrameIsComputed(int32 frame) {
       dest.CopyFromVec(src);
     }
     DoNnetComputation(first_input_frame, feats_block, ivector,
-                      start_output_frame, num_output_frames);
+                      first_output_frame, num_subsampled_frames);
   }
 }
 
-void NnetDecodableBase::GetOutputForFrame(int32 frame,
+// note: in the normal case (with no frame subsampling) you can ignore the
+// 'subsampled_' in the variable name.
+void NnetDecodableBase::GetOutputForFrame(int32 subsampled_frame,
                                           VectorBase<BaseFloat> *output) {
-  if (frame < current_log_post_offset_ ||
-      frame >= current_log_post_offset_ + current_log_post_.NumRows())
-    EnsureFrameIsComputed(frame);
-  output->CopyFromVec(current_log_post_.Row(frame - current_log_post_offset_));
+  if (subsampled_frame < current_log_post_subsampled_offset_ ||
+      subsampled_frame >= current_log_post_subsampled_offset_ +
+      current_log_post_.NumRows())
+    EnsureFrameIsComputed(subsampled_frame);
+  output->CopyFromVec(current_log_post_.Row(
+      subsampled_frame - current_log_post_subsampled_offset_));
 }
 
 void NnetDecodableBase::GetCurrentIvector(int32 output_t_start,
-                                              int32 num_output_frames,
-                                              Vector<BaseFloat> *ivector) {
+                                          int32 num_output_frames,
+                                          Vector<BaseFloat> *ivector) {
   if (ivector_ != NULL) {
     *ivector = *ivector_;
     return;
@@ -185,7 +206,7 @@ void NnetDecodableBase::DoNnetComputation(
     const MatrixBase<BaseFloat> &input_feats,
     const VectorBase<BaseFloat> &ivector,
     int32 output_t_start,
-    int32 num_output_frames) {
+    int32 num_subsampled_frames) {
   ComputationRequest request;
   request.need_model_derivative = false;
   request.store_component_stats = false;
@@ -205,9 +226,17 @@ void NnetDecodableBase::DoNnetComputation(
     indexes.push_back(Index(0, 0, 0));
     request.inputs.push_back(IoSpecification("ivector", indexes));
   }
-  request.outputs.push_back(
-      IoSpecification("output", time_offset + output_t_start,
-                      time_offset + output_t_start + num_output_frames));
+  IoSpecification output_spec;
+  output_spec.name = "output";
+  output_spec.has_deriv = false;
+  int32 subsample = opts_.frame_subsampling_factor;
+  output_spec.indexes.resize(num_subsampled_frames);
+  // leave n and x values at 0 (the constructor sets these).
+  for (int32 i = 0; i < num_subsampled_frames; i++)
+    output_spec.indexes[i].t = time_offset + output_t_start + i * subsample;
+  request.outputs.resize(1);
+  request.outputs[0].Swap(&output_spec);
+
   const NnetComputation *computation = compiler_.Compile(request);
   Nnet *nnet_to_update = NULL;  // we're not doing any update.
   NnetComputer computer(opts_.compute_config, *computation,
@@ -232,14 +261,31 @@ void NnetDecodableBase::DoNnetComputation(
   current_log_post_.Resize(0, 0);
   // the following statement just swaps the pointers if we're not using a GPU.
   cu_output.Swap(&current_log_post_);
-  current_log_post_offset_ = output_t_start;
+  current_log_post_subsampled_offset_ = output_t_start / subsample;
 }
 
-void NnetDecodableBase::PossiblyWarnForFramesPerChunk() const {
-  static bool warned = false;
+void NnetDecodableBase::CheckAndFixConfigs() {
+  static bool warned_modulus = false,
+      warned_subsampling = false;
   int32 nnet_modulus = nnet_.Modulus();
-  if (opts_.frames_per_chunk % nnet_modulus != 0 && !warned) {
-    warned = true;
+  if (opts_.frame_subsampling_factor < 1 ||
+      opts_.frames_per_chunk < 1)
+    KALDI_ERR << "--frame-subsampling-factor and --frames-per-chunk must be > 0";
+  if (opts_.frames_per_chunk % opts_.frame_subsampling_factor != 0) {
+    int32 f = opts_.frame_subsampling_factor,
+        frames_per_chunk = f * ((opts_.frames_per_chunk + f - 1) / f);
+    if (!warned_subsampling) {
+      warned_subsampling = true;
+      KALDI_LOG << "Increasing --frames-per-chunk from "
+                << opts_.frames_per_chunk << " to "
+                << frames_per_chunk << " to make it a multiple of "
+                << "--frame-subsampling-factor="
+                << opts_.frame_subsampling_factor;
+    }
+    opts_.frames_per_chunk = frames_per_chunk;
+  }
+  if (opts_.frames_per_chunk % nnet_modulus != 0 && !warned_modulus) {
+    warned_modulus = true;
     KALDI_WARN << "It may be more efficient to set the --frames-per-chunk "
                << "(currently " << opts_.frames_per_chunk << " to a "
                << "multiple of the network's shift-invariance modulus "
@@ -249,4 +295,4 @@ void NnetDecodableBase::PossiblyWarnForFramesPerChunk() const {
 
 } // namespace nnet3
 } // namespace kaldi
+
index 4eeb262c7874f7244fe0281c71115bf6f7cc8371..15399b1308d8627a74d8292f3f8bfb5da1ba5ade 100644 (file)
@@ -37,6 +37,7 @@ namespace nnet3 {
 // for which IsSimpleNnet(nnet) would return true.
 struct NnetSimpleComputationOptions {
   int32 extra_left_context;
+  int32 frame_subsampling_factor;
   int32 frames_per_chunk;
   BaseFloat acoustic_scale;
   bool debug_computation;
@@ -45,6 +46,7 @@ struct NnetSimpleComputationOptions {
 
   NnetSimpleComputationOptions():
       extra_left_context(0),
+      frame_subsampling_factor(1),
       frames_per_chunk(50),
       acoustic_scale(0.1),
       debug_computation(false) { }
@@ -54,11 +56,17 @@ struct NnetSimpleComputationOptions {
                    "Number of frames of additional left-context to add on top "
                    "of the neural net's inherent left context (may be useful in "
                    "recurrent setups");
+    opts->Register("frame-subsampling-factor", &frame_subsampling_factor,
+                   "Required if the frame-rate of the output (e.g. in 'chain' "
+                   "models) is less than the frame-rate of the original "
+                   "alignment.");
     opts->Register("acoustic-scale", &acoustic_scale,
                    "Scaling factor for acoustic log-likelihoods");
     opts->Register("frames-per-chunk", &frames_per_chunk,
                    "Number of frames in each chunk that is separately evaluated "
-                   "by the neural net.");
+                   "by the neural net.  Measured before any subsampling, if the "
+                   "--frame-subsampling-factor options is used (i.e. counts "
+                   "input frames");
     opts->Register("debug-computation", &debug_computation, "If true, turn on "
                    "debug for the actual computation (very verbose!)");
 
@@ -115,8 +123,10 @@ class NnetDecodableBase {
                     int32 online_ivector_period = 1);
 
 
-  // returns the number of frames of likelihoods.
-  inline int32 NumFrames() const { return feats_.NumRows(); }
+  // returns the number of frames of likelihoods.  The same as feats_.NumRows()
+  // in the normal case (but may be less if opts_.frame_subsampling_factor !=
+  // 1).
+  inline int32 NumFrames() const { return num_subsampled_frames_; }
 
   inline int32 OutputDim() const { return output_dim_; }
 
@@ -124,19 +134,22 @@ class NnetDecodableBase {
   // 'output' must be correctly sized (with dimension OutputDim()).
   void GetOutputForFrame(int32 frame, VectorBase<BaseFloat> *output);
 
-  // Gets the output for a particular frame and pdf_id, with 0 <= frame < NumFrames(),
+  // Gets the output for a particular frame and pdf_id, with
+  // 0 <= subsampled_frame < NumFrames(),
   // and 0 <= pdf_id < OutputDim().
-  inline BaseFloat GetOutput(int32 frame, int32 pdf_id) {
-    if (frame < current_log_post_offset_ ||
-        frame >= current_log_post_offset_ + current_log_post_.NumRows())
-      EnsureFrameIsComputed(frame);
-    return current_log_post_(frame - current_log_post_offset_,
+  inline BaseFloat GetOutput(int32 subsampled_frame, int32 pdf_id) {
+    if (subsampled_frame < current_log_post_subsampled_offset_ ||
+        subsampled_frame >= current_log_post_subsampled_offset_ +
+                            current_log_post_.NumRows())
+      EnsureFrameIsComputed(subsampled_frame);
+    return current_log_post_(subsampled_frame -
+                             current_log_post_subsampled_offset_,
                              pdf_id);
   }
  private:
   // This call is made to ensure that we have the log-probs for this frame
   // cached in current_log_post_.
-  void EnsureFrameIsComputed(int32 frame);
+  void EnsureFrameIsComputed(int32 subsampled_frame);
 
   // This function does the actual nnet computation; it is called from
   // EnsureFrameIsComputed.  Any padding at file start/end is done by
@@ -146,19 +159,24 @@ class NnetDecodableBase {
                          const MatrixBase<BaseFloat> &input_feats,
                          const VectorBase<BaseFloat> &ivector,
                          int32 output_t_start,
-                         int32 num_output_frames);
-
-  // Gets the iVector that will be used for this chunk of frames, if
-  // we are using iVectors (else does nothing).
-  void GetCurrentIvector(int32 output_t_start, int32 num_output_frames,
+                         int32 num_subsampled_frames);
+
+  // Gets the iVector that will be used for this chunk of frames, if we are
+  // using iVectors (else does nothing).  note: the num_output_frames is
+  // interpreted as the number of t value, which in the subsampled case is not
+  // the same as the number of subsampled frames (it would be larger by
+  // opts_.frame_subsampling_factor).
+  void GetCurrentIvector(int32 output_t_start,
+                         int32 num_output_frames,
                          Vector<BaseFloat> *ivector);
 
-  void PossiblyWarnForFramesPerChunk() const;
+  // called from constructor
+  void CheckAndFixConfigs();
 
   // returns dimension of the provided iVectors if supplied, or 0 otherwise.
   int32 GetIvectorDim() const;
 
-  const NnetSimpleComputationOptions &opts_;
+  NnetSimpleComputationOptions opts_;
   const Nnet &nnet_;
   int32 nnet_left_context_;
   int32 nnet_right_context_;
@@ -166,6 +184,9 @@ class NnetDecodableBase {
   // the log priors (or the empty vector if the priors are not set in the model)
   CuVector<BaseFloat> log_priors_;
   const MatrixBase<BaseFloat> &feats_;
+  // note: num_subsampled_frames_ will equal feats_.NumRows() in the normal case
+  // when opts_.frame_subsampling_factor == 1.
+  int32 num_subsampled_frames_;
 
   // ivector_ is the iVector if we're using iVectors that are estimated in batch
   // mode.
@@ -182,8 +203,10 @@ class NnetDecodableBase {
   // The current log-posteriors that we got from the last time we
   // ran the computation.
   Matrix<BaseFloat> current_log_post_;
-  // The time-offset of the current log-posteriors.
-  int32 current_log_post_offset_;
+  // The time-offset of the current log-posteriors.  Note: if
+  // opts_.frame_subsampling_factor > 1, this will be measured in subsampled
+  // frames.
+  int32 current_log_post_subsampled_offset_;
 
 
 };
index ab09879c9c3648f647bdbcc41bcffdc112451f92..bb61f33e6b0ebd3ce60d18a7463e9114e77c5ed9 100644 (file)
@@ -145,7 +145,7 @@ void DecodableNnetCctcSimple::EnsureFrameIsComputed(int32 subsampled_frame) {
                                    current_subsampled_frames_computed);
 
   // all subsampled frames pertain to the output of the network,
-  // they are output frames divided by opts_.frame_subsampled_factor.
+  // they are output frames divided by opts_.frame_subsampling_factor.
   int32 subsampling_factor = opts_.frame_subsampling_factor,
       subsampled_frames_per_chunk = opts_.frames_per_chunk / subsampling_factor,
       start_subsampled_frame = subsampled_frame,
index b9b1b5ad2822ab410b0306535ec6e3da72714b59..eab486ade85c3ce2ec644613687f6afe4c686b5c 100644 (file)
@@ -261,7 +261,7 @@ void NormalizeComponent::Write(std::ostream &os, bool binary) const {
 std::string NormalizeComponent::Info() const {
   std::stringstream stream;
   stream << NonlinearComponent::Info();
-  stream << ", target_rms=" << target_rms_;
+  stream << ", target-rms=" << target_rms_;
   return stream.str();
 }