// fstbin/fstrmsymbols.cc // Copyright 2009-2011 Microsoft Corporation // See ../../COPYING for clarification regarding multiple authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, // MERCHANTABLITY OR NON-INFRINGEMENT. // See the Apache 2 License for the specific language governing permissions and // limitations under the License. #include "base/kaldi-common.h" #include "util/common-utils.h" #include "fst/fstlib.h" #include "fstext/determinize-star.h" #include "fstext/fstext-utils.h" #include "fstext/kaldi-fst-io.h" namespace fst { // we can move these functions elsewhere later, if they are needed in other // places. template void RemoveArcsWithSomeInputSymbols(const std::vector &symbols_in, VectorFst *fst) { typedef typename Arc::StateId StateId; kaldi::ConstIntegerSet symbol_set(symbols_in); StateId num_states = fst->NumStates(); StateId dead_state = fst->AddState(); for (StateId s = 0; s < num_states; s++) { for (MutableArcIterator > iter(fst, s); !iter.Done(); iter.Next()) { if (symbol_set.count(iter.Value().ilabel) != 0) { Arc arc = iter.Value(); arc.nextstate = dead_state; iter.SetValue(arc); } } } // Connect() will actually remove the arcs, and the dead state. Connect(fst); if (fst->NumStates() == 0) KALDI_WARN << "After Connect(), fst was empty."; } template void PenalizeArcsWithSomeInputSymbols(const std::vector &symbols_in, float penalty, VectorFst *fst) { typedef typename Arc::StateId StateId; typedef typename Arc::Label Label; typedef typename Arc::Weight Weight; Weight penalty_weight(penalty); kaldi::ConstIntegerSet symbol_set(symbols_in); StateId num_states = fst->NumStates(); for (StateId s = 0; s < num_states; s++) { for (MutableArcIterator > iter(fst, s); !iter.Done(); iter.Next()) { if (symbol_set.count(iter.Value().ilabel) != 0) { Arc arc = iter.Value(); arc.weight = Times(arc.weight, penalty_weight); iter.SetValue(arc); } } } } } int main(int argc, char *argv[]) { try { using namespace kaldi; using namespace fst; using kaldi::int32; bool apply_to_output = false; bool remove_arcs = false; float penalty = -std::numeric_limits::infinity(); const char *usage = "With no options, replaces a subset of symbols with epsilon, wherever\n" "they appear on the input side of an FST." "With --remove-arcs=true, will remove arcs that contain these symbols\n" "on the input\n" "With --penalty=, will add the specified penalty to the\n" "cost of any arc that has one of the given symbols on its input side\n" "In all cases, the option --apply-to-output=true (or for\n" "back-compatibility, --remove-from-output=true) makes this apply\n" "to the output side.\n" "\n" "Usage: fstrmsymbols [options] [ []]\n" "E.g: fstrmsymbols in.list < in.fst > out.fst\n" " is an rxfilename specifying a file containing list of integers\n" "representing symbols, in text form, one per line.\n"; ParseOptions po(usage); po.Register("remove-from-output", &apply_to_output, "If true, this applies to symbols " "on the output, not the input, side. (For back compatibility; use " "--apply-to-output insead)"); po.Register("apply-to-output", &apply_to_output, "If true, this applies to symbols " "on the output, not the input, side."); po.Register("remove-arcs", &remove_arcs, "If true, instead of converting the symbol " "to , remove the arcs."); po.Register("penalty", &penalty, "If specified, instead of converting " "the symbol to , penalize the arc it is on by adding this " "value to its cost."); po.Read(argc, argv); if (remove_arcs && penalty != -std::numeric_limits::infinity()) KALDI_ERR << "--remove-arc and --penalty options are mutually exclusive"; if (po.NumArgs() < 1 || po.NumArgs() > 3) { po.PrintUsage(); exit(1); } std::string disambig_rxfilename = po.GetArg(1), fst_rxfilename = po.GetOptArg(2), fst_wxfilename = po.GetOptArg(3); VectorFst *fst = ReadFstKaldi(fst_rxfilename); std::vector disambig_in; if (!ReadIntegerVectorSimple(disambig_rxfilename, &disambig_in)) KALDI_ERR << "fstrmsymbols: Could not read disambiguation symbols from " << (disambig_rxfilename == "" ? "standard input" : disambig_rxfilename); if (apply_to_output) Invert(fst); if (remove_arcs) { RemoveArcsWithSomeInputSymbols(disambig_in, fst); } else if (penalty != -std::numeric_limits::infinity()) { PenalizeArcsWithSomeInputSymbols(disambig_in, penalty, fst); } else { RemoveSomeInputSymbols(disambig_in, fst); } if (apply_to_output) Invert(fst); WriteFstKaldi(*fst, fst_wxfilename); delete fst; return 0; } catch(const std::exception &e) { std::cerr << e.what(); return -1; } } /* some test examples: ( echo "0 0 1 1"; echo " 0 0 3 2"; echo "0 0"; ) | fstcompile | fstrmsymbols "echo 3; echo 4|" | fstprint # should produce: # 0 0 1 1 # 0 0 0 2 # 0 ( echo "0 0 1 1"; echo " 0 0 3 2"; echo "0 0"; ) | fstcompile | fstrmsymbols --apply-to-output=true "echo 2; echo 3|" | fstprint # should produce: # 0 0 1 1 # 0 0 3 0 # 0 ( echo "0 0 1 1"; echo " 0 0 3 2"; echo "0 0"; ) | fstcompile | fstrmsymbols --remove-arcs=true "echo 3; echo 4|" | fstprint # should produce: # 0 0 1 1 # 0 ( echo "0 0 1 1"; echo " 0 0 3 2"; echo "0 0"; ) | fstcompile | fstrmsymbols --penalty=2 "echo 3; echo 4; echo 5|" | fstprint # should produce: # 0 0 1 1 # 0 0 3 2 2 # 0 */