nnet-component.h 8.46 KB
Newer Older
1 2
// nnet/nnet-component.h

3
// Copyright 2011-2013  Brno University of Technology (Author: Karel Vesely)
4

5 6
// See ../../COPYING for clarification regarding multiple authors
//
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//  http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.



22 23
#ifndef KALDI_NNET_NNET_COMPONENT_H_
#define KALDI_NNET_NNET_COMPONENT_H_
24 25 26 27 28 29


#include "base/kaldi-common.h"
#include "matrix/matrix-lib.h"
#include "cudamatrix/cu-matrix.h"
#include "cudamatrix/cu-vector.h"
30
#include "nnet/nnet-trnopts.h"
31 32 33 34

#include <iostream>

namespace kaldi {
35
namespace nnet1 {
36 37

/**
38 39 40 41
 * Abstract class, building block of the network.
 * It is able to propagate (PropagateFnc: compute the output based on its input)
 * and backpropagate (BackpropagateFnc: i.e. transform loss derivative w.r.t. output to derivative w.r.t. the input)
 * the formulas are implemented in descendant classes (AffineTransform,Sigmoid,Softmax,...).
42 43 44
 */ 
class Component {

45
 /// Component type identification mechanism
46
 public: 
47
  /// Types of Components
48 49 50 51
  typedef enum {
    kUnknown = 0x0,
     
    kUpdatableComponent = 0x0100, 
52
    kAffineTransform,
53
    kLinearTransform,
54
    kConvolutionalComponent,
55
    kConvolutional2DComponent,
56
    kLstmProjectedStreams,
57
    kBLstmProjectedStreams,
58 59 60

    kActivationFunction = 0x0200, 
    kSoftmax, 
61
    kBlockSoftmax, 
62
    kSigmoid,
63 64
    kTanh,
    kDropout,
65

Karel Vesely's avatar
Karel Vesely committed
66
    kTranform = 0x0400,
67
    kRbm,
68
    kSplice,
69 70 71
    kCopy,
    kTranspose,
    kBlockLinearity,
72
    kAddShift,
David Imseng's avatar
David Imseng committed
73
    kRescale,
74 75 76 77
    
    kKlHmm = 0x0800,
    kSentenceAveragingComponent,
    kAveragePoolingComponent,
78
    kAveragePooling2DComponent,
79
    kMaxPoolingComponent,
80
    kMaxPooling2DComponent,
81
    kFramePoolingComponent, 
82
    kParallelComponent
83
  } ComponentType;
84
  /// A pair of type and marker 
85 86
  struct key_value {
    const Component::ComponentType key;
87
    const char *value;
88
  };
89
  /// Mapping of types and markers (the table is defined in nnet-component.cc) 
90 91 92
  static const struct key_value kMarkerMap[];
  /// Convert component type to marker
  static const char* TypeToMarker(ComponentType t);
93
  /// Convert marker to component type (case insensitive)
94
  static ComponentType MarkerToType(const std::string &s);
nichongjia's avatar
nichongjia committed
95 96
  /// during training of LSTM models.
  virtual void SetSeqLengths(std::vector<int> &sequence_lengths) { }
97 98
 
 /// General interface of a component  
99
 public:
100 101 102 103 104 105 106
  Component(int32 input_dim, int32 output_dim) 
      : input_dim_(input_dim), output_dim_(output_dim) { }
  virtual ~Component() { }

  /// Copy component (deep copy).
  virtual Component* Copy() const = 0;

107 108 109 110 111 112 113 114
  /// Get Type Identification of the component
  virtual ComponentType GetType() const = 0;  
  /// Check if contains trainable parameters 
  virtual bool IsUpdatable() const { 
    return false; 
  }

  /// Get size of input vectors
115
  int32 InputDim() const { 
116 117 118
    return input_dim_; 
  }  
  /// Get size of output vectors 
119
  int32 OutputDim() const { 
120 121 122
    return output_dim_; 
  }
 
123
  /// Perform forward pass propagation Input->Output
124
  void Propagate(const CuMatrixBase<BaseFloat> &in, CuMatrix<BaseFloat> *out); 
125
  /// Perform backward pass propagation, out_diff -> in_diff
126
  /// '&in' and '&out' will sometimes be unused... 
127 128 129
  void Backpropagate(const CuMatrixBase<BaseFloat> &in,
                     const CuMatrixBase<BaseFloat> &out,
                     const CuMatrixBase<BaseFloat> &out_diff,
130
                     CuMatrix<BaseFloat> *in_diff); 
131

132 133
  /// Initialize component from a line in config file
  static Component* Init(const std::string &conf_line);
134
  /// Read component from stream
135
  static Component* Read(std::istream &is, bool binary);
136
  /// Write component to stream
137
  void Write(std::ostream &os, bool binary) const;
138

139 140
  /// Optionally print some additional info
  virtual std::string Info() const { return ""; }
Karel Vesely's avatar
Karel Vesely committed
141
  virtual std::string InfoGradient() const { return ""; }
142

143

144
 /// Abstract interface for propagation/backpropagation 
145
 protected:
146
  /// Forward pass transformation (to be implemented by descending class...)
147 148
  virtual void PropagateFnc(const CuMatrixBase<BaseFloat> &in,
                            CuMatrixBase<BaseFloat> *out) = 0;
149
  /// Backward pass transformation (to be implemented by descending class...)
150 151 152 153
  virtual void BackpropagateFnc(const CuMatrixBase<BaseFloat> &in,
                                const CuMatrixBase<BaseFloat> &out,
                                const CuMatrixBase<BaseFloat> &out_diff,
                                CuMatrixBase<BaseFloat> *in_diff) = 0;
154

155 156 157
  /// Initialize internal data of a component
  virtual void InitData(std::istream &is) { }

158
  /// Reads the component content
159
  virtual void ReadData(std::istream &is, bool binary) { }
160 161

