634a6267c4e8a9f9493e02e57dd07be4ba926a18
[processor-sdk/kaldi.git] / src / lm / arpa-lm-compiler.cc
1 // lm/arpa-lm-compiler.cc
3 // Copyright 2009-2011 Gilles Boulianne
4 // Copyright 2016 Smart Action LLC (kkm)
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 //  http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
21 #include <algorithm>
22 #include <limits>
23 #include <sstream>
24 #include <utility>
26 #include "base/kaldi-math.h"
27 #include "lm/arpa-lm-compiler.h"
28 #include "util/stl-utils.h"
29 #include "util/text-utils.h"
30 #include "fstext/remove-eps-local.h"
32 namespace kaldi {
34 class ArpaLmCompilerImplInterface {
35  public:
36   virtual ~ArpaLmCompilerImplInterface() { }
37   virtual void ConsumeNGram(const NGram& ngram, bool is_highest) = 0;
38 };
40 namespace {
42 typedef int32 StateId;
43 typedef int32 Symbol;
45 // GeneralHistKey can represent state history in an arbitrarily large n
46 // n-gram model with symbol ids fitting int32.
47 class GeneralHistKey {
48  public:
49   // Construct key from being and end iterators.
50   template<class InputIt>
51   GeneralHistKey(InputIt begin, InputIt end) : vector_(begin, end) { }
52   // Construct empty history key.
53   GeneralHistKey() : vector_() { }
54   // Return tails of the key as a GeneralHistKey. The tails of an n-gram
55   // w[1..n] is the sequence w[2..n] (and the heads is w[1..n-1], but the
56   // key class does not need this operartion).
57   GeneralHistKey Tails() const {
58     return GeneralHistKey(vector_.begin() + 1, vector_.end());
59   }
60   // Keys are equal if represent same state.
61   friend bool operator==(const GeneralHistKey& a, const GeneralHistKey& b) {
62     return a.vector_ == b.vector_;
63   }
64   // Public typename HashType for hashing.
65   struct HashType : public std::unary_function<GeneralHistKey, size_t> {
66     size_t operator()(const GeneralHistKey& key) const {
67       return VectorHasher<Symbol>().operator()(key.vector_);
68     }
69   };
71  private:
72   std::vector<Symbol> vector_;
73 };
75 // OptimizedHistKey combines 3 21-bit symbol ID values into one 64-bit
76 // machine word. allowing significant memory reduction and some runtime
77 // benefit over GeneralHistKey. Since 3 symbols are enough to track history
78 // in a 4-gram model, this optimized key is used for smaller models with up
79 // to 4-gram and symbol values up to 2^21-1.
80 //
81 // See GeneralHistKey for interface requrements of a key class.
82 class OptimizedHistKey {
83  public:
84   enum {
85     kShift = 21,  // 21 * 3 = 63 bits for data.
86     kMaxData = (1 << kShift) - 1
87   };
88   template<class InputIt>
89   OptimizedHistKey(InputIt begin, InputIt end) : data_(0) {
90     for (uint32 shift = 0; begin != end; ++begin, shift += kShift) {
91       data_ |= static_cast<uint64>(*begin) << shift;
92     }
93   }
94   OptimizedHistKey() : data_(0) { }
95   OptimizedHistKey Tails() const {
96     return OptimizedHistKey(data_ >> kShift);
97   }
98   friend bool operator==(const OptimizedHistKey& a, const OptimizedHistKey& b) {
99     return a.data_ == b.data_;
100   }
101   struct HashType : public std::unary_function<OptimizedHistKey, size_t> {
102     size_t operator()(const OptimizedHistKey& key) const { return key.data_; }
103   };
105  private:
106   explicit OptimizedHistKey(uint64 data) : data_(data) { }
107   uint64 data_;
108 };
110 }  // namespace
112 template <class HistKey>
113 class ArpaLmCompilerImpl : public ArpaLmCompilerImplInterface {
114  public:
115   ArpaLmCompilerImpl(ArpaLmCompiler* parent, fst::StdVectorFst* fst,
116                      Symbol sub_eps);
118   virtual void ConsumeNGram(const NGram &ngram, bool is_highest);
120  private:
121   StateId AddStateWithBackoff(HistKey key, float backoff);
122   void CreateBackoff(HistKey key, StateId state, float weight);
124   ArpaLmCompiler *parent_;  // Not owned.
125   fst::StdVectorFst* fst_;  // Not owned.
126   Symbol bos_symbol_;
127   Symbol eos_symbol_;
128   Symbol sub_eps_;
130   StateId eos_state_;
131   typedef unordered_map<HistKey, StateId,
132                         typename HistKey::HashType> HistoryMap;
133   HistoryMap history_;
134 };
136 template <class HistKey>
137 ArpaLmCompilerImpl<HistKey>::ArpaLmCompilerImpl(
138     ArpaLmCompiler* parent, fst::StdVectorFst* fst, Symbol sub_eps)
139     : parent_(parent), fst_(fst), bos_symbol_(parent->Options().bos_symbol),
140       eos_symbol_(parent->Options().eos_symbol), sub_eps_(sub_eps) {
141   // The algorithm maintains state per history. The 0-gram is a special state
142   // for emptry history. All unigrams (including BOS) backoff into this state.
143   StateId zerogram = fst_->AddState();
144   history_[HistKey()] = zerogram;
146   // Also, if </s> is not treated as epsilon, create a common end state for
147   // all transitions acepting the </s>, since they do not back off. This small
148   // optimization saves about 2% states in an average grammar.
149   if (sub_eps_ == 0) {
150     eos_state_ = fst_->AddState();
151     fst_->SetFinal(eos_state_, 0);
152   }
155 template <class HistKey>
156 void ArpaLmCompilerImpl<HistKey>::ConsumeNGram(const NGram &ngram,
157                                                bool is_highest) {
158   // Generally, we do the following. Suppose we are adding an n-gram "A B
159   // C". Then find the node for "A B", add a new node for "A B C", and connect
160   // them with the arc accepting "C" with the specified weight. Also, add a
161   // backoff arc from the new "A B C" node to its backoff state "B C".
162   //
163   // Two notable exceptions are the highest order n-grams, and final n-grams.
164   //
165   // When adding a highest order n-gram (e. g., our "A B C" is in a 3-gram LM),
166   // the following optimization is performed. There is no point adding a node
167   // for "A B C" with a "C" arc from "A B", since there will be no other
168   // arcs ingoing to this node, and an epsilon backoff arc into the backoff
169   // model "B C", with the weight of \bar{1}. To save a node, create an arc
170   // accepting "C" directly from "A B" to "B C". This saves as many nodes
171   // as there are the highest order n-grams, which is typically about half
172   // the size of a large 3-gram model.
173   //
174   // Indeed, this does not apply to n-grams ending in EOS, since they do not
175   // back off. These are special, as they do not have a back-off state, and
176   // the node for "(..anything..) </s>" is always final. These are handled
177   // in one of the two possible ways, If symbols <s> and </s> are being
178   // replaced by epsilons, neither node nor arc is created, and the logprob
179   // of the n-gram is applied to its source node as final weight. If <s> and
180   // </s> are preserved, then a special final node for </s> is allocated and
181   // used as the destination of the "</s>" acceptor arc.
182   HistKey heads(ngram.words.begin(), ngram.words.end() - 1);
183   typename HistoryMap::iterator source_it = history_.find(heads);
184   if (source_it == history_.end()) {
185     // There was no "A B", therefore the probability of "A B C" is zero.
186     // Print a warning and discard current n-gram.
187     if (parent_->ShouldWarn())
188       KALDI_WARN << parent_->LineReference()
189                  << " skipped: no parent (n-1)-gram exists";
190     return;
191   }
193   StateId source = source_it->second;
194   StateId dest;
195   Symbol sym = ngram.words.back();
196   float weight = -ngram.logprob;
197   if (sym == eos_symbol_) {
198     if (sub_eps_ == 0) {
199       // Keep </s> as a real symbol when not substituting.
200       dest = eos_state_;
201     } else {
202       // Treat </s> as if it was epsilon: mark source final, with the weight
203       // of the n-gram.
204       fst_->SetFinal(source, weight);
205       return;
206     }
207   } else {
208     // For the highest order n-gram, this may find an existing state, for
209     // non-highest, will create one (unless there are duplicate n-grams
210     // in the grammar, which cannot be reliably detected if highest order,
211     // so we better do not do that at all).
212     dest = AddStateWithBackoff(
213         HistKey(ngram.words.begin() + (is_highest ? 1 : 0),
214                 ngram.words.end()),
215         -ngram.backoff);
216   }
218   if (sym == bos_symbol_) {
219     weight = 0;  // Accepting <s> is always free.
220     if (sub_eps_ == 0) {
221       // <s> is as a real symbol, only accepted in the start state.
222       source = fst_->AddState();
223       fst_->SetStart(source);
224     } else {
225       // The new state for <s> unigram history *is* the start state.
226       fst_->SetStart(dest);
227       return;
228     }
229   }
231   // Add arc from source to dest, whichever way it was found.
232   fst_->AddArc(source, fst::StdArc(sym, sym, weight, dest));
233   return;
236 // Find or create a new state for n-gram defined by key, and ensure it has a
237 // backoff transition.  The key is either the current n-gram for all but
238 // highest orders, or the tails of the n-gram for the highest order. The
239 // latter arises from the chain-collapsing optimization described above.
240 template <class HistKey>
241 StateId ArpaLmCompilerImpl<HistKey>::AddStateWithBackoff(HistKey key,
242                                                          float backoff) {
243   typename HistoryMap::iterator dest_it = history_.find(key);
244   if (dest_it != history_.end()) {
245     // Found an existing state in the history map. Invariant: if the state in
246     // the map, then its backoff arc is in the FST. We are done.
247     return dest_it->second;
248   }
249   // Otherwise create a new state and its backoff arc, and register in the map.
250   StateId dest = fst_->AddState();
251   history_[key] = dest;
252   CreateBackoff(key.Tails(), dest, backoff);
253   return dest;
256 // Create a backoff arc for a state. Key is a backoff destination that may or
257 // may not exist. When the destination is not found, naturally fall back to
258 // the lower order model, and all the way down until one is found (since the
259 // 0-gram model is always present, the search is guaranteed to terminate).
260 template <class HistKey>
261 inline void ArpaLmCompilerImpl<HistKey>::CreateBackoff(
262     HistKey key, StateId state, float weight) {
263   typename HistoryMap::iterator dest_it = history_.find(key);
264   while (dest_it == history_.end()) {
265     key = key.Tails();
266     dest_it = history_.find(key);
267   }
269   // The arc should transduce either <eos> or #0 to <eps>, depending on the
270   // epsilon substitution mode. This is the only case when input and output
271   // label may differ.
272   fst_->AddArc(state, fst::StdArc(sub_eps_, 0, weight, dest_it->second));
275 ArpaLmCompiler::~ArpaLmCompiler() {
276   if (impl_ != NULL)
277     delete impl_;
280 void ArpaLmCompiler::HeaderAvailable() {
281   KALDI_ASSERT(impl_ == NULL);
282   // Use optimized implementation if the grammar is 4-gram or less, and the
283   // maximum attained symbol id will fit into the optimized range.
284   int64 max_symbol = 0;
285   if (Symbols() != NULL)
286     max_symbol = Symbols()->AvailableKey() - 1;
287   // If augmenting the symbol table, assume the wors case when all words in
288   // the model being read are novel.
289   if (Options().oov_handling == ArpaParseOptions::kAddToSymbols)
290     max_symbol += NgramCounts()[0];
292   if (NgramCounts().size() <= 4 && max_symbol < OptimizedHistKey::kMaxData) {
293     impl_ = new ArpaLmCompilerImpl<OptimizedHistKey>(this, &fst_, sub_eps_);
294   } else {
295     impl_ = new ArpaLmCompilerImpl<GeneralHistKey>(this, &fst_, sub_eps_);
296     KALDI_LOG << "Reverting to slower state tracking because model is large: "
297               << NgramCounts().size() << "-gram with symbols up to "
298               << max_symbol;
299   }
302 void ArpaLmCompiler::ConsumeNGram(const NGram &ngram) {
303   // <s> is invalid in tails, </s> in heads of an n-gram.
304   for (int i = 0; i < ngram.words.size(); ++i) {
305     if ((i > 0 && ngram.words[i] == Options().bos_symbol) ||
306         (i + 1 < ngram.words.size()
307          && ngram.words[i] == Options().eos_symbol)) {
308       if (ShouldWarn())
309         KALDI_WARN << LineReference()
310                    << " skipped: n-gram has invalid BOS/EOS placement";
311       return;
312     }
313   }
315   bool is_highest = ngram.words.size() == NgramCounts().size();
316   impl_->ConsumeNGram(ngram, is_highest);
319 void ArpaLmCompiler::RemoveRedundantStates() {
320   fst::StdArc::Label backoff_symbol = sub_eps_;
321   if (backoff_symbol == 0) {
322     // The method of removing redundant states implemented in this function
323     // leads to slow determinization of L o G when people use the older style of
324     // usage of arpa2fst where the --disambig-symbol option was not specified.
325     // The issue seems to be that it creates a non-deterministic FST, while G is
326     // supposed to be deterministic.  By 'return'ing below, we just disable this
327     // method if people were using an older script.  This method isn't really
328     // that consequential anyway, and people will move to the newer-style
329     // scripts (see current utils/format_lm.sh), so this isn't much of a
330     // problem.
331     return;
332   }
334   fst::StdArc::StateId num_states = fst_.NumStates();
337   // replace the #0 symbols on the input of arcs out of redundant states (states
338   // that are not final and have only a backoff arc leaving them), with <eps>.
339   for (fst::StdArc::StateId state = 0; state < num_states; state++) {
340     if (fst_.NumArcs(state) == 1 && fst_.Final(state) == fst::TropicalWeight::Zero()) {
341       fst::MutableArcIterator<fst::StdVectorFst> iter(&fst_, state);
342       fst::StdArc arc = iter.Value();
343       if (arc.ilabel == backoff_symbol) {
344         arc.ilabel = 0;
345         iter.SetValue(arc);
346       }
347     }
348   }
350   // we could call fst::RemoveEps, and it would have the same effect in normal
351   // cases, where backoff_symbol != 0 and there are no epsilons in unexpected
352   // places, but RemoveEpsLocal is a bit safer in case something weird is going
353   // on; it guarantees not to blow up the FST.
354   fst::RemoveEpsLocal(&fst_);
355   KALDI_LOG << "Reduced num-states from " << num_states << " to "
356             << fst_.NumStates();
359 void ArpaLmCompiler::ReadComplete() {
360   fst_.SetInputSymbols(Symbols());
361   fst_.SetOutputSymbols(Symbols());
362   RemoveRedundantStates();
365 }  // namespace kaldi