index 98b99f49a6ffaa2115f27559d3db4f96f9f946cb..a365b016e5815bc76c88ecb0e487d24d92511f91 100644 (file)
// Copyright 2009-2011 Microsoft Corporation
+// See ../../COPYING for clarification regarding multiple authors
+//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
#include "fst/fstlib.h"
#include "fstext/determinize-star.h"
#include "fstext/fstext-utils.h"
+#include "fstext/kaldi-fst-io.h"
+
+namespace fst {
+// we can move these functions elsewhere later, if they are needed in other
+// places.
+
+template<class Arc, class I>
+void RemoveArcsWithSomeInputSymbols(const std::vector<I> &symbols_in,
+ VectorFst<Arc> *fst) {
+ typedef typename Arc::StateId StateId;
+
+ kaldi::ConstIntegerSet<I> symbol_set(symbols_in);
+
+ StateId num_states = fst->NumStates();
+ StateId dead_state = fst->AddState();
+ for (StateId s = 0; s < num_states; s++) {
+ for (MutableArcIterator<VectorFst<Arc> > iter(fst, s);
+ !iter.Done(); iter.Next()) {
+ if (symbol_set.count(iter.Value().ilabel) != 0) {
+ Arc arc = iter.Value();
+ arc.nextstate = dead_state;
+ iter.SetValue(arc);
+ }
+ }
+ }
+ // Connect() will actually remove the arcs, and the dead state.
+ Connect(fst);
+ if (fst->NumStates() == 0)
+ KALDI_WARN << "After Connect(), fst was empty.";
+}
+template<class Arc, class I>
+void PenalizeArcsWithSomeInputSymbols(const std::vector<I> &symbols_in,
+ float penalty,
+ VectorFst<Arc> *fst) {
+ typedef typename Arc::StateId StateId;
+ typedef typename Arc::Label Label;
+ typedef typename Arc::Weight Weight;
+
+ Weight penalty_weight(penalty);
+
+ kaldi::ConstIntegerSet<I> symbol_set(symbols_in);
+
+ StateId num_states = fst->NumStates();
+ for (StateId s = 0; s < num_states; s++) {
+ for (MutableArcIterator<VectorFst<Arc> > iter(fst, s);
+ !iter.Done(); iter.Next()) {
+ if (symbol_set.count(iter.Value().ilabel) != 0) {
+ Arc arc = iter.Value();
+ arc.weight = Times(arc.weight, penalty_weight);
+ iter.SetValue(arc);
+ }
+ }
+ }
+}
-/* some test examples:
- ( echo 3; echo 4) > /tmp/in.list
- ( echo "0 0 1 1"; echo " 0 0 3 2"; echo "0 0"; ) | fstcompile | fstrmsymbols /tmp/in.list | fstprint
-
- cd ~/tmpdir
- while true; do
- fstrand > 1.fst
- fstpredeterminize out.lst 1.fst | fstdeterminizestar | fstrmsymbols out.lst > 2.fst
- fstequivalent --random=true 1.fst 2.fst || echo "Test failed"
- echo -n "."
- done
+}
-*/
-int main(int argc, char *argv[])
-{
+int main(int argc, char *argv[]) {
try {
using namespace kaldi;
using namespace fst;
using kaldi::int32;
+ bool apply_to_output = false;
+ bool remove_arcs = false;
+ float penalty = -std::numeric_limits<BaseFloat>::infinity();
+
const char *usage =
- "Replaces a subset of symbols with epsilon, wherever they appear on the input side\n"
- "of an FST.\n"
+ "With no options, replaces a subset of symbols with epsilon, wherever\n"
+ "they appear on the input side of an FST."
+ "With --remove-arcs=true, will remove arcs that contain these symbols\n"
+ "on the input\n"
+ "With --penalty=<float>, will add the specified penalty to the\n"
+ "cost of any arc that has one of the given symbols on its input side\n"
+ "In all cases, the option --apply-to-output=true (or for\n"
+ "back-compatibility, --remove-from-output=true) makes this apply\n"
+ "to the output side.\n"
"\n"
- "Usage: fstrmsymbols in-disambig-list [in.fst [out.fst] ]\n"
- "E.g: fstrmsymbols in.list < in.fst > out.fst\n";
+ "Usage: fstrmsymbols [options] <in-disambig-list> [<in.fst> [<out.fst>]]\n"
+ "E.g: fstrmsymbols in.list < in.fst > out.fst\n"
+ "<in-disambig-list> is an rxfilename specifying a file containing list of integers\n"
+ "representing symbols, in text form, one per line.\n";
ParseOptions po(usage);
+ po.Register("remove-from-output", &apply_to_output, "If true, this applies to symbols "
+ "on the output, not the input, side. (For back compatibility; use "
+ "--apply-to-output insead)");
+ po.Register("apply-to-output", &apply_to_output, "If true, this applies to symbols "
+ "on the output, not the input, side.");
+ po.Register("remove-arcs", &remove_arcs, "If true, instead of converting the symbol "
+ "to <eps>, remove the arcs.");
+ po.Register("penalty", &penalty, "If specified, instead of converting "
+ "the symbol to <eps>, penalize the arc it is on by adding this "
+ "value to its cost.");
+
+
po.Read(argc, argv);
+ if (remove_arcs &&
+ penalty != -std::numeric_limits<BaseFloat>::infinity())
+ KALDI_ERR << "--remove-arc and --penalty options are mutually exclusive";
+
if (po.NumArgs() < 1 || po.NumArgs() > 3) {
po.PrintUsage();
exit(1);
}
- std::string disambig_in_filename = po.GetArg(1);
- if (disambig_in_filename == "-") disambig_in_filename = "";
-
- std::string fst_in_filename;
- fst_in_filename = po.GetOptArg(2);
- if (fst_in_filename == "-") fst_in_filename = "";
+ std::string disambig_rxfilename = po.GetArg(1),
+ fst_rxfilename = po.GetOptArg(2),
+ fst_wxfilename = po.GetOptArg(3);
- std::string fst_out_filename;
- fst_out_filename = po.GetOptArg(3);
- if (fst_out_filename == "-") fst_out_filename = "";
-
- VectorFst<StdArc> *fst = VectorFst<StdArc>::Read(fst_in_filename);
- if (!fst) {
- std::cerr << "fstrmsymbols: could not read input fst from " << fst_in_filename << '\n';
- return 1;
- }
+ VectorFst<StdArc> *fst = CastOrConvertToVectorFst(
+ ReadFstKaldiGeneric(fst_rxfilename));
std::vector<int32> disambig_in;
- if (!ReadIntegerVectorSimple(disambig_in_filename, &disambig_in)) {
- std::cerr << "fstrmsymbols: Could not read disambiguation symbols from "
- << (disambig_in_filename == "" ? "standard input" : disambig_in_filename)
- << '\n';
- return 1;
+ if (!ReadIntegerVectorSimple(disambig_rxfilename, &disambig_in))
+ KALDI_ERR << "fstrmsymbols: Could not read disambiguation symbols from "
+ << (disambig_rxfilename == "" ? "standard input" : disambig_rxfilename);
+
+ if (apply_to_output) Invert(fst);
+ if (remove_arcs) {
+ RemoveArcsWithSomeInputSymbols(disambig_in, fst);
+ } else if (penalty != -std::numeric_limits<BaseFloat>::infinity()) {
+ PenalizeArcsWithSomeInputSymbols(disambig_in, penalty, fst);
+ } else {
+ RemoveSomeInputSymbols(disambig_in, fst);
}
+ if (apply_to_output) Invert(fst);
- RemoveSomeInputSymbols(disambig_in, fst);
+ WriteFstKaldi(*fst, fst_wxfilename);
- if (! fst->Write(fst_out_filename) ) {
- std::cerr << "fstrmsymbols: error writing the output to "<<fst_out_filename << '\n';
- return 1;
- }
delete fst;
- } catch(const std::exception& e) {
+ return 0;
+ } catch(const std::exception &e) {
std::cerr << e.what();
return -1;
}
- return 0;
}
+/* some test examples:
+
+ ( echo "0 0 1 1"; echo " 0 0 3 2"; echo "0 0"; ) | fstcompile | fstrmsymbols "echo 3; echo 4|" | fstprint
+ # should produce:
+ # 0 0 1 1
+ # 0 0 0 2
+ # 0
+
+ ( echo "0 0 1 1"; echo " 0 0 3 2"; echo "0 0"; ) | fstcompile | fstrmsymbols --apply-to-output=true "echo 2; echo 3|" | fstprint
+ # should produce:
+ # 0 0 1 1
+ # 0 0 3 0
+ # 0
+
+
+ ( echo "0 0 1 1"; echo " 0 0 3 2"; echo "0 0"; ) | fstcompile | fstrmsymbols --remove-arcs=true "echo 3; echo 4|" | fstprint
+ # should produce:
+ # 0 0 1 1
+ # 0
+
+ ( echo "0 0 1 1"; echo " 0 0 3 2"; echo "0 0"; ) | fstcompile | fstrmsymbols --penalty=2 "echo 3; echo 4; echo 5|" | fstprint
+# should produce:
+ # 0 0 1 1
+ # 0 0 3 2 2
+ # 0
+
+*/