Commit 4163db46 authored by nichongjia's avatar nichongjia

blstm remove bug according to karel

parent afe298c3
......@@ -76,7 +76,7 @@ class BLstmProjectedStreams : public UpdatableComponent {
}
/// set the utterance length used for parallel training
void SetSeqLengths(std::vector<int> &sequence_lengths) {
void SetSeqLengths(const std::vector<int32> &sequence_lengths) {
sequence_lengths_ = sequence_lengths;
}
......@@ -1055,7 +1055,7 @@ class BLstmProjectedStreams : public UpdatableComponent {
int32 ncell_; ///< the number of cell blocks
int32 nrecur_; ///< recurrent projection layer dim
int32 nstream_;
std::vector<int> sequence_lengths_;
std::vector<int32> sequence_lengths_;
// gradient-clipping value,
BaseFloat clip_gradient_;
......
......@@ -92,8 +92,6 @@ class Component {
static const char* TypeToMarker(ComponentType t);
/// Convert marker to component type (case insensitive)
static ComponentType MarkerToType(const std::string &s);
/// during training of LSTM models.
virtual void SetSeqLengths(std::vector<int> &sequence_lengths) { }
/// General interface of a component
public:
......
......@@ -360,6 +360,14 @@ void Nnet::ResetLstmStreams(const std::vector<int32> &stream_reset_flag) {
}
}
void Nnet::SetSeqLengths(const std::vector<int32> &sequence_lengths) {
for (int32 c=0; c < NumComponents(); c++) {
if (GetComponent(c).GetType() == Component::kBLstmProjectedStreams) {
BLstmProjectedStreams& comp = dynamic_cast<BLstmProjectedStreams&>(GetComponent(c));
comp.SetSeqLengths(sequence_lengths);
}
}
}
void Nnet::Init(const std::string &file) {
Input in(file);
......
......@@ -101,6 +101,9 @@ class Nnet {
/// Reset streams in LSTM multi-stream training,
void ResetLstmStreams(const std::vector<int32> &stream_reset_flag);
/// set sequence length in LSTM multi-stream training
void SetSeqLengths(const std::vector<int32> &sequence_lengths);
/// Initialize MLP from config
void Init(const std::string &config_file);
/// Read the MLP from file (can add layers to exisiting instance of Nnet)
......@@ -131,12 +134,6 @@ class Nnet {
const NnetTrainOptions& GetTrainOptions() const {
return opts_;
}
/// Set lengths of utterances for LSTM parallel training
void SetSeqLengths(std::vector<int> &sequence_lengths) {
for(int32 i=0; i < (int32)components_.size(); i++) {
components_[i]->SetSeqLengths(sequence_lengths);
}
}
private:
/// Vector which contains all the components composing the neural network,
......
......@@ -10,7 +10,7 @@ BINFILES = nnet-train-frmshuff \
nnet-train-perutt \
nnet-train-mmi-sequential \
nnet-train-mpe-sequential \
nnet-train-lstm-streams nnet-train-blstm-parallel \
nnet-train-lstm-streams nnet-train-blstm-streams \
rbm-train-cd1-frmshuff rbm-convert-to-nnet \
nnet-forward nnet-copy nnet-info nnet-concat \
transf-to-nnet cmvn-to-nnet nnet-initialize \
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment