nnet-nnet.h 5.58 KB
Newer Older
1 2
// nnet/nnet-nnet.h

3
// Copyright 2011-2013  Brno University of Technology (Author: Karel Vesely)
4

5 6
// See ../../COPYING for clarification regarding multiple authors
//
7 8 9 10 11 12 13 14 15 16 17 18 19
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//  http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

20 21
#ifndef KALDI_NNET_NNET_NNET_H_
#define KALDI_NNET_NNET_NNET_H_
22

23 24 25
#include <iostream>
#include <sstream>
#include <vector>
26 27 28 29

#include "base/kaldi-common.h"
#include "util/kaldi-io.h"
#include "matrix/matrix-lib.h"
30
#include "nnet/nnet-trnopts.h"
31 32 33
#include "nnet/nnet-component.h"

namespace kaldi {
34
namespace nnet1 {
35 36 37

class Nnet {
 public:
38
  Nnet() {}
nichongjia's avatar
nichongjia committed
39
  Nnet(const Nnet& other);  // Copy constructor.
Karel Vesely's avatar
Karel Vesely committed
40
  Nnet &operator = (const Nnet& other); // Assignment operator.
41

nichongjia's avatar
nichongjia committed
42
  ~Nnet();
43 44 45

 public:
  /// Perform forward pass through the network
nichongjia's avatar
nichongjia committed
46
  void Propagate(const CuMatrixBase<BaseFloat> &in, CuMatrix<BaseFloat> *out);
47
  /// Perform backward pass through the network
48
  void Backpropagate(const CuMatrixBase<BaseFloat> &out_diff, CuMatrix<BaseFloat> *in_diff);
49
  /// Perform forward pass through the network, don't keep buffers (use it when not training)
nichongjia's avatar
nichongjia committed
50
  void Feedforward(const CuMatrixBase<BaseFloat> &in, CuMatrix<BaseFloat> *out);
51

52
  /// Dimensionality on network input (input feature dim.)
nichongjia's avatar
nichongjia committed
53
  int32 InputDim() const;
54
  /// Dimensionality of network outputs (posteriors | bn-features | etc.)
nichongjia's avatar
nichongjia committed
55
  int32 OutputDim() const;
56 57 58 59 60 61 62 63 64 65 66 67

  /// Returns number of components-- think of this as similar to # of layers, but
  /// e.g. the nonlinearity and the linear part count as separate components,
  /// so the number of components will be more than the number of layers.
  int32 NumComponents() const { return components_.size(); }

  const Component& GetComponent(int32 c) const;
  Component& GetComponent(int32 c);

  /// Sets the c'th component to "component", taking ownership of the pointer
  /// and deleting the corresponding one that we own.
  void SetComponent(int32 c, Component *component);
nichongjia's avatar
nichongjia committed
68

69 70 71 72 73 74 75 76 77
  /// Appends this component to the components already in the neural net.
  /// Takes ownership of the pointer
  void AppendComponent(Component *dynamically_allocated_comp);
  /// Append another network to the current one (copy components).
  void AppendNnet(const Nnet& nnet_to_append);

  /// Remove component
  void RemoveComponent(int32 c);
  void RemoveLastComponent() { RemoveComponent(NumComponents()-1); }
78

79
  /// Access to forward pass buffers
nichongjia's avatar
nichongjia committed
80 81
  const std::vector<CuMatrix<BaseFloat> >& PropagateBuffer() const {
    return propagate_buf_;
82 83
  }
  /// Access to backward pass buffers
nichongjia's avatar
nichongjia committed
84 85
  const std::vector<CuMatrix<BaseFloat> >& BackpropagateBuffer() const {
    return backpropagate_buf_;
86
  }
87

88
  /// Get the number of parameters in the network
89
  int32 NumParams() const;
90
  /// Get the network weights in a supervector
Karel Vesely's avatar
Karel Vesely committed
91 92 93
  void GetParams(Vector<BaseFloat>* wei_copy) const;
  /// Get the network weights in a supervector
  void GetWeights(Vector<BaseFloat>* wei_copy) const;
94 95 96
  /// Set the network weights from a supervector
  void SetWeights(const Vector<BaseFloat>& wei_src);
  /// Get the gradient stored in the network
Karel Vesely's avatar
Karel Vesely committed
97
  void GetGradient(Vector<BaseFloat>* grad_copy) const;
98

nichongjia's avatar
nichongjia committed
99
  /// Set the dropout rate
Karel Vesely's avatar
Karel Vesely committed
100
  void SetDropoutRetention(BaseFloat r);
101 102
  /// Reset streams in LSTM multi-stream training,
  void ResetLstmStreams(const std::vector<int32> &stream_reset_flag);
Karel Vesely's avatar
Karel Vesely committed
103

104 105 106
  /// set sequence length in LSTM multi-stream training
  void SetSeqLengths(const std::vector<int32> &sequence_lengths);

107 108
  /// Initialize MLP from config
  void Init(const std::string &config_file);
109
  /// Read the MLP from file (can add layers to exisiting instance of Nnet)
nichongjia's avatar
nichongjia committed
110
  void Read(const std::string &file);
111
  /// Read the MLP from stream (can add layers to exisiting instance of Nnet)
nichongjia's avatar
nichongjia committed
112
  void Read(std::istream &in, bool binary);
113
  /// Write MLP to file
Karel Vesely's avatar
Karel Vesely committed
114
  void Write(const std::string &file, bool binary) const;
nichongjia's avatar
nichongjia committed
115 116 117
  /// Write MLP to stream
  void Write(std::ostream &out, bool binary) const;

Karel Vesely's avatar
Karel Vesely committed
118
  /// Create string with human readable description of the nnet
119
  std::string Info() const;
Karel Vesely's avatar
Karel Vesely committed
120 121 122 123
  /// Create string with per-component gradient statistics
  std::string InfoGradient() const;
  /// Create string with propagation-buffer statistics
  std::string InfoPropagate() const;
124 125
  /// Create string with back-propagation-buffer statistics
  std::string InfoBackPropagate() const;
126 127
  /// Consistency check.
  void Check() const;
Karel Vesely's avatar
Karel Vesely committed
128 129
  /// Relese the memory
  void Destroy();
130

131 132 133 134 135
  /// Set training hyper-parameters to the network and its UpdatableComponent(s)
  void SetTrainOptions(const NnetTrainOptions& opts);
  /// Get training hyper-parameters from the network
  const NnetTrainOptions& GetTrainOptions() const {
    return opts_;
136
  }
137 138

 private:
139 140
  /// Vector which contains all the components composing the neural network,
  /// the components are for example: AffineTransform, Sigmoid, Softmax
nichongjia's avatar
nichongjia committed
141
  std::vector<Component*> components_;
142

nichongjia's avatar
nichongjia committed
143 144
  std::vector<CuMatrix<BaseFloat> > propagate_buf_;  ///< buffers for forward pass
  std::vector<CuMatrix<BaseFloat> > backpropagate_buf_;  ///< buffers for backward pass
145

146 147
  /// Option class with hyper-parameters passed to UpdatableComponent(s)
  NnetTrainOptions opts_;
148 149
};

nichongjia's avatar
nichongjia committed
150 151
}  // namespace nnet1
}  // namespace kaldi
152

153
#endif  // KALDI_NNET_NNET_NNET_H_
154