1 // fstbin/fstrmsymbols.cc
3 // Copyright 2009-2011 Microsoft Corporation
5 // See ../../COPYING for clarification regarding multiple authors
6 //
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 //
11 // http://www.apache.org/licenses/LICENSE-2.0
12 //
13 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
15 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
16 // MERCHANTABLITY OR NON-INFRINGEMENT.
17 // See the Apache 2 License for the specific language governing permissions and
18 // limitations under the License.
21 #include "base/kaldi-common.h"
22 #include "util/common-utils.h"
23 #include "fst/fstlib.h"
24 #include "fstext/determinize-star.h"
25 #include "fstext/fstext-utils.h"
26 #include "fstext/kaldi-fst-io.h"
28 namespace fst {
29 // we can move these functions elsewhere later, if they are needed in other
30 // places.
32 template<class Arc, class I>
33 void RemoveArcsWithSomeInputSymbols(const std::vector<I> &symbols_in,
34 VectorFst<Arc> *fst) {
35 typedef typename Arc::StateId StateId;
37 kaldi::ConstIntegerSet<I> symbol_set(symbols_in);
39 StateId num_states = fst->NumStates();
40 StateId dead_state = fst->AddState();
41 for (StateId s = 0; s < num_states; s++) {
42 for (MutableArcIterator<VectorFst<Arc> > iter(fst, s);
43 !iter.Done(); iter.Next()) {
44 if (symbol_set.count(iter.Value().ilabel) != 0) {
45 Arc arc = iter.Value();
46 arc.nextstate = dead_state;
47 iter.SetValue(arc);
48 }
49 }
50 }
51 // Connect() will actually remove the arcs, and the dead state.
52 Connect(fst);
53 if (fst->NumStates() == 0)
54 KALDI_WARN << "After Connect(), fst was empty.";
55 }
57 template<class Arc, class I>
58 void PenalizeArcsWithSomeInputSymbols(const std::vector<I> &symbols_in,
59 float penalty,
60 VectorFst<Arc> *fst) {
61 typedef typename Arc::StateId StateId;
62 typedef typename Arc::Label Label;
63 typedef typename Arc::Weight Weight;
65 Weight penalty_weight(penalty);
67 kaldi::ConstIntegerSet<I> symbol_set(symbols_in);
69 StateId num_states = fst->NumStates();
70 for (StateId s = 0; s < num_states; s++) {
71 for (MutableArcIterator<VectorFst<Arc> > iter(fst, s);
72 !iter.Done(); iter.Next()) {
73 if (symbol_set.count(iter.Value().ilabel) != 0) {
74 Arc arc = iter.Value();
75 arc.weight = Times(arc.weight, penalty_weight);
76 iter.SetValue(arc);
77 }
78 }
79 }
80 }
82 }
85 int main(int argc, char *argv[]) {
86 try {
87 using namespace kaldi;
88 using namespace fst;
89 using kaldi::int32;
91 bool apply_to_output = false;
92 bool remove_arcs = false;
93 float penalty = -std::numeric_limits<BaseFloat>::infinity();
95 const char *usage =
96 "With no options, replaces a subset of symbols with epsilon, wherever\n"
97 "they appear on the input side of an FST."
98 "With --remove-arcs=true, will remove arcs that contain these symbols\n"
99 "on the input\n"
100 "With --penalty=<float>, will add the specified penalty to the\n"
101 "cost of any arc that has one of the given symbols on its input side\n"
102 "In all cases, the option --apply-to-output=true (or for\n"
103 "back-compatibility, --remove-from-output=true) makes this apply\n"
104 "to the output side.\n"
105 "\n"
106 "Usage: fstrmsymbols [options] <in-disambig-list> [<in.fst> [<out.fst>]]\n"
107 "E.g: fstrmsymbols in.list < in.fst > out.fst\n"
108 "<in-disambig-list> is an rxfilename specifying a file containing list of integers\n"
109 "representing symbols, in text form, one per line.\n";
111 ParseOptions po(usage);
112 po.Register("remove-from-output", &apply_to_output, "If true, this applies to symbols "
113 "on the output, not the input, side. (For back compatibility; use "
114 "--apply-to-output insead)");
115 po.Register("apply-to-output", &apply_to_output, "If true, this applies to symbols "
116 "on the output, not the input, side.");
117 po.Register("remove-arcs", &remove_arcs, "If true, instead of converting the symbol "
118 "to <eps>, remove the arcs.");
119 po.Register("penalty", &penalty, "If specified, instead of converting "
120 "the symbol to <eps>, penalize the arc it is on by adding this "
121 "value to its cost.");
124 po.Read(argc, argv);
126 if (remove_arcs &&
127 penalty != -std::numeric_limits<BaseFloat>::infinity())
128 KALDI_ERR << "--remove-arc and --penalty options are mutually exclusive";
130 if (po.NumArgs() < 1 || po.NumArgs() > 3) {
131 po.PrintUsage();
132 exit(1);
133 }
135 std::string disambig_rxfilename = po.GetArg(1),
136 fst_rxfilename = po.GetOptArg(2),
137 fst_wxfilename = po.GetOptArg(3);
139 VectorFst<StdArc> *fst = ReadFstKaldi(fst_rxfilename);
141 std::vector<int32> disambig_in;
142 if (!ReadIntegerVectorSimple(disambig_rxfilename, &disambig_in))
143 KALDI_ERR << "fstrmsymbols: Could not read disambiguation symbols from "
144 << (disambig_rxfilename == "" ? "standard input" : disambig_rxfilename);
146 if (apply_to_output) Invert(fst);
147 if (remove_arcs) {
148 RemoveArcsWithSomeInputSymbols(disambig_in, fst);
149 } else if (penalty != -std::numeric_limits<BaseFloat>::infinity()) {
150 PenalizeArcsWithSomeInputSymbols(disambig_in, penalty, fst);
151 } else {
152 RemoveSomeInputSymbols(disambig_in, fst);
153 }
154 if (apply_to_output) Invert(fst);
156 WriteFstKaldi(*fst, fst_wxfilename);
158 delete fst;
159 return 0;
160 } catch(const std::exception &e) {
161 std::cerr << e.what();
162 return -1;
163 }
164 }
166 /* some test examples:
168 ( echo "0 0 1 1"; echo " 0 0 3 2"; echo "0 0"; ) | fstcompile | fstrmsymbols "echo 3; echo 4|" | fstprint
169 # should produce:
170 # 0 0 1 1
171 # 0 0 0 2
172 # 0
174 ( echo "0 0 1 1"; echo " 0 0 3 2"; echo "0 0"; ) | fstcompile | fstrmsymbols --apply-to-output=true "echo 2; echo 3|" | fstprint
175 # should produce:
176 # 0 0 1 1
177 # 0 0 3 0
178 # 0
181 ( echo "0 0 1 1"; echo " 0 0 3 2"; echo "0 0"; ) | fstcompile | fstrmsymbols --remove-arcs=true "echo 3; echo 4|" | fstprint
182 # should produce:
183 # 0 0 1 1
184 # 0
186 ( echo "0 0 1 1"; echo " 0 0 3 2"; echo "0 0"; ) | fstcompile | fstrmsymbols --penalty=2 "echo 3; echo 4; echo 5|" | fstprint
187 # should produce:
188 # 0 0 1 1
189 # 0 0 3 2 2
190 # 0
192 */