  /// Writes the component content
162
  virtual void WriteData(std::ostream &os, bool binary) const { }
163

164
 /// Data members
165
 protected:
166 167
  int32 input_dim_;  ///< Size of input vectors
  int32 output_dim_; ///< Size of output vectors
168 169 170 171 172

 private:
  /// Create new intance of component
  static Component* NewComponentOfType(ComponentType t, 
                      int32 input_dim, int32 output_dim);
173
  
174 175
 protected:
  //KALDI_DISALLOW_COPY_AND_ASSIGN(Component);
176 177 178 179
};


/**
180 181
 * Class UpdatableComponent is a Component which has trainable parameters,
 * contains SGD training hyper-parameters in NnetTrainOptions.
182 183 184
 */
class UpdatableComponent : public Component {
 public: 
185 186
  UpdatableComponent(int32 input_dim, int32 output_dim)
    : Component(input_dim, output_dim) { }
187 188 189 190 191 192 193
  virtual ~UpdatableComponent() { }

  /// Check if contains trainable parameters 
  bool IsUpdatable() const { 
    return true; 
  }

Karel Vesely's avatar
Karel Vesely committed
194 195 196 197
  /// Number of trainable parameters
  virtual int32 NumParams() const = 0;
  virtual void GetParams(Vector<BaseFloat> *params) const = 0;

198
  /// Compute gradient and update parameters
199 200
  virtual void Update(const CuMatrixBase<BaseFloat> &input,
                      const CuMatrixBase<BaseFloat> &diff) = 0;
201

202
  /// Sets the training options to the component
203
  virtual void SetTrainOptions(const NnetTrainOptions &opts) {
204
    opts_ = opts;
205
  }
206 207 208
  /// Gets the training options from the component
  const NnetTrainOptions& GetTrainOptions() const { 
    return opts_; 
209 210
  }

211 212
  virtual void InitData(std::istream &is) = 0;

213
 protected:
214 215
  /// Option-class with training hyper-parameters
  NnetTrainOptions opts_; 
216 217 218
};


219
inline void Component::Propagate(const CuMatrixBase<BaseFloat> &in,
220
                                 CuMatrix<BaseFloat> *out) {
221
  // Check the dims
222
  if (input_dim_ != in.NumCols()) {
223 224
    KALDI_ERR << "Non-matching dims! " << TypeToMarker(GetType()) 
              << " input-dim : " << input_dim_ << " data : " << in.NumCols();
225
  }
226
  // Allocate target buffer
227
  out->Resize(in.NumRows(), output_dim_, kSetZero); // reset
228
  // Call the propagation implementation of the component
229 230 231 232
  PropagateFnc(in, out);
}


233 234 235
inline void Component::Backpropagate(const CuMatrixBase<BaseFloat> &in,
                                     const CuMatrixBase<BaseFloat> &out,
                                     const CuMatrixBase<BaseFloat> &out_diff,
236
                                     CuMatrix<BaseFloat> *in_diff) {
237
  // Check the dims
238
  if (output_dim_ != out_diff.NumCols()) {
239
    KALDI_ERR << "Non-matching output dims, component:" << output_dim_ 
240
              << " data:" << out_diff.NumCols();
241
  }
242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262
  
  // Target buffer NULL : backpropagate only through components with nested nnets.
  if (in_diff == NULL) {
    if (GetType() == kParallelComponent ||
        GetType() == kSentenceAveragingComponent) {
      BackpropagateFnc(in, out, out_diff, NULL);
    } else {
      return;
    }
  } else {
    // Allocate target buffer
    in_diff->Resize(out_diff.NumRows(), input_dim_, kSetZero); // reset
    // Asserts on the dims
    KALDI_ASSERT((in.NumRows() == out.NumRows()) &&
                 (in.NumRows() == out_diff.NumRows()) &&
                 (in.NumRows() == in_diff->NumRows()));
    KALDI_ASSERT(in.NumCols() == in_diff->NumCols());
    KALDI_ASSERT(out.NumCols() == out_diff.NumCols());
    // Call the backprop implementation of the component
    BackpropagateFnc(in, out, out_diff, in_diff);
  }
263 264 265
}


266
} // namespace nnet1
267 268 269 270
} // namespace kaldi


#endif