index 438170c2b98b76fec5f768a4ee3eb06ab55552fa..a365b016e5815bc76c88ecb0e487d24d92511f91 100644 (file)
#include "fstext/fstext-utils.h"
#include "fstext/kaldi-fst-io.h"
-/* some test examples:
- ( echo 3; echo 4) > /tmp/in.list
- ( echo "0 0 1 1"; echo " 0 0 3 2"; echo "0 0"; ) | fstcompile | fstrmsymbols /tmp/in.list | fstprint
+namespace fst {
+// we can move these functions elsewhere later, if they are needed in other
+// places.
- cd ~/tmpdir
- while true; do
- fstrand > 1.fst
- fstpredeterminize out.lst 1.fst | fstdeterminizestar | fstrmsymbols out.lst > 2.fst
- fstequivalent --random=true 1.fst 2.fst || echo "Test failed"
- echo -n "."
- done
+template<class Arc, class I>
+void RemoveArcsWithSomeInputSymbols(const std::vector<I> &symbols_in,
+ VectorFst<Arc> *fst) {
+ typedef typename Arc::StateId StateId;
+
+ kaldi::ConstIntegerSet<I> symbol_set(symbols_in);
+
+ StateId num_states = fst->NumStates();
+ StateId dead_state = fst->AddState();
+ for (StateId s = 0; s < num_states; s++) {
+ for (MutableArcIterator<VectorFst<Arc> > iter(fst, s);
+ !iter.Done(); iter.Next()) {
+ if (symbol_set.count(iter.Value().ilabel) != 0) {
+ Arc arc = iter.Value();
+ arc.nextstate = dead_state;
+ iter.SetValue(arc);
+ }
+ }
+ }
+ // Connect() will actually remove the arcs, and the dead state.
+ Connect(fst);
+ if (fst->NumStates() == 0)
+ KALDI_WARN << "After Connect(), fst was empty.";
+}
+
+template<class Arc, class I>
+void PenalizeArcsWithSomeInputSymbols(const std::vector<I> &symbols_in,
+ float penalty,
+ VectorFst<Arc> *fst) {
+ typedef typename Arc::StateId StateId;
+ typedef typename Arc::Label Label;
+ typedef typename Arc::Weight Weight;
+
+ Weight penalty_weight(penalty);
+
+ kaldi::ConstIntegerSet<I> symbol_set(symbols_in);
+
+ StateId num_states = fst->NumStates();
+ for (StateId s = 0; s < num_states; s++) {
+ for (MutableArcIterator<VectorFst<Arc> > iter(fst, s);
+ !iter.Done(); iter.Next()) {
+ if (symbol_set.count(iter.Value().ilabel) != 0) {
+ Arc arc = iter.Value();
+ arc.weight = Times(arc.weight, penalty_weight);
+ iter.SetValue(arc);
+ }
+ }
+ }
+}
+
+}
-*/
int main(int argc, char *argv[]) {
try {
using namespace fst;
using kaldi::int32;
- bool remove_from_output = false;
-
+ bool apply_to_output = false;
+ bool remove_arcs = false;
+ float penalty = -std::numeric_limits<BaseFloat>::infinity();
+
const char *usage =
- "Replaces a subset of symbols with epsilon, wherever they appear on the input side\n"
- "of an FST (or the output side, with --remove-from-output=true)\n"
+ "With no options, replaces a subset of symbols with epsilon, wherever\n"
+ "they appear on the input side of an FST."
+ "With --remove-arcs=true, will remove arcs that contain these symbols\n"
+ "on the input\n"
+ "With --penalty=<float>, will add the specified penalty to the\n"
+ "cost of any arc that has one of the given symbols on its input side\n"
+ "In all cases, the option --apply-to-output=true (or for\n"
+ "back-compatibility, --remove-from-output=true) makes this apply\n"
+ "to the output side.\n"
"\n"
- "Usage: fstrmsymbols in-disambig-list [in.fst [out.fst] ]\n"
- "E.g: fstrmsymbols in.list < in.fst > out.fst\n";
+ "Usage: fstrmsymbols [options] <in-disambig-list> [<in.fst> [<out.fst>]]\n"
+ "E.g: fstrmsymbols in.list < in.fst > out.fst\n"
+ "<in-disambig-list> is an rxfilename specifying a file containing list of integers\n"
+ "representing symbols, in text form, one per line.\n";
ParseOptions po(usage);
- po.Register("remove-from-output", &remove_from_output, "If true, remove these symbols from "
- "the output, not the input, side.");
+ po.Register("remove-from-output", &apply_to_output, "If true, this applies to symbols "
+ "on the output, not the input, side. (For back compatibility; use "
+ "--apply-to-output insead)");
+ po.Register("apply-to-output", &apply_to_output, "If true, this applies to symbols "
+ "on the output, not the input, side.");
+ po.Register("remove-arcs", &remove_arcs, "If true, instead of converting the symbol "
+ "to <eps>, remove the arcs.");
+ po.Register("penalty", &penalty, "If specified, instead of converting "
+ "the symbol to <eps>, penalize the arc it is on by adding this "
+ "value to its cost.");
+
+
po.Read(argc, argv);
+ if (remove_arcs &&
+ penalty != -std::numeric_limits<BaseFloat>::infinity())
+ KALDI_ERR << "--remove-arc and --penalty options are mutually exclusive";
+
if (po.NumArgs() < 1 || po.NumArgs() > 3) {
po.PrintUsage();
exit(1);
}
-
+
std::string disambig_rxfilename = po.GetArg(1),
fst_rxfilename = po.GetOptArg(2),
fst_wxfilename = po.GetOptArg(3);
- VectorFst<StdArc> *fst = ReadFstKaldi(fst_rxfilename);
-
+ VectorFst<StdArc> *fst = CastOrConvertToVectorFst(
+ ReadFstKaldiGeneric(fst_rxfilename));
+
std::vector<int32> disambig_in;
if (!ReadIntegerVectorSimple(disambig_rxfilename, &disambig_in))
KALDI_ERR << "fstrmsymbols: Could not read disambiguation symbols from "
<< (disambig_rxfilename == "" ? "standard input" : disambig_rxfilename);
- if (remove_from_output) Invert(fst);
- RemoveSomeInputSymbols(disambig_in, fst);
- if (remove_from_output) Invert(fst);
-
+ if (apply_to_output) Invert(fst);
+ if (remove_arcs) {
+ RemoveArcsWithSomeInputSymbols(disambig_in, fst);
+ } else if (penalty != -std::numeric_limits<BaseFloat>::infinity()) {
+ PenalizeArcsWithSomeInputSymbols(disambig_in, penalty, fst);
+ } else {
+ RemoveSomeInputSymbols(disambig_in, fst);
+ }
+ if (apply_to_output) Invert(fst);
+
WriteFstKaldi(*fst, fst_wxfilename);
delete fst;
- return 0;
+ return 0;
} catch(const std::exception &e) {
std::cerr << e.what();
return -1;
}
}
+/* some test examples:
+
+ ( echo "0 0 1 1"; echo " 0 0 3 2"; echo "0 0"; ) | fstcompile | fstrmsymbols "echo 3; echo 4|" | fstprint
+ # should produce:
+ # 0 0 1 1
+ # 0 0 0 2
+ # 0
+
+ ( echo "0 0 1 1"; echo " 0 0 3 2"; echo "0 0"; ) | fstcompile | fstrmsymbols --apply-to-output=true "echo 2; echo 3|" | fstprint
+ # should produce:
+ # 0 0 1 1
+ # 0 0 3 0
+ # 0
+
+
+ ( echo "0 0 1 1"; echo " 0 0 3 2"; echo "0 0"; ) | fstcompile | fstrmsymbols --remove-arcs=true "echo 3; echo 4|" | fstprint
+ # should produce:
+ # 0 0 1 1
+ # 0
+
+ ( echo "0 0 1 1"; echo " 0 0 3 2"; echo "0 0"; ) | fstcompile | fstrmsymbols --penalty=2 "echo 3; echo 4; echo 5|" | fstprint
+# should produce:
+ # 0 0 1 1
+ # 0 0 3 2 2
+ # 0
+
+*/