Commit 13c78d3b authored by Karel Vesely's avatar Karel Vesely

Merge pull request #66 from nichongjia/blstm

blstm remove bug
parents 9c257c5a 7afae3f8
...@@ -58,17 +58,17 @@ print "<NnetProto>" ...@@ -58,17 +58,17 @@ print "<NnetProto>"
# normally we won't use more than 2 layers of LSTM # normally we won't use more than 2 layers of LSTM
if o.num_layers == 1: if o.num_layers == 1:
print "<BLstmProjectedStreams> <InputDim> %d <OutputDim> %d <CellDim> %s <ParamScale> %f <ClipGradient> %f" % \ print "<BLstmProjectedStreams> <InputDim> %d <OutputDim> %d <CellDim> %s <ParamScale> %f <ClipGradient> %f" % \
(feat_dim, o.num_recurrent, o.num_cells, o.lstm_stddev_factor, o.clip_gradient) (feat_dim, 2*o.num_recurrent, o.num_cells, o.lstm_stddev_factor, o.clip_gradient)
elif o.num_layers == 2: elif o.num_layers == 2:
print "<BLstmProjectedStreams> <InputDim> %d <OutputDim> %d <CellDim> %s <ParamScale> %f <ClipGradient> %f" % \ print "<BLstmProjectedStreams> <InputDim> %d <OutputDim> %d <CellDim> %s <ParamScale> %f <ClipGradient> %f" % \
(feat_dim, o.num_recurrent, o.num_cells, o.lstm_stddev_factor, o.clip_gradient) (feat_dim, 2*o.num_recurrent, o.num_cells, o.lstm_stddev_factor, o.clip_gradient)
print "<BLstmProjectedStreams> <InputDim> %d <OutputDim> %d <CellDim> %s <ParamScale> %f <ClipGradient> %f" % \ print "<BLstmProjectedStreams> <InputDim> %d <OutputDim> %d <CellDim> %s <ParamScale> %f <ClipGradient> %f" % \
(o.num_recurrent, o.num_recurrent, o.num_cells, o.lstm_stddev_factor, o.clip_gradient) (2*o.num_recurrent, 2*o.num_recurrent, o.num_cells, o.lstm_stddev_factor, o.clip_gradient)
else: else:
sys.stderr.write("make_lstm_proto.py ERROR: more than 2 layers of LSTM, not supported yet.\n") sys.stderr.write("make_lstm_proto.py ERROR: more than 2 layers of LSTM, not supported yet.\n")
sys.exit(1) sys.exit(1)
print "<AffineTransform> <InputDim> %d <OutputDim> %d <BiasMean> 0.0 <BiasRange> 0.0 <ParamStddev> %f" % \ print "<AffineTransform> <InputDim> %d <OutputDim> %d <BiasMean> 0.0 <BiasRange> 0.0 <ParamStddev> %f" % \
(o.num_recurrent, num_leaves, o.param_stddev_factor) (2*o.num_recurrent, num_leaves, o.param_stddev_factor)
print "<Softmax> <InputDim> %d <OutputDim> %d" % \ print "<Softmax> <InputDim> %d <OutputDim> %d" % \
(num_leaves, num_leaves) (num_leaves, num_leaves)
print "</NnetProto>" print "</NnetProto>"
......
This diff is collapsed.
...@@ -120,7 +120,7 @@ Component* Component::NewComponentOfType(ComponentType comp_type, ...@@ -120,7 +120,7 @@ Component* Component::NewComponentOfType(ComponentType comp_type,
ans = new LstmProjectedStreams(input_dim, output_dim); ans = new LstmProjectedStreams(input_dim, output_dim);
break; break;
case Component::kBLstmProjectedStreams : case Component::kBLstmProjectedStreams :
ans = new BLstmProjectedStreams(input_dim, output_dim); ans = new BLstmProjectedStreams(input_dim, output_dim);
break; break;
case Component::kSoftmax : case Component::kSoftmax :
ans = new Softmax(input_dim, output_dim); ans = new Softmax(input_dim, output_dim);
......
...@@ -91,7 +91,7 @@ class Component { ...@@ -91,7 +91,7 @@ class Component {
/// Convert component type to marker /// Convert component type to marker
static const char* TypeToMarker(ComponentType t); static const char* TypeToMarker(ComponentType t);
/// Convert marker to component type (case insensitive) /// Convert marker to component type (case insensitive)
static ComponentType MarkerToType(const std::string &s); static ComponentType MarkerToType(const std::string &s);
/// General interface of a component /// General interface of a component
public: public:
......
...@@ -32,7 +32,7 @@ namespace nnet1 { ...@@ -32,7 +32,7 @@ namespace nnet1 {
Nnet::Nnet(const Nnet& other) { Nnet::Nnet(const Nnet& other) {
// copy the components // copy the components
for(int32 i=0; i<other.NumComponents(); i++) { for(int32 i = 0; i < other.NumComponents(); i++) {
components_.push_back(other.GetComponent(i).Copy()); components_.push_back(other.GetComponent(i).Copy());
} }
// create empty buffers // create empty buffers
...@@ -40,13 +40,13 @@ Nnet::Nnet(const Nnet& other) { ...@@ -40,13 +40,13 @@ Nnet::Nnet(const Nnet& other) {
backpropagate_buf_.resize(NumComponents()+1); backpropagate_buf_.resize(NumComponents()+1);
// copy train opts // copy train opts
SetTrainOptions(other.opts_); SetTrainOptions(other.opts_);
Check(); Check();
} }
Nnet & Nnet::operator = (const Nnet& other) { Nnet & Nnet::operator = (const Nnet& other) {
Destroy(); Destroy();
// copy the components // copy the components
for(int32 i=0; i<other.NumComponents(); i++) { for(int32 i = 0; i < other.NumComponents(); i++) {
components_.push_back(other.GetComponent(i).Copy()); components_.push_back(other.GetComponent(i).Copy());
} }
// create empty buffers // create empty buffers
...@@ -356,15 +356,19 @@ void Nnet::ResetLstmStreams(const std::vector<int32> &stream_reset_flag) { ...@@ -356,15 +356,19 @@ void Nnet::ResetLstmStreams(const std::vector<int32> &stream_reset_flag) {
if (GetComponent(c).GetType() == Component::kLstmProjectedStreams) { if (GetComponent(c).GetType() == Component::kLstmProjectedStreams) {
LstmProjectedStreams& comp = dynamic_cast<LstmProjectedStreams&>(GetComponent(c)); LstmProjectedStreams& comp = dynamic_cast<LstmProjectedStreams&>(GetComponent(c));
comp.ResetLstmStreams(stream_reset_flag); comp.ResetLstmStreams(stream_reset_flag);
} }
}
}
void Nnet::SetSeqLengths(const std::vector<int32> &sequence_lengths) {
for (int32 c=0; c < NumComponents(); c++) {
if (GetComponent(c).GetType() == Component::kBLstmProjectedStreams) { if (GetComponent(c).GetType() == Component::kBLstmProjectedStreams) {
BLstmProjectedStreams& comp = dynamic_cast<BLstmProjectedStreams&>(GetComponent(c)); BLstmProjectedStreams& comp = dynamic_cast<BLstmProjectedStreams&>(GetComponent(c));
comp.ResetLstmStreams(stream_reset_flag); comp.SetSeqLengths(sequence_lengths);
} }
} }
} }
void Nnet::Init(const std::string &file) { void Nnet::Init(const std::string &file) {
Input in(file); Input in(file);
std::istream &is = in.Stream(); std::istream &is = in.Stream();
......
...@@ -36,23 +36,23 @@ namespace nnet1 { ...@@ -36,23 +36,23 @@ namespace nnet1 {
class Nnet { class Nnet {
public: public:
Nnet() {} Nnet() {}
Nnet(const Nnet& other); // Copy constructor. Nnet(const Nnet& other); // Copy constructor.
Nnet &operator = (const Nnet& other); // Assignment operator. Nnet &operator = (const Nnet& other); // Assignment operator.
~Nnet(); ~Nnet();
public: public:
/// Perform forward pass through the network /// Perform forward pass through the network
void Propagate(const CuMatrixBase<BaseFloat> &in, CuMatrix<BaseFloat> *out); void Propagate(const CuMatrixBase<BaseFloat> &in, CuMatrix<BaseFloat> *out);
/// Perform backward pass through the network /// Perform backward pass through the network
void Backpropagate(const CuMatrixBase<BaseFloat> &out_diff, CuMatrix<BaseFloat> *in_diff); void Backpropagate(const CuMatrixBase<BaseFloat> &out_diff, CuMatrix<BaseFloat> *in_diff);
/// Perform forward pass through the network, don't keep buffers (use it when not training) /// Perform forward pass through the network, don't keep buffers (use it when not training)
void Feedforward(const CuMatrixBase<BaseFloat> &in, CuMatrix<BaseFloat> *out); void Feedforward(const CuMatrixBase<BaseFloat> &in, CuMatrix<BaseFloat> *out);
/// Dimensionality on network input (input feature dim.) /// Dimensionality on network input (input feature dim.)
int32 InputDim() const; int32 InputDim() const;
/// Dimensionality of network outputs (posteriors | bn-features | etc.) /// Dimensionality of network outputs (posteriors | bn-features | etc.)
int32 OutputDim() const; int32 OutputDim() const;
/// Returns number of components-- think of this as similar to # of layers, but /// Returns number of components-- think of this as similar to # of layers, but
/// e.g. the nonlinearity and the linear part count as separate components, /// e.g. the nonlinearity and the linear part count as separate components,
...@@ -65,7 +65,7 @@ class Nnet { ...@@ -65,7 +65,7 @@ class Nnet {
/// Sets the c'th component to "component", taking ownership of the pointer /// Sets the c'th component to "component", taking ownership of the pointer
/// and deleting the corresponding one that we own. /// and deleting the corresponding one that we own.
void SetComponent(int32 c, Component *component); void SetComponent(int32 c, Component *component);
/// Appends this component to the components already in the neural net. /// Appends this component to the components already in the neural net.
/// Takes ownership of the pointer /// Takes ownership of the pointer
void AppendComponent(Component *dynamically_allocated_comp); void AppendComponent(Component *dynamically_allocated_comp);
...@@ -77,12 +77,12 @@ class Nnet { ...@@ -77,12 +77,12 @@ class Nnet {
void RemoveLastComponent() { RemoveComponent(NumComponents()-1); } void RemoveLastComponent() { RemoveComponent(NumComponents()-1); }
/// Access to forward pass buffers /// Access to forward pass buffers
const std::vector<CuMatrix<BaseFloat> >& PropagateBuffer() const { const std::vector<CuMatrix<BaseFloat> >& PropagateBuffer() const {
return propagate_buf_; return propagate_buf_;
} }
/// Access to backward pass buffers /// Access to backward pass buffers
const std::vector<CuMatrix<BaseFloat> >& BackpropagateBuffer() const { const std::vector<CuMatrix<BaseFloat> >& BackpropagateBuffer() const {
return backpropagate_buf_; return backpropagate_buf_;
} }
/// Get the number of parameters in the network /// Get the number of parameters in the network
...@@ -96,22 +96,25 @@ class Nnet { ...@@ -96,22 +96,25 @@ class Nnet {
/// Get the gradient stored in the network /// Get the gradient stored in the network
void GetGradient(Vector<BaseFloat>* grad_copy) const; void GetGradient(Vector<BaseFloat>* grad_copy) const;
/// Set the dropout rate /// Set the dropout rate
void SetDropoutRetention(BaseFloat r); void SetDropoutRetention(BaseFloat r);
/// Reset streams in LSTM multi-stream training, /// Reset streams in LSTM multi-stream training,
void ResetLstmStreams(const std::vector<int32> &stream_reset_flag); void ResetLstmStreams(const std::vector<int32> &stream_reset_flag);
/// set sequence length in LSTM multi-stream training
void SetSeqLengths(const std::vector<int32> &sequence_lengths);
/// Initialize MLP from config /// Initialize MLP from config
void Init(const std::string &config_file); void Init(const std::string &config_file);
/// Read the MLP from file (can add layers to exisiting instance of Nnet) /// Read the MLP from file (can add layers to exisiting instance of Nnet)
void Read(const std::string &file); void Read(const std::string &file);
/// Read the MLP from stream (can add layers to exisiting instance of Nnet) /// Read the MLP from stream (can add layers to exisiting instance of Nnet)
void Read(std::istream &in, bool binary); void Read(std::istream &in, bool binary);
/// Write MLP to file /// Write MLP to file
void Write(const std::string &file, bool binary) const; void Write(const std::string &file, bool binary) const;
/// Write MLP to stream /// Write MLP to stream
void Write(std::ostream &out, bool binary) const; void Write(std::ostream &out, bool binary) const;
/// Create string with human readable description of the nnet /// Create string with human readable description of the nnet
std::string Info() const; std::string Info() const;
/// Create string with per-component gradient statistics /// Create string with per-component gradient statistics
...@@ -135,18 +138,17 @@ class Nnet { ...@@ -135,18 +138,17 @@ class Nnet {
private: private:
/// Vector which contains all the components composing the neural network, /// Vector which contains all the components composing the neural network,
/// the components are for example: AffineTransform, Sigmoid, Softmax /// the components are for example: AffineTransform, Sigmoid, Softmax
std::vector<Component*> components_; std::vector<Component*> components_;
std::vector<CuMatrix<BaseFloat> > propagate_buf_; ///< buffers for forward pass std::vector<CuMatrix<BaseFloat> > propagate_buf_; ///< buffers for forward pass
std::vector<CuMatrix<BaseFloat> > backpropagate_buf_; ///< buffers for backward pass std::vector<CuMatrix<BaseFloat> > backpropagate_buf_; ///< buffers for backward pass
/// Option class with hyper-parameters passed to UpdatableComponent(s) /// Option class with hyper-parameters passed to UpdatableComponent(s)
NnetTrainOptions opts_; NnetTrainOptions opts_;
}; };
} // namespace nnet1 } // namespace nnet1
} // namespace kaldi } // namespace kaldi
#endif // KALDI_NNET_NNET_NNET_H_ #endif // KALDI_NNET_NNET_NNET_H_
...@@ -10,7 +10,7 @@ BINFILES = nnet-train-frmshuff \ ...@@ -10,7 +10,7 @@ BINFILES = nnet-train-frmshuff \
nnet-train-perutt \ nnet-train-perutt \
nnet-train-mmi-sequential \ nnet-train-mmi-sequential \
nnet-train-mpe-sequential \ nnet-train-mpe-sequential \
nnet-train-lstm-streams \ nnet-train-lstm-streams nnet-train-blstm-streams \
rbm-train-cd1-frmshuff rbm-convert-to-nnet \ rbm-train-cd1-frmshuff rbm-convert-to-nnet \
nnet-forward nnet-copy nnet-info nnet-concat \ nnet-forward nnet-copy nnet-info nnet-concat \
transf-to-nnet cmvn-to-nnet nnet-initialize \ transf-to-nnet cmvn-to-nnet nnet-initialize \
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment