Commit 4163db46 authored by nichongjia's avatar nichongjia

blstm remove bug according to karel

parent afe298c3
...@@ -76,7 +76,7 @@ class BLstmProjectedStreams : public UpdatableComponent { ...@@ -76,7 +76,7 @@ class BLstmProjectedStreams : public UpdatableComponent {
} }
/// set the utterance length used for parallel training /// set the utterance length used for parallel training
void SetSeqLengths(std::vector<int> &sequence_lengths) { void SetSeqLengths(const std::vector<int32> &sequence_lengths) {
sequence_lengths_ = sequence_lengths; sequence_lengths_ = sequence_lengths;
} }
...@@ -1055,7 +1055,7 @@ class BLstmProjectedStreams : public UpdatableComponent { ...@@ -1055,7 +1055,7 @@ class BLstmProjectedStreams : public UpdatableComponent {
int32 ncell_; ///< the number of cell blocks int32 ncell_; ///< the number of cell blocks
int32 nrecur_; ///< recurrent projection layer dim int32 nrecur_; ///< recurrent projection layer dim
int32 nstream_; int32 nstream_;
std::vector<int> sequence_lengths_; std::vector<int32> sequence_lengths_;
// gradient-clipping value, // gradient-clipping value,
BaseFloat clip_gradient_; BaseFloat clip_gradient_;
......
...@@ -91,9 +91,7 @@ class Component { ...@@ -91,9 +91,7 @@ class Component {
/// Convert component type to marker /// Convert component type to marker
static const char* TypeToMarker(ComponentType t); static const char* TypeToMarker(ComponentType t);
/// Convert marker to component type (case insensitive) /// Convert marker to component type (case insensitive)
static ComponentType MarkerToType(const std::string &s); static ComponentType MarkerToType(const std::string &s);
/// during training of LSTM models.
virtual void SetSeqLengths(std::vector<int> &sequence_lengths) { }
/// General interface of a component /// General interface of a component
public: public:
......
...@@ -360,6 +360,14 @@ void Nnet::ResetLstmStreams(const std::vector<int32> &stream_reset_flag) { ...@@ -360,6 +360,14 @@ void Nnet::ResetLstmStreams(const std::vector<int32> &stream_reset_flag) {
} }
} }
void Nnet::SetSeqLengths(const std::vector<int32> &sequence_lengths) {
for (int32 c=0; c < NumComponents(); c++) {
if (GetComponent(c).GetType() == Component::kBLstmProjectedStreams) {
BLstmProjectedStreams& comp = dynamic_cast<BLstmProjectedStreams&>(GetComponent(c));
comp.SetSeqLengths(sequence_lengths);
}
}
}
void Nnet::Init(const std::string &file) { void Nnet::Init(const std::string &file) {
Input in(file); Input in(file);
......
...@@ -101,6 +101,9 @@ class Nnet { ...@@ -101,6 +101,9 @@ class Nnet {
/// Reset streams in LSTM multi-stream training, /// Reset streams in LSTM multi-stream training,
void ResetLstmStreams(const std::vector<int32> &stream_reset_flag); void ResetLstmStreams(const std::vector<int32> &stream_reset_flag);
/// set sequence length in LSTM multi-stream training
void SetSeqLengths(const std::vector<int32> &sequence_lengths);
/// Initialize MLP from config /// Initialize MLP from config
void Init(const std::string &config_file); void Init(const std::string &config_file);
/// Read the MLP from file (can add layers to exisiting instance of Nnet) /// Read the MLP from file (can add layers to exisiting instance of Nnet)
...@@ -130,13 +133,7 @@ class Nnet { ...@@ -130,13 +133,7 @@ class Nnet {
/// Get training hyper-parameters from the network /// Get training hyper-parameters from the network
const NnetTrainOptions& GetTrainOptions() const { const NnetTrainOptions& GetTrainOptions() const {
return opts_; return opts_;
} }
/// Set lengths of utterances for LSTM parallel training
void SetSeqLengths(std::vector<int> &sequence_lengths) {
for(int32 i=0; i < (int32)components_.size(); i++) {
components_[i]->SetSeqLengths(sequence_lengths);
}
}
private: private:
/// Vector which contains all the components composing the neural network, /// Vector which contains all the components composing the neural network,
......
...@@ -10,7 +10,7 @@ BINFILES = nnet-train-frmshuff \ ...@@ -10,7 +10,7 @@ BINFILES = nnet-train-frmshuff \
nnet-train-perutt \ nnet-train-perutt \
nnet-train-mmi-sequential \ nnet-train-mmi-sequential \
nnet-train-mpe-sequential \ nnet-train-mpe-sequential \
nnet-train-lstm-streams nnet-train-blstm-parallel \ nnet-train-lstm-streams nnet-train-blstm-streams \
rbm-train-cd1-frmshuff rbm-convert-to-nnet \ rbm-train-cd1-frmshuff rbm-convert-to-nnet \
nnet-forward nnet-copy nnet-info nnet-concat \ nnet-forward nnet-copy nnet-info nnet-concat \
transf-to-nnet cmvn-to-nnet nnet-initialize \ transf-to-nnet cmvn-to-nnet nnet-initialize \
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment