rescale-test.cc 2.93 KB
Newer Older
1 2 3 4
// fstext/rescale-test.cc

// Copyright 2009-2011  Microsoft Corporation

5 6
// See ../../COPYING for clarification regarding multiple authors
//
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//  http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#include "fstext/rescale.h"
#include "fstext/fstext-utils.h"
#include "fstext/fst-test-utils.h"
23
#include "base/kaldi-math.h"
24 25 26 27 28
// Just check that it compiles, for now.

namespace fst
{

29 30
  using kaldi::Exp;
  using kaldi::Log;
31 32 33 34 35 36 37

template<class Arc> void TestComputeTotalWeight() {
  typedef typename Arc::Weight Weight;
  VectorFst<Arc> *fst = RandFst<Arc>();

  std::cout <<" printing FST at start\n";
  {
38 39 40
#ifdef HAVE_OPENFST_GE_10400
    FstPrinter<Arc> fstprinter(*fst, NULL, NULL, NULL, false, true, "\t");
#else
41
    FstPrinter<Arc> fstprinter(*fst, NULL, NULL, NULL, false, true);
42
#endif
43 44 45
    fstprinter.Print(&std::cout, "standard output");
  }

46
  Weight max(-Log(2.0));
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
  Weight tot = ComputeTotalWeight(*fst, max);
  std::cout << "Total weight is: " << tot.Value() << '\n';


  if (tot.Value() > max.Value()) {  // didn't max out...
    Weight tot2 = ShortestDistance(*fst);
    if (!ApproxEqual(tot, tot2, 0.05)) {
      KALDI_ERR << tot << " differs from " << tot2;
      assert(0);
    }
    std::cout << "our tot: " <<tot.Value() <<", shortest-distance tot: " << tot2.Value() << '\n';
  }

  delete fst;
}



void TestRescaleToStochastic() {
  typedef LogArc Arc;
  typedef Arc::Weight Weight;
  RandFstOptions opts;
  opts.allow_empty = false;
  VectorFst<Arc> *fst = RandFst<Arc>(opts);

  std::cout <<" printing FST at start\n";
  {
74 75 76
#ifdef HAVE_OPENFST_GE_10400
    FstPrinter<Arc> fstprinter(*fst, NULL, NULL, NULL, false, true, "\t");
#else
77
    FstPrinter<Arc> fstprinter(*fst, NULL, NULL, NULL, false, true);
78
#endif
79 80 81 82 83
    fstprinter.Print(&std::cout, "standard output");

  }
  float diff = 0.01;

84
  RescaleToStochastic(fst, diff);
85
  Weight tot = ShortestDistance(*fst),
86
      tot2 = ComputeTotalWeight(*fst, Weight(-Log(2.0)));
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
  std::cerr <<  " tot is " << tot<<", tot2 = "<<tot2<<'\n';
  assert(ApproxEqual(tot2, Weight::One(), diff));

  delete fst;
}


} // end namespace fst


int main() {
  using namespace fst;
  for (int i = 0;i < 10;i++) {
    std::cout << "Testing with tropical\n";
    fst::TestComputeTotalWeight<StdArc>();
    std::cout << "Testing with log:\n";
    fst::TestComputeTotalWeight<LogArc>();
  }
  for (int i = 0;i < 10;i++) {
    std::cout << "i = "<<i<<'\n';
    fst::TestRescaleToStochastic();
  }
}