nnet-nnet.h 5.6 KB
Newer Older
1 2
// nnet/nnet-nnet.h

3
// Copyright 2011-2013  Brno University of Technology (Author: Karel Vesely)
4

5 6
// See ../../COPYING for clarification regarding multiple authors
//
7 8 9 10 11 12 13 14 15 16 17 18 19
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//  http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

20 21
#ifndef KALDI_NNET_NNET_NNET_H_
#define KALDI_NNET_NNET_NNET_H_
22

23 24 25
#include <iostream>
#include <sstream>
#include <vector>
26 27 28 29

#include "base/kaldi-common.h"
#include "util/kaldi-io.h"
#include "matrix/matrix-lib.h"
30
#include "nnet/nnet-trnopts.h"
31 32 33
#include "nnet/nnet-component.h"

namespace kaldi {
34
namespace nnet1 {
35 36 37

class Nnet {
 public:
38
  Nnet() {}
Karel Vesely's avatar
Karel Vesely committed
39 40
  Nnet(const Nnet& other); // Copy constructor.
  Nnet &operator = (const Nnet& other); // Assignment operator.
41

42 43 44 45
  ~Nnet(); 

 public:
  /// Perform forward pass through the network
46
  void Propagate(const CuMatrixBase<BaseFloat> &in, CuMatrix<BaseFloat> *out); 
47
  /// Perform backward pass through the network
48
  void Backpropagate(const CuMatrixBase<BaseFloat> &out_diff, CuMatrix<BaseFloat> *in_diff);
49
  /// Perform forward pass through the network, don't keep buffers (use it when not training)
50
  void Feedforward(const CuMatrixBase<BaseFloat> &in, CuMatrix<BaseFloat> *out); 
51

52 53 54 55
  /// Dimensionality on network input (input feature dim.)
  int32 InputDim() const; 
  /// Dimensionality of network outputs (posteriors | bn-features | etc.)
  int32 OutputDim() const; 
56 57 58 59 60 61 62 63 64 65 66 67

  /// Returns number of components-- think of this as similar to # of layers, but
  /// e.g. the nonlinearity and the linear part count as separate components,
  /// so the number of components will be more than the number of layers.
  int32 NumComponents() const { return components_.size(); }

  const Component& GetComponent(int32 c) const;
  Component& GetComponent(int32 c);

  /// Sets the c'th component to "component", taking ownership of the pointer
  /// and deleting the corresponding one that we own.
  void SetComponent(int32 c, Component *component);
68
 
69 70 71 72 73 74 75 76 77
  /// Appends this component to the components already in the neural net.
  /// Takes ownership of the pointer
  void AppendComponent(Component *dynamically_allocated_comp);
  /// Append another network to the current one (copy components).
  void AppendNnet(const Nnet& nnet_to_append);

  /// Remove component
  void RemoveComponent(int32 c);
  void RemoveLastComponent() { RemoveComponent(NumComponents()-1); }
78

79 80 81 82 83 84 85 86
  /// Access to forward pass buffers
  const std::vector<CuMatrix<BaseFloat> >& PropagateBuffer() const { 
    return propagate_buf_; 
  }
  /// Access to backward pass buffers
  const std::vector<CuMatrix<BaseFloat> >& BackpropagateBuffer() const { 
    return backpropagate_buf_; 
  }
87

88
  /// Get the number of parameters in the network
89
  int32 NumParams() const;
90
  /// Get the network weights in a supervector
Karel Vesely's avatar
Karel Vesely committed
91 92 93
  void GetParams(Vector<BaseFloat>* wei_copy) const;
  /// Get the network weights in a supervector
  void GetWeights(Vector<BaseFloat>* wei_copy) const;
94 95 96
  /// Set the network weights from a supervector
  void SetWeights(const Vector<BaseFloat>& wei_src);
  /// Get the gradient stored in the network
Karel Vesely's avatar
Karel Vesely committed
97
  void GetGradient(Vector<BaseFloat>* grad_copy) const;
98

Karel Vesely's avatar
Karel Vesely committed
99 100
  /// Set the dropout rate 
  void SetDropoutRetention(BaseFloat r);
101 102
  /// Reset streams in LSTM multi-stream training,
  void ResetLstmStreams(const std::vector<int32> &stream_reset_flag);
Karel Vesely's avatar
Karel Vesely committed
103

104 105 106
  /// set sequence length in LSTM multi-stream training
  void SetSeqLengths(const std::vector<int32> &sequence_lengths);

107 108
  /// Initialize MLP from config
  void Init(const std::string &config_file);
109
  /// Read the MLP from file (can add layers to exisiting instance of Nnet)
110
  void Read(const std::string &file);  
111
  /// Read the MLP from stream (can add layers to exisiting instance of Nnet)
112
  void Read(std::istream &in, bool binary);  
113
  /// Write MLP to file
Karel Vesely's avatar
Karel Vesely committed
114
  void Write(const std::string &file, bool binary) const;
115
  /// Write MLP to stream 
Karel Vesely's avatar
Karel Vesely committed
116
  void Write(std::ostream &out, bool binary) const;   
117
  
Karel Vesely's avatar
Karel Vesely committed
118
  /// Create string with human readable description of the nnet
119
  std::string Info() const;
Karel Vesely's avatar
Karel Vesely committed
120 121 122 123
  /// Create string with per-component gradient statistics
  std::string InfoGradient() const;
  /// Create string with propagation-buffer statistics
  std::string InfoPropagate() const;
124 125
  /// Create string with back-propagation-buffer statistics
  std::string InfoBackPropagate() const;
126 127
  /// Consistency check.
  void Check() const;
Karel Vesely's avatar
Karel Vesely committed
128 129
  /// Relese the memory
  void Destroy();
130

131 132 133 134 135
  /// Set training hyper-parameters to the network and its UpdatableComponent(s)
  void SetTrainOptions(const NnetTrainOptions& opts);
  /// Get training hyper-parameters from the network
  const NnetTrainOptions& GetTrainOptions() const {
    return opts_;
136
  }
137 138

 private:
139 140 141
  /// Vector which contains all the components composing the neural network,
  /// the components are for example: AffineTransform, Sigmoid, Softmax
  std::vector<Component*> components_; 
142 143 144 145

  std::vector<CuMatrix<BaseFloat> > propagate_buf_; ///< buffers for forward pass
  std::vector<CuMatrix<BaseFloat> > backpropagate_buf_; ///< buffers for backward pass

146 147
  /// Option class with hyper-parameters passed to UpdatableComponent(s)
  NnetTrainOptions opts_;
148 149 150
};
  

151
} // namespace nnet1
152
} // namespace kaldi
153

154
#endif  // KALDI_NNET_NNET_NNET_H_
155