decoder-wrappers.h 8.05 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
// decoder/decoder-wrappers.h

// Copyright   2014  Johns Hopkins University (author: Daniel Povey)

// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//  http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#ifndef KALDI_DECODER_DECODER_WRAPPERS_H_
#define KALDI_DECODER_DECODER_WRAPPERS_H_

#include "itf/options-itf.h"
#include "decoder/lattice-faster-decoder.h"
#include "decoder/lattice-simple-decoder.h"

// This header contains declarations from various convenience functions that are called
// from binary-level programs such as gmm-decode-faster.cc, gmm-align-compiled.cc, and
// so on.

namespace kaldi {


struct AlignConfig {
  BaseFloat beam;
  BaseFloat retry_beam;
  bool careful;

  AlignConfig(): beam(200.0), retry_beam(0.0), careful(false) { }

41 42 43 44 45 46 47
  void Register(OptionsItf *opts) {
    opts->Register("beam", &beam, "Decoding beam used in alignment");
    opts->Register("retry-beam", &retry_beam,
                   "Decoding beam for second try at alignment");
    opts->Register("careful", &careful,
                   "If true, do 'careful' alignment, which is better at detecting "
                   "alignment failure (involves loop to start of decoding graph).");
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
  }
};


/// AlignUtteranceWapper is a wrapper for alignment code used in training, that
/// is called from many different binaries, e.g. gmm-align, gmm-align-compiled,
/// sgmm-align, etc.  The writers for alignments and words will only be written
/// to if they are open.  The num_done, num_error, num_retried, tot_like and
/// frame_count pointers will (if non-NULL) be incremented or added to, not set,
/// by this function.
void AlignUtteranceWrapper(
    const AlignConfig &config,
    const std::string &utt,
    BaseFloat acoustic_scale,  // affects scores written to scores_writer, if
                               // present
    fst::VectorFst<fst::StdArc> *fst,  // non-const in case config.careful ==
                                       // true, we add loop.
    DecodableInterface *decodable,  // not const but is really an input.
    Int32VectorWriter *alignment_writer,
    BaseFloatWriter *scores_writer,
    int32 *num_done,
    int32 *num_error,
    int32 *num_retried,
    double *tot_like,
    int64 *frame_count);



/// This function modifies the decoding graph for what we call "careful
/// alignment".  The problem we are trying to solve is that if the decoding eats
/// up the words in the graph too fast, it can get stuck at the end, and produce
/// what looks like a valid alignment even though there was really a failure.
/// So what we want to do is to introduce, after the final-states of the graph,
/// a "blind alley" with no final-probs reachable, where the decoding can go to
/// get lost.  Our basic idea is to append the decoding-graph to itself using
/// the fst Concat operation; but in order that there should be final-probs at the end of
/// the first but not the second FST, we modify the right-hand argument to the
/// Concat operation so that it has none of the original final-probs, and add
/// a "pre-initial" state that is final.
void ModifyGraphForCarefulAlignment(
    fst::VectorFst<fst::StdArc> *fst);


/// This function DecodeUtteranceLatticeFaster is used in several decoders, and
/// we have moved it here.  Note: this is really "binary-level" code as it
/// involves table readers and writers; we've just put it here as there is no
/// other obvious place to put it.  If determinize == false, it writes to
/// lattice_writer, else to compact_lattice_writer.  The writers for
/// alignments and words will only be written to if they are open.
bool DecodeUtteranceLatticeFaster(
    LatticeFasterDecoder &decoder, // not const but is really an input.
    DecodableInterface &decodable, // not const but is really an input.
    const TransitionModel &trans_model,
    const fst::SymbolTable *word_syms,
    std::string utt,
    double acoustic_scale,
    bool determinize,
    bool allow_partial,
    Int32VectorWriter *alignments_writer,
    Int32VectorWriter *words_writer,
    CompactLatticeWriter *compact_lattice_writer,
    LatticeWriter *lattice_writer,
    double *like_ptr);  // puts utterance's likelihood in like_ptr on success.

/// This class basically does the same job as the function
/// DecodeUtteranceLatticeFaster, but in a way that allows us
/// to build a multi-threaded command line program more easily,
/// using code in ../thread/kaldi-task-sequence.h.  The main
/// computation takes place in operator (), and the output happens
/// in the destructor.
class DecodeUtteranceLatticeFasterClass {
 public:
  // Initializer sets various variables.
  // NOTE: we "take ownership" of "decoder" and "decodable".  These
  // are deleted by the destructor.  On error, "num_err" is incremented.
  DecodeUtteranceLatticeFasterClass(
      LatticeFasterDecoder *decoder,
      DecodableInterface *decodable,
      const TransitionModel &trans_model,
      const fst::SymbolTable *word_syms,
      std::string utt,
      BaseFloat acoustic_scale,
      bool determinize,
      bool allow_partial,
      Int32VectorWriter *alignments_writer,
      Int32VectorWriter *words_writer,
      CompactLatticeWriter *compact_lattice_writer,
      LatticeWriter *lattice_writer,
      double *like_sum, // on success, adds likelihood to this.
      int64 *frame_sum, // on success, adds #frames to this.
      int32 *num_done, // on success (including partial decode), increments this.
      int32 *num_err,  // on failure, increments this.
      int32 *num_partial);  // If partial decode (final-state not reached), increments this.
  void operator () (); // The decoding happens here.
  ~DecodeUtteranceLatticeFasterClass(); // Output happens here.
 private:
  // The following variables correspond to inputs:
  LatticeFasterDecoder *decoder_;
  DecodableInterface *decodable_;
  const TransitionModel *trans_model_;
  const fst::SymbolTable *word_syms_;
  std::string utt_;
  BaseFloat acoustic_scale_;
  bool determinize_;
  bool allow_partial_;
  Int32VectorWriter *alignments_writer_;
  Int32VectorWriter *words_writer_;
  CompactLatticeWriter *compact_lattice_writer_;
  LatticeWriter *lattice_writer_;
  double *like_sum_;
  int64 *frame_sum_;
  int32 *num_done_;
  int32 *num_err_;
  int32 *num_partial_;

  // The following variables are stored by the computation.
  bool computed_; // operator ()  was called.
  bool success_; // decoding succeeded (possibly partial)
  bool partial_; // decoding was partial.
  CompactLattice *clat_; // Stored output, if determinize_ == true.
  Lattice *lat_; // Stored output, if determinize_ == false.
};

// This function DecodeUtteranceLatticeSimple is used in several decoders, and
// we have moved it here.  Note: this is really "binary-level" code as it
// involves table readers and writers; we've just put it here as there is no
// other obvious place to put it.  If determinize == false, it writes to
// lattice_writer, else to compact_lattice_writer.  The writers for
// alignments and words will only be written to if they are open.
bool DecodeUtteranceLatticeSimple(
    LatticeSimpleDecoder &decoder, // not const but is really an input.
    DecodableInterface &decodable, // not const but is really an input.
    const TransitionModel &trans_model,
    const fst::SymbolTable *word_syms,
    std::string utt,
    double acoustic_scale,
    bool determinize,
    bool allow_partial,
    Int32VectorWriter *alignments_writer,
    Int32VectorWriter *words_writer,
    CompactLatticeWriter *compact_lattice_writer,
    LatticeWriter *lattice_writer,
    double *like_ptr);  // puts utterance's likelihood in like_ptr on success.



} // end namespace kaldi.


#endif