[src] Fix bug in fstrmymbols RE recent const-fst changes (thanks: Jon Nichols); other...
[processor-sdk/kaldi.git] / src / fstbin / fstrmsymbols.cc
1 // fstbin/fstrmsymbols.cc
3 // Copyright 2009-2011  Microsoft Corporation
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 //  http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
21 #include "base/kaldi-common.h"
22 #include "util/common-utils.h"
23 #include "fst/fstlib.h"
24 #include "fstext/determinize-star.h"
25 #include "fstext/fstext-utils.h"
26 #include "fstext/kaldi-fst-io.h"
28 namespace fst {
29 // we can move these functions elsewhere later, if they are needed in other
30 // places.
32 template<class Arc, class I>
33 void RemoveArcsWithSomeInputSymbols(const std::vector<I> &symbols_in,
34                                     VectorFst<Arc> *fst) {
35   typedef typename Arc::StateId StateId;
37   kaldi::ConstIntegerSet<I> symbol_set(symbols_in);
39   StateId num_states = fst->NumStates();
40   StateId dead_state = fst->AddState();
41   for (StateId s = 0; s < num_states; s++) {
42     for (MutableArcIterator<VectorFst<Arc> > iter(fst, s);
43          !iter.Done(); iter.Next()) {
44       if (symbol_set.count(iter.Value().ilabel) != 0) {
45         Arc arc = iter.Value();
46         arc.nextstate = dead_state;
47         iter.SetValue(arc);
48       }
49     }
50   }
51   // Connect() will actually remove the arcs, and the dead state.
52   Connect(fst);
53   if (fst->NumStates() == 0)
54     KALDI_WARN << "After Connect(), fst was empty.";
55 }
57 template<class Arc, class I>
58 void PenalizeArcsWithSomeInputSymbols(const std::vector<I> &symbols_in,
59                                       float penalty,
60                                       VectorFst<Arc> *fst) {
61   typedef typename Arc::StateId StateId;
62   typedef typename Arc::Label Label;
63   typedef typename Arc::Weight Weight;
65   Weight penalty_weight(penalty);
67   kaldi::ConstIntegerSet<I> symbol_set(symbols_in);
69   StateId num_states = fst->NumStates();
70   for (StateId s = 0; s < num_states; s++) {
71     for (MutableArcIterator<VectorFst<Arc> > iter(fst, s);
72          !iter.Done(); iter.Next()) {
73       if (symbol_set.count(iter.Value().ilabel) != 0) {
74         Arc arc = iter.Value();
75         arc.weight = Times(arc.weight, penalty_weight);
76         iter.SetValue(arc);
77       }
78     }
79   }
80 }
82 }
85 int main(int argc, char *argv[]) {
86   try {
87     using namespace kaldi;
88     using namespace fst;
89     using kaldi::int32;
91     bool apply_to_output = false;
92     bool remove_arcs = false;
93     float penalty = -std::numeric_limits<BaseFloat>::infinity();
95     const char *usage =
96         "With no options, replaces a subset of symbols with epsilon, wherever\n"
97         "they appear on the input side of an FST."
98         "With --remove-arcs=true, will remove arcs that contain these symbols\n"
99         "on the input\n"
100         "With --penalty=<float>, will add the specified penalty to the\n"
101         "cost of any arc that has one of the given symbols on its input side\n"
102         "In all cases, the option --apply-to-output=true (or for\n"
103         "back-compatibility, --remove-from-output=true) makes this apply\n"
104         "to the output side.\n"
105         "\n"
106         "Usage:  fstrmsymbols [options] <in-disambig-list>  [<in.fst> [<out.fst>]]\n"
107         "E.g:  fstrmsymbols in.list  < in.fst > out.fst\n"
108         "<in-disambig-list> is an rxfilename specifying a file containing list of integers\n"
109         "representing symbols, in text form, one per line.\n";
111     ParseOptions po(usage);
112     po.Register("remove-from-output", &apply_to_output, "If true, this applies to symbols "
113                 "on the output, not the input, side.  (For back compatibility; use "
114                 "--apply-to-output insead)");
115     po.Register("apply-to-output", &apply_to_output, "If true, this applies to symbols "
116                 "on the output, not the input, side.");
117     po.Register("remove-arcs", &remove_arcs, "If true, instead of converting the symbol "
118                 "to <eps>, remove the arcs.");
119     po.Register("penalty", &penalty, "If specified, instead of converting "
120                 "the symbol to <eps>, penalize the arc it is on by adding this "
121                 "value to its cost.");
124     po.Read(argc, argv);
126     if (remove_arcs &&
127         penalty != -std::numeric_limits<BaseFloat>::infinity())
128       KALDI_ERR << "--remove-arc and --penalty options are mutually exclusive";
130     if (po.NumArgs() < 1 || po.NumArgs() > 3) {
131       po.PrintUsage();
132       exit(1);
133     }
135     std::string disambig_rxfilename = po.GetArg(1),
136         fst_rxfilename = po.GetOptArg(2),
137         fst_wxfilename = po.GetOptArg(3);
139     VectorFst<StdArc> *fst = CastOrConvertToVectorFst(
140         ReadFstKaldiGeneric(fst_rxfilename));
142     std::vector<int32> disambig_in;
143     if (!ReadIntegerVectorSimple(disambig_rxfilename, &disambig_in))
144       KALDI_ERR << "fstrmsymbols: Could not read disambiguation symbols from "
145                 << (disambig_rxfilename == "" ? "standard input" : disambig_rxfilename);
147     if (apply_to_output) Invert(fst);
148     if (remove_arcs) {
149       RemoveArcsWithSomeInputSymbols(disambig_in, fst);
150     } else if (penalty != -std::numeric_limits<BaseFloat>::infinity()) {
151       PenalizeArcsWithSomeInputSymbols(disambig_in, penalty, fst);
152     } else {
153       RemoveSomeInputSymbols(disambig_in, fst);
154     }
155     if (apply_to_output) Invert(fst);
157     WriteFstKaldi(*fst, fst_wxfilename);
159     delete fst;
160     return 0;
161   } catch(const std::exception &e) {
162     std::cerr << e.what();
163     return -1;
164   }
167 /* some test examples:
169  ( echo "0 0 1 1"; echo " 0 0 3 2"; echo "0 0"; ) | fstcompile | fstrmsymbols "echo 3; echo  4|" | fstprint
170  # should produce:
171  # 0   0   1   1
172  # 0   0   0   2
173  # 0
175  ( echo "0 0 1 1"; echo " 0 0 3 2"; echo "0 0"; ) | fstcompile | fstrmsymbols --apply-to-output=true "echo 2; echo 3|" | fstprint
176  # should produce:
177  # 0   0   1   1
178  # 0   0   3   0
179  # 0
182  ( echo "0 0 1 1"; echo " 0 0 3 2"; echo "0 0"; ) | fstcompile | fstrmsymbols --remove-arcs=true  "echo 3; echo  4|" | fstprint
183  # should produce:
184  # 0   0   1   1
185  # 0
187  ( echo "0 0 1 1"; echo " 0 0 3 2"; echo "0 0"; ) | fstcompile | fstrmsymbols --penalty=2 "echo 3; echo 4; echo 5|" | fstprint
188 # should produce:
189  # 0   0   1   1
190  # 0   0   3   2   2
191  # 0
193 */