diff --git a/egs/wsj/s5/utils/nnet/make_blstm_proto.py b/egs/wsj/s5/utils/nnet/make_blstm_proto.py
old mode 100755
new mode 100644
index ea1ce29522dff2af2b9f78a2019ff4372933b3da..4873d51d67a230edb3fa1bc7eca231d5affc4ad6
--- a/egs/wsj/s5/utils/nnet/make_blstm_proto.py
+++ b/egs/wsj/s5/utils/nnet/make_blstm_proto.py
@@ -58,17 +58,17 @@ print ""
# normally we won't use more than 2 layers of LSTM
if o.num_layers == 1:
print " %d %d %s %f %f" % \
- (feat_dim, o.num_recurrent, o.num_cells, o.lstm_stddev_factor, o.clip_gradient)
+ (feat_dim, 2*o.num_recurrent, o.num_cells, o.lstm_stddev_factor, o.clip_gradient)
elif o.num_layers == 2:
print " %d %d %s %f %f" % \
- (feat_dim, o.num_recurrent, o.num_cells, o.lstm_stddev_factor, o.clip_gradient)
+ (feat_dim, 2*o.num_recurrent, o.num_cells, o.lstm_stddev_factor, o.clip_gradient)
print " %d %d %s %f %f" % \
- (o.num_recurrent, o.num_recurrent, o.num_cells, o.lstm_stddev_factor, o.clip_gradient)
+ (2*o.num_recurrent, 2*o.num_recurrent, o.num_cells, o.lstm_stddev_factor, o.clip_gradient)
else:
sys.stderr.write("make_lstm_proto.py ERROR: more than 2 layers of LSTM, not supported yet.\n")
sys.exit(1)
print " %d %d 0.0 0.0 %f" % \
- (o.num_recurrent, num_leaves, o.param_stddev_factor)
+ (2*o.num_recurrent, num_leaves, o.param_stddev_factor)
print " %d %d" % \
(num_leaves, num_leaves)
print ""
diff --git a/src/nnet/nnet-blstm-projected-streams.h b/src/nnet/nnet-blstm-projected-streams.h
index 05126a2f3536463b318b5784b8bd4be062e0c2e5..72706e54c1f518fd10a79bdb0b1b30a176da7ad4 100644
--- a/src/nnet/nnet-blstm-projected-streams.h
+++ b/src/nnet/nnet-blstm-projected-streams.h
@@ -1,4 +1,4 @@
-// nnet/nnet-lstm-projected-streams.h
+// nnet/nnet-blstm-projected-streams.h
// Copyright 2014 Jiayu DU (Jerry), Wei Li
// Copyright 2015 Chongjia Ni
@@ -49,7 +49,7 @@ class BLstmProjectedStreams : public UpdatableComponent {
BLstmProjectedStreams(int32 input_dim, int32 output_dim) :
UpdatableComponent(input_dim, output_dim),
ncell_(0),
- nrecur_(output_dim),
+ nrecur_(static_cast(output_dim/2)),
nstream_(0),
clip_gradient_(0.0)
//, dropout_rate_(0.0)
@@ -75,7 +75,12 @@ class BLstmProjectedStreams : public UpdatableComponent {
v = tmp;
}
- void InitData(std::istream &is) {
+ /// set the utterance length used for parallel training
+ void SetSeqLengths(const std::vector &sequence_lengths) {
+ sequence_lengths_ = sequence_lengths;
+ }
+
+ void InitData(const std::istream &is) {
// define options
float param_scale = 0.02;
// parse config
@@ -86,7 +91,7 @@ class BLstmProjectedStreams : public UpdatableComponent {
ReadBasicType(is, false, &ncell_);
else if (token == "")
ReadBasicType(is, false, &clip_gradient_);
- //else if (token == "")
+ // else if (token == "")
// ReadBasicType(is, false, &dropout_rate_);
else if (token == "")
ReadBasicType(is, false, ¶m_scale);
@@ -121,7 +126,6 @@ class BLstmProjectedStreams : public UpdatableComponent {
InitVecParam(f_bias_, param_scale);
InitVecParam(b_bias_, param_scale);
- // This is for input gate, forgot gate and output gate connected with the previous cell
// forward direction
f_peephole_i_c_.Resize(ncell_, kUndefined);
f_peephole_f_c_.Resize(ncell_, kUndefined);
@@ -174,8 +178,8 @@ class BLstmProjectedStreams : public UpdatableComponent {
ReadBasicType(is, binary, &ncell_);
ExpectToken(is, binary, "");
ReadBasicType(is, binary, &clip_gradient_);
- //ExpectToken(is, binary, "");
- //ReadBasicType(is, binary, &dropout_rate_);
+ // ExpectToken(is, binary, "");
+ // ReadBasicType(is, binary, &dropout_rate_);
// reading parameters corresponding to forward direction
f_w_gifo_x_.Read(is, binary);
@@ -439,52 +443,9 @@ class BLstmProjectedStreams : public UpdatableComponent {
"\n B_DR " + MomentStatistics(B_DR);
}
-
- void ResetLstmStreams(const std::vector &stream_reset_flag) {
- // allocate f_prev_nnet_state_, b_prev_nnet_state_ if not done yet,
- if (nstream_ == 0) {
- // Karel: we just got number of streams! (before the 1st batch comes)
- nstream_ = stream_reset_flag.size();
- // forward direction
- f_prev_nnet_state_.Resize(nstream_, 7*ncell_ + 1*nrecur_, kSetZero);
- // backward direction
- b_prev_nnet_state_.Resize(nstream_, 7*ncell_ + 1*nrecur_, kSetZero);
- KALDI_LOG << "Running training with " << nstream_ << " streams.";
- }
- // reset flag: 1 - reset stream network state
- KALDI_ASSERT(f_prev_nnet_state_.NumRows() == stream_reset_flag.size());
- KALDI_ASSERT(b_prev_nnet_state_.NumRows() == stream_reset_flag.size());
- for (int s = 0; s < stream_reset_flag.size(); s++) {
- if (stream_reset_flag[s] == 1) {
- // forward direction
- f_prev_nnet_state_.Row(s).SetZero();
- // backward direction
- b_prev_nnet_state_.Row(s).SetZero();
- }
- }
- }
-
-
void PropagateFnc(const CuMatrixBase &in, CuMatrixBase *out) {
int DEBUG = 0;
-
- static bool do_stream_reset = false;
- if (nstream_ == 0) {
- do_stream_reset = true;
- nstream_ = 1; // Karel: we are in nnet-forward, so we will use 1 stream,
- // forward direction
- f_prev_nnet_state_.Resize(nstream_, 7*ncell_ + 1*nrecur_, kSetZero);
- // backward direction
- b_prev_nnet_state_.Resize(nstream_, 7*ncell_ + 1*nrecur_, kSetZero);
- KALDI_LOG << "Running nnet-forward with per-utterance BLSTM-state reset";
- }
- if (do_stream_reset) {
- // resetting the forward and backward streams
- f_prev_nnet_state_.SetZero();
- b_prev_nnet_state_.SetZero();
- }
- KALDI_ASSERT(nstream_ > 0);
-
+ int32 nstream_ = sequence_lengths_.size();
KALDI_ASSERT(in.NumRows() % nstream_ == 0);
int32 T = in.NumRows() / nstream_;
int32 S = nstream_;
@@ -492,12 +453,8 @@ class BLstmProjectedStreams : public UpdatableComponent {
// 0:forward pass history, [1, T]:current sequence, T+1:dummy
// forward direction
f_propagate_buf_.Resize((T+2)*S, 7 * ncell_ + nrecur_, kSetZero);
- f_propagate_buf_.RowRange(0*S,S).CopyFromMat(f_prev_nnet_state_);
-
// backward direction
b_propagate_buf_.Resize((T+2)*S, 7 * ncell_ + nrecur_, kSetZero);
- // for the backward direction, we initialize it at (T+1) frame
- b_propagate_buf_.RowRange((T+1)*S,S).CopyFromMat(b_prev_nnet_state_);
// disassembling forward-pass forward-propagation buffer into different neurons,
CuSubMatrix F_YG(f_propagate_buf_.ColRange(0*ncell_, ncell_));
@@ -525,32 +482,33 @@ class BLstmProjectedStreams : public UpdatableComponent {
// forward direction
// x -> g, i, f, o, not recurrent, do it all in once
- F_YGIFO.RowRange(1*S,T*S).AddMatMat(1.0, in, kNoTrans, f_w_gifo_x_, kTrans, 0.0);
+ F_YGIFO.RowRange(1*S, T*S).AddMatMat(1.0, in, kNoTrans, f_w_gifo_x_, kTrans, 0.0);
// bias -> g, i, f, o
- F_YGIFO.RowRange(1*S,T*S).AddVecToRows(1.0, f_bias_);
+ F_YGIFO.RowRange(1*S, T*S).AddVecToRows(1.0, f_bias_);
for (int t = 1; t <= T; t++) {
// multistream buffers for current time-step
- CuSubMatrix y_g(F_YG.RowRange(t*S,S));
- CuSubMatrix y_i(F_YI.RowRange(t*S,S));
- CuSubMatrix y_f(F_YF.RowRange(t*S,S));
- CuSubMatrix y_o(F_YO.RowRange(t*S,S));
- CuSubMatrix y_c(F_YC.RowRange(t*S,S));
- CuSubMatrix y_h(F_YH.RowRange(t*S,S));
- CuSubMatrix y_m(F_YM.RowRange(t*S,S));
- CuSubMatrix y_r(F_YR.RowRange(t*S,S));
-
- CuSubMatrix y_gifo(F_YGIFO.RowRange(t*S,S));
+ CuSubMatrix y_all(f_propagate_buf_.RowRange(t*S, S));
+ CuSubMatrix y_g(F_YG.RowRange(t*S, S));
+ CuSubMatrix y_i(F_YI.RowRange(t*S, S));
+ CuSubMatrix y_f(F_YF.RowRange(t*S, S));
+ CuSubMatrix y_o(F_YO.RowRange(t*S, S));
+ CuSubMatrix y_c(F_YC.RowRange(t*S, S));
+ CuSubMatrix y_h(F_YH.RowRange(t*S, S));
+ CuSubMatrix y_m(F_YM.RowRange(t*S, S));
+ CuSubMatrix y_r(F_YR.RowRange(t*S, S));
+
+ CuSubMatrix y_gifo(F_YGIFO.RowRange(t*S, S));
// r(t-1) -> g, i, f, o
- y_gifo.AddMatMat(1.0, F_YR.RowRange((t-1)*S,S), kNoTrans, f_w_gifo_r_, kTrans, 1.0);
+ y_gifo.AddMatMat(1.0, F_YR.RowRange((t-1)*S, S), kNoTrans, f_w_gifo_r_, kTrans, 1.0);
// c(t-1) -> i(t) via peephole
- y_i.AddMatDiagVec(1.0, F_YC.RowRange((t-1)*S,S), kNoTrans, f_peephole_i_c_, 1.0);
+ y_i.AddMatDiagVec(1.0, F_YC.RowRange((t-1)*S, S), kNoTrans, f_peephole_i_c_, 1.0);
// c(t-1) -> f(t) via peephole
- y_f.AddMatDiagVec(1.0, F_YC.RowRange((t-1)*S,S), kNoTrans, f_peephole_f_c_, 1.0);
+ y_f.AddMatDiagVec(1.0, F_YC.RowRange((t-1)*S, S), kNoTrans, f_peephole_f_c_, 1.0);
// i, f sigmoid squashing
y_i.Sigmoid(y_i);
@@ -563,7 +521,7 @@ class BLstmProjectedStreams : public UpdatableComponent {
y_c.AddMatMatElements(1.0, y_g, y_i, 0.0);
// c(t-1) -> c(t) via forget-gate
- y_c.AddMatMatElements(1.0, F_YC.RowRange((t-1)*S,S), y_f, 1.0);
+ y_c.AddMatMatElements(1.0, F_YC.RowRange((t-1)*S, S), y_f, 1.0);
y_c.ApplyFloor(-50); // optional clipping of cell activation
y_c.ApplyCeiling(50); // google paper Interspeech2014: LSTM for LVCSR
@@ -583,6 +541,12 @@ class BLstmProjectedStreams : public UpdatableComponent {
// m -> r
y_r.AddMatMat(1.0, y_m, kNoTrans, f_w_r_m_, kTrans, 0.0);
+ // set zeros
+ // for (int s = 0; s < S; s++) {
+ // if (t > sequence_lengths_[s])
+ // y_all.Row(s).SetZero();
+ // }
+
if (DEBUG) {
std::cerr << "forward direction forward-pass frame " << t << "\n";
std::cerr << "activation of g: " << y_g;
@@ -597,43 +561,43 @@ class BLstmProjectedStreams : public UpdatableComponent {
}
// backward direction
- B_YGIFO.RowRange(1*S,T*S).AddMatMat(1.0, in, kNoTrans, b_w_gifo_x_, kTrans, 0.0);
+ B_YGIFO.RowRange(1*S, T*S).AddMatMat(1.0, in, kNoTrans, b_w_gifo_x_, kTrans, 0.0);
//// LSTM forward dropout
//// Google paper 2014: Recurrent Neural Network Regularization
//// by Wojciech Zaremba, Ilya Sutskever, Oriol Vinyals
- //if (dropout_rate_ != 0.0) {
+ // if (dropout_rate_ != 0.0) {
// dropout_mask_.Resize(in.NumRows(), 4*ncell_, kUndefined);
// dropout_mask_.SetRandUniform(); // [0,1]
// dropout_mask_.Add(-dropout_rate_); // [-dropout_rate, 1-dropout_rate_],
// dropout_mask_.ApplyHeaviside(); // -tive -> 0.0, +tive -> 1.0
// YGIFO.RowRange(1*S,T*S).MulElements(dropout_mask_);
- //}
+ // }
// bias -> g, i, f, o
- B_YGIFO.RowRange(1*S,T*S).AddVecToRows(1.0, b_bias_);
+ B_YGIFO.RowRange(1*S, T*S).AddVecToRows(1.0, b_bias_);
// backward direction, from T to 1, t--
for (int t = T; t >= 1; t--) {
// multistream buffers for current time-step
- CuSubMatrix y_g(B_YG.RowRange(t*S,S));
- CuSubMatrix y_i(B_YI.RowRange(t*S,S));
- CuSubMatrix y_f(B_YF.RowRange(t*S,S));
- CuSubMatrix y_o(B_YO.RowRange(t*S,S));
- CuSubMatrix y_c(B_YC.RowRange(t*S,S));
- CuSubMatrix y_h(B_YH.RowRange(t*S,S));
- CuSubMatrix y_m(B_YM.RowRange(t*S,S));
- CuSubMatrix y_r(B_YR.RowRange(t*S,S));
-
- CuSubMatrix y_gifo(B_YGIFO.RowRange(t*S,S));
+ CuSubMatrix y_all(b_propagate_buf_.RowRange(t*S, S));
+ CuSubMatrix y_g(B_YG.RowRange(t*S, S));
+ CuSubMatrix y_i(B_YI.RowRange(t*S, S));
+ CuSubMatrix y_f(B_YF.RowRange(t*S, S));
+ CuSubMatrix y_o(B_YO.RowRange(t*S, S));
+ CuSubMatrix y_c(B_YC.RowRange(t*S, S));
+ CuSubMatrix y_h(B_YH.RowRange(t*S, S));
+ CuSubMatrix y_m(B_YM.RowRange(t*S, S));
+ CuSubMatrix y_r(B_YR.RowRange(t*S, S));
+ CuSubMatrix y_gifo(B_YGIFO.RowRange(t*S, S));
// r(t+1) -> g, i, f, o
- y_gifo.AddMatMat(1.0, B_YR.RowRange((t+1)*S,S), kNoTrans, b_w_gifo_r_, kTrans, 1.0);
+ y_gifo.AddMatMat(1.0, B_YR.RowRange((t+1)*S, S), kNoTrans, b_w_gifo_r_, kTrans, 1.0);
// c(t+1) -> i(t) via peephole
- y_i.AddMatDiagVec(1.0, B_YC.RowRange((t+1)*S,S), kNoTrans, b_peephole_i_c_, 1.0);
+ y_i.AddMatDiagVec(1.0, B_YC.RowRange((t+1)*S, S), kNoTrans, b_peephole_i_c_, 1.0);
// c(t+1) -> f(t) via peephole
- y_f.AddMatDiagVec(1.0, B_YC.RowRange((t+1)*S,S), kNoTrans, b_peephole_f_c_, 1.0);
+ y_f.AddMatDiagVec(1.0, B_YC.RowRange((t+1)*S, S), kNoTrans, b_peephole_f_c_, 1.0);
// i, f sigmoid squashing
y_i.Sigmoid(y_i);
@@ -646,7 +610,7 @@ class BLstmProjectedStreams : public UpdatableComponent {
y_c.AddMatMatElements(1.0, y_g, y_i, 0.0);
// c(t+1) -> c(t) via forget-gate
- y_c.AddMatMatElements(1.0, B_YC.RowRange((t+1)*S,S), y_f, 1.0);
+ y_c.AddMatMatElements(1.0, B_YC.RowRange((t+1)*S, S), y_f, 1.0);
y_c.ApplyFloor(-50); // optional clipping of cell activation
y_c.ApplyCeiling(50); // google paper Interspeech2014: LSTM for LVCSR
@@ -666,6 +630,11 @@ class BLstmProjectedStreams : public UpdatableComponent {
// m -> r
y_r.AddMatMat(1.0, y_m, kNoTrans, b_w_r_m_, kTrans, 0.0);
+ for (int s = 0; s < S; s++) {
+ if (t > sequence_lengths_[s])
+ y_all.Row(s).SetZero();
+ }
+
if (DEBUG) {
std::cerr << "backward direction forward-pass frame " << t << "\n";
std::cerr << "activation of g: " << y_g;
@@ -679,26 +648,21 @@ class BLstmProjectedStreams : public UpdatableComponent {
}
}
- // According to definition of BLSTM, for output YR of BLSTM, YR should be F_YR + B_YR
- CuSubMatrix YR(F_YR.RowRange(1*S,T*S));
- YR.AddMat(1.0,B_YR.RowRange(1*S,T*S));
-
+ CuMatrix YR_FB;
+ YR_FB.Resize((T+2)*S, 2 * nrecur_, kSetZero);
+ // forward part
+ YR_FB.ColRange(0, nrecur_).CopyFromMat(f_propagate_buf_.ColRange(7*ncell_, nrecur_));
+ // backward part
+ YR_FB.ColRange(nrecur_, nrecur_).CopyFromMat(b_propagate_buf_.ColRange(7*ncell_, nrecur_));
// recurrent projection layer is also feed-forward as BLSTM output
- out->CopyFromMat(YR);
-
- // now the last frame state becomes previous network state for next batch
- f_prev_nnet_state_.CopyFromMat(f_propagate_buf_.RowRange(T*S,S));
-
- // now the last frame (,that is the first frame) becomes previous netwok state for next batch
- b_prev_nnet_state_.CopyFromMat(b_propagate_buf_.RowRange(1*S,S));
+ out->CopyFromMat(YR_FB.RowRange(1*S, T*S));
}
-
void BackpropagateFnc(const CuMatrixBase &in, const CuMatrixBase &out,
const CuMatrixBase &out_diff, CuMatrixBase *in_diff) {
-
int DEBUG = 0;
-
+ // the number of sequences to be processed in parallel
+ int32 nstream_ = sequence_lengths_.size();
int32 T = in.NumRows() / nstream_;
int32 S = nstream_;
// disassembling forward-pass forward-propagation buffer into different neurons,
@@ -727,31 +691,31 @@ class BLstmProjectedStreams : public UpdatableComponent {
CuSubMatrix F_DGIFO(f_backpropagate_buf_.ColRange(0, 4*ncell_));
// projection layer to BLSTM output is not recurrent, so backprop it all in once
- F_DR.RowRange(1*S,T*S).CopyFromMat(out_diff);
+ F_DR.RowRange(1*S, T*S).CopyFromMat(out_diff.ColRange(0, nrecur_));
for (int t = T; t >= 1; t--) {
- CuSubMatrix y_g(F_YG.RowRange(t*S,S));
- CuSubMatrix y_i(F_YI.RowRange(t*S,S));
- CuSubMatrix y_f(F_YF.RowRange(t*S,S));
- CuSubMatrix y_o(F_YO.RowRange(t*S,S));
- CuSubMatrix y_c(F_YC.RowRange(t*S,S));
- CuSubMatrix y_h(F_YH.RowRange(t*S,S));
- CuSubMatrix y_m(F_YM.RowRange(t*S,S));
- CuSubMatrix y_r(F_YR.RowRange(t*S,S));
-
- CuSubMatrix d_g(F_DG.RowRange(t*S,S));
- CuSubMatrix d_i(F_DI.RowRange(t*S,S));
- CuSubMatrix d_f(F_DF.RowRange(t*S,S));
- CuSubMatrix d_o(F_DO.RowRange(t*S,S));
- CuSubMatrix d_c(F_DC.RowRange(t*S,S));
- CuSubMatrix d_h(F_DH.RowRange(t*S,S));
- CuSubMatrix d_m(F_DM.RowRange(t*S,S));
- CuSubMatrix d_r(F_DR.RowRange(t*S,S));
-
+ CuSubMatrix y_g(F_YG.RowRange(t*S, S));
+ CuSubMatrix y_i(F_YI.RowRange(t*S, S));
+ CuSubMatrix y_f(F_YF.RowRange(t*S, S));
+ CuSubMatrix y_o(F_YO.RowRange(t*S, S));
+ CuSubMatrix y_c(F_YC.RowRange(t*S, S));
+ CuSubMatrix y_h(F_YH.RowRange(t*S, S));
+ CuSubMatrix y_m(F_YM.RowRange(t*S, S));
+ CuSubMatrix y_r(F_YR.RowRange(t*S, S));
+
+ CuSubMatrix d_g(F_DG.RowRange(t*S, S));
+ CuSubMatrix d_i(F_DI.RowRange(t*S, S));
+ CuSubMatrix d_f(F_DF.RowRange(t*S, S));
+ CuSubMatrix d_o(F_DO.RowRange(t*S, S));
+ CuSubMatrix d_c(F_DC.RowRange(t*S, S));
+ CuSubMatrix d_h(F_DH.RowRange(t*S, S));
+ CuSubMatrix d_m(F_DM.RowRange(t*S, S));
+ CuSubMatrix d_r(F_DR.RowRange(t*S, S));
+ CuSubMatrix d_all(f_backpropagate_buf_.RowRange(t*S, S));
// r
// Version 1 (precise gradients):
// backprop error from g(t+1), i(t+1), f(t+1), o(t+1) to r(t)
- d_r.AddMatMat(1.0, F_DGIFO.RowRange((t+1)*S,S), kNoTrans, f_w_gifo_r_, kNoTrans, 1.0);
+ d_r.AddMatMat(1.0, F_DGIFO.RowRange((t+1)*S, S), kNoTrans, f_w_gifo_r_, kNoTrans, 1.0);
/*
// Version 2 (Alex Graves' PhD dissertation):
@@ -785,13 +749,13 @@ class BLstmProjectedStreams : public UpdatableComponent {
// 4. diff from f(t+1) (via peephole)
// 5. diff from o(t) (via peephole, not recurrent)
d_c.AddMat(1.0, d_h);
- d_c.AddMatMatElements(1.0, F_DC.RowRange((t+1)*S,S), F_YF.RowRange((t+1)*S,S), 1.0);
- d_c.AddMatDiagVec(1.0, F_DI.RowRange((t+1)*S,S), kNoTrans, f_peephole_i_c_, 1.0);
- d_c.AddMatDiagVec(1.0, F_DF.RowRange((t+1)*S,S), kNoTrans, f_peephole_f_c_, 1.0);
+ d_c.AddMatMatElements(1.0, F_DC.RowRange((t+1)*S, S), F_YF.RowRange((t+1)*S, S), 1.0);
+ d_c.AddMatDiagVec(1.0, F_DI.RowRange((t+1)*S, S), kNoTrans, f_peephole_i_c_, 1.0);
+ d_c.AddMatDiagVec(1.0, F_DF.RowRange((t+1)*S, S), kNoTrans, f_peephole_f_c_, 1.0);
d_c.AddMatDiagVec(1.0, d_o , kNoTrans, f_peephole_o_c_, 1.0);
// f
- d_f.AddMatMatElements(1.0, d_c, F_YC.RowRange((t-1)*S,S), 0.0);
+ d_f.AddMatMatElements(1.0, d_c, F_YC.RowRange((t-1)*S, S), 0.0);
d_f.DiffSigmoid(y_f, d_f);
// i
@@ -838,35 +802,35 @@ class BLstmProjectedStreams : public UpdatableComponent {
CuSubMatrix B_DH(b_backpropagate_buf_.ColRange(5*ncell_, ncell_));
CuSubMatrix B_DM(b_backpropagate_buf_.ColRange(6*ncell_, ncell_));
CuSubMatrix B_DR(b_backpropagate_buf_.ColRange(7*ncell_, nrecur_));
-
CuSubMatrix B_DGIFO(b_backpropagate_buf_.ColRange(0, 4*ncell_));
// projection layer to BLSTM output is not recurrent, so backprop it all in once
- B_DR.RowRange(1*S,T*S).CopyFromMat(out_diff);
+ B_DR.RowRange(1*S, T*S).CopyFromMat(out_diff.ColRange(nrecur_, nrecur_));
for (int t = 1; t <= T; t++) {
- CuSubMatrix y_g(B_YG.RowRange(t*S,S));
- CuSubMatrix y_i(B_YI.RowRange(t*S,S));
- CuSubMatrix y_f(B_YF.RowRange(t*S,S));
- CuSubMatrix y_o(B_YO.RowRange(t*S,S));
- CuSubMatrix y_c(B_YC.RowRange(t*S,S));
- CuSubMatrix y_h(B_YH.RowRange(t*S,S));
- CuSubMatrix y_m(B_YM.RowRange(t*S,S));
- CuSubMatrix y_r(B_YR.RowRange(t*S,S));
-
- CuSubMatrix d_g(B_DG.RowRange(t*S,S));
- CuSubMatrix d_i(B_DI.RowRange(t*S,S));
- CuSubMatrix d_f(B_DF.RowRange(t*S,S));
- CuSubMatrix d_o(B_DO.RowRange(t*S,S));
- CuSubMatrix d_c(B_DC.RowRange(t*S,S));
- CuSubMatrix d_h(B_DH.RowRange(t*S,S));
- CuSubMatrix d_m(B_DM.RowRange(t*S,S));
- CuSubMatrix d_r(B_DR.RowRange(t*S,S));
+ CuSubMatrix y_g(B_YG.RowRange(t*S, S));
+ CuSubMatrix y_i(B_YI.RowRange(t*S, S));
+ CuSubMatrix y_f(B_YF.RowRange(t*S, S));
+ CuSubMatrix y_o(B_YO.RowRange(t*S, S));
+ CuSubMatrix y_c(B_YC.RowRange(t*S, S));
+ CuSubMatrix y_h(B_YH.RowRange(t*S, S));
+ CuSubMatrix y_m(B_YM.RowRange(t*S, S));
+ CuSubMatrix y_r(B_YR.RowRange(t*S, S));
+
+ CuSubMatrix d_g(B_DG.RowRange(t*S, S));
+ CuSubMatrix d_i(B_DI.RowRange(t*S, S));
+ CuSubMatrix d_f(B_DF.RowRange(t*S, S));
+ CuSubMatrix d_o(B_DO.RowRange(t*S, S));
+ CuSubMatrix d_c(B_DC.RowRange(t*S, S));
+ CuSubMatrix d_h(B_DH.RowRange(t*S, S));
+ CuSubMatrix d_m(B_DM.RowRange(t*S, S));
+ CuSubMatrix d_r(B_DR.RowRange(t*S, S));
+ CuSubMatrix d_all(b_backpropagate_buf_.RowRange(t*S, S));
// r
// Version 1 (precise gradients):
// backprop error from g(t-1), i(t-1), f(t-1), o(t-1) to r(t)
- d_r.AddMatMat(1.0, B_DGIFO.RowRange((t-1)*S,S), kNoTrans, b_w_gifo_r_, kNoTrans, 1.0);
+ d_r.AddMatMat(1.0, B_DGIFO.RowRange((t-1)*S, S), kNoTrans, b_w_gifo_r_, kNoTrans, 1.0);
/*
// Version 2 (Alex Graves' PhD dissertation):
@@ -899,13 +863,13 @@ class BLstmProjectedStreams : public UpdatableComponent {
// 4. diff from f(t+1) (via peephole)
// 5. diff from o(t) (via peephole, not recurrent)
d_c.AddMat(1.0, d_h);
- d_c.AddMatMatElements(1.0, B_DC.RowRange((t-1)*S,S), B_YF.RowRange((t-1)*S,S), 1.0);
- d_c.AddMatDiagVec(1.0, B_DI.RowRange((t-1)*S,S), kNoTrans, b_peephole_i_c_, 1.0);
- d_c.AddMatDiagVec(1.0, B_DF.RowRange((t-1)*S,S), kNoTrans, b_peephole_f_c_, 1.0);
+ d_c.AddMatMatElements(1.0, B_DC.RowRange((t-1)*S, S), B_YF.RowRange((t-1)*S, S), 1.0);
+ d_c.AddMatDiagVec(1.0, B_DI.RowRange((t-1)*S, S), kNoTrans, b_peephole_i_c_, 1.0);
+ d_c.AddMatDiagVec(1.0, B_DF.RowRange((t-1)*S, S), kNoTrans, b_peephole_f_c_, 1.0);
d_c.AddMatDiagVec(1.0, d_o , kNoTrans, b_peephole_o_c_, 1.0);
// f
- d_f.AddMatMatElements(1.0, d_c, B_YC.RowRange((t-1)*S,S), 0.0);
+ d_f.AddMatMatElements(1.0, d_c, B_YC.RowRange((t-1)*S, S), 0.0);
d_f.DiffSigmoid(y_f, d_f);
// i
@@ -932,13 +896,12 @@ class BLstmProjectedStreams : public UpdatableComponent {
// g,i,f,o -> x, do it all in once
// forward direction difference
- in_diff->AddMatMat(1.0, F_DGIFO.RowRange(1*S,T*S), kNoTrans, f_w_gifo_x_, kNoTrans, 0.0);
+ in_diff->AddMatMat(1.0, F_DGIFO.RowRange(1*S, T*S), kNoTrans, f_w_gifo_x_, kNoTrans, 0.0);
// backward direction difference
- in_diff->AddMatMat(1.0, B_DGIFO.RowRange(1*S,T*S), kNoTrans, b_w_gifo_x_, kNoTrans, 1.0);
-
+ in_diff->AddMatMat(1.0, B_DGIFO.RowRange(1*S, T*S), kNoTrans, b_w_gifo_x_, kNoTrans, 1.0);
// backward pass dropout
- //if (dropout_rate_ != 0.0) {
+ // if (dropout_rate_ != 0.0) {
// in_diff->MulElements(dropout_mask_);
//}
@@ -947,26 +910,26 @@ class BLstmProjectedStreams : public UpdatableComponent {
// forward direction
// weight x -> g, i, f, o
- f_w_gifo_x_corr_.AddMatMat(1.0, F_DGIFO.RowRange(1*S,T*S), kTrans,
+ f_w_gifo_x_corr_.AddMatMat(1.0, F_DGIFO.RowRange(1*S, T*S), kTrans,
in, kNoTrans, mmt);
// recurrent weight r -> g, i, f, o
- f_w_gifo_r_corr_.AddMatMat(1.0, F_DGIFO.RowRange(1*S,T*S), kTrans,
- F_YR.RowRange(0*S,T*S), kNoTrans, mmt);
+ f_w_gifo_r_corr_.AddMatMat(1.0, F_DGIFO.RowRange(1*S, T*S), kTrans,
+ F_YR.RowRange(0*S, T*S), kNoTrans, mmt);
// bias of g, i, f, o
- f_bias_corr_.AddRowSumMat(1.0, F_DGIFO.RowRange(1*S,T*S), mmt);
+ f_bias_corr_.AddRowSumMat(1.0, F_DGIFO.RowRange(1*S, T*S), mmt);
// recurrent peephole c -> i
- f_peephole_i_c_corr_.AddDiagMatMat(1.0, F_DI.RowRange(1*S,T*S), kTrans,
- F_YC.RowRange(0*S,T*S), kNoTrans, mmt);
+ f_peephole_i_c_corr_.AddDiagMatMat(1.0, F_DI.RowRange(1*S, T*S), kTrans,
+ F_YC.RowRange(0*S, T*S), kNoTrans, mmt);
// recurrent peephole c -> f
- f_peephole_f_c_corr_.AddDiagMatMat(1.0, F_DF.RowRange(1*S,T*S), kTrans,
- F_YC.RowRange(0*S,T*S), kNoTrans, mmt);
+ f_peephole_f_c_corr_.AddDiagMatMat(1.0, F_DF.RowRange(1*S, T*S), kTrans,
+ F_YC.RowRange(0*S, T*S), kNoTrans, mmt);
// peephole c -> o
- f_peephole_o_c_corr_.AddDiagMatMat(1.0, F_DO.RowRange(1*S,T*S), kTrans,
- F_YC.RowRange(1*S,T*S), kNoTrans, mmt);
+ f_peephole_o_c_corr_.AddDiagMatMat(1.0, F_DO.RowRange(1*S, T*S), kTrans,
+ F_YC.RowRange(1*S, T*S), kNoTrans, mmt);
- f_w_r_m_corr_.AddMatMat(1.0, F_DR.RowRange(1*S,T*S), kTrans,
- F_YM.RowRange(1*S,T*S), kNoTrans, mmt);
+ f_w_r_m_corr_.AddMatMat(1.0, F_DR.RowRange(1*S, T*S), kTrans,
+ F_YM.RowRange(1*S, T*S), kNoTrans, mmt);
// apply the gradient clipping for forwardpass gradients
if (clip_gradient_ > 0.0) {
@@ -988,25 +951,25 @@ class BLstmProjectedStreams : public UpdatableComponent {
// backward direction backpropagate
// weight x -> g, i, f, o
- b_w_gifo_x_corr_.AddMatMat(1.0, B_DGIFO.RowRange(1*S,T*S), kTrans, in, kNoTrans, mmt);
+ b_w_gifo_x_corr_.AddMatMat(1.0, B_DGIFO.RowRange(1*S, T*S), kTrans, in, kNoTrans, mmt);
// recurrent weight r -> g, i, f, o
- b_w_gifo_r_corr_.AddMatMat(1.0, B_DGIFO.RowRange(1*S,T*S), kTrans,
- B_YR.RowRange(0*S,T*S) , kNoTrans, mmt);
+ b_w_gifo_r_corr_.AddMatMat(1.0, B_DGIFO.RowRange(1*S, T*S), kTrans,
+ B_YR.RowRange(0*S, T*S) , kNoTrans, mmt);
// bias of g, i, f, o
- b_bias_corr_.AddRowSumMat(1.0, B_DGIFO.RowRange(1*S,T*S), mmt);
-
- // recurrent peephole c -> i, c(t+1) --> i ##commented by chongjia
- b_peephole_i_c_corr_.AddDiagMatMat(1.0, B_DI.RowRange(1*S,T*S), kTrans,
- B_YC.RowRange(2*S,T*S), kNoTrans, mmt);
- // recurrent peephole c -> f, c(t+1) --> f ###commented by chongjia
- b_peephole_f_c_corr_.AddDiagMatMat(1.0, B_DF.RowRange(1*S,T*S), kTrans,
- B_YC.RowRange(2*S,T*S), kNoTrans, mmt);
+ b_bias_corr_.AddRowSumMat(1.0, B_DGIFO.RowRange(1*S, T*S), mmt);
+
+ // recurrent peephole c -> i, c(t+1) --> i
+ b_peephole_i_c_corr_.AddDiagMatMat(1.0, B_DI.RowRange(1*S, T*S), kTrans,
+ B_YC.RowRange(2*S, T*S), kNoTrans, mmt);
+ // recurrent peephole c -> f, c(t+1) --> f
+ b_peephole_f_c_corr_.AddDiagMatMat(1.0, B_DF.RowRange(1*S, T*S), kTrans,
+ B_YC.RowRange(2*S, T*S), kNoTrans, mmt);
// peephole c -> o
- b_peephole_o_c_corr_.AddDiagMatMat(1.0, B_DO.RowRange(1*S,T*S), kTrans,
- B_YC.RowRange(1*S,T*S), kNoTrans, mmt);
+ b_peephole_o_c_corr_.AddDiagMatMat(1.0, B_DO.RowRange(1*S, T*S), kTrans,
+ B_YC.RowRange(1*S, T*S), kNoTrans, mmt);
- b_w_r_m_corr_.AddMatMat(1.0, B_DR.RowRange(1*S,T*S), kTrans,
- B_YM.RowRange(1*S,T*S), kNoTrans, mmt);
+ b_w_r_m_corr_.AddMatMat(1.0, B_DR.RowRange(1*S, T*S), kTrans,
+ B_YM.RowRange(1*S, T*S), kNoTrans, mmt);
// apply the gradient clipping for backwardpass gradients
if (clip_gradient_ > 0.0) {
@@ -1083,16 +1046,14 @@ class BLstmProjectedStreams : public UpdatableComponent {
int32 ncell_; ///< the number of cell blocks
int32 nrecur_; ///< recurrent projection layer dim
int32 nstream_;
-
- CuMatrix f_prev_nnet_state_;
- CuMatrix b_prev_nnet_state_;
+ std::vector sequence_lengths_;
// gradient-clipping value,
BaseFloat clip_gradient_;
// non-recurrent dropout
- //BaseFloat dropout_rate_;
- //CuMatrix dropout_mask_;
+ // BaseFloat dropout_rate_;
+ // CuMatrix dropout_mask_;
// feed-forward connections: from x to [g, i, f, o]
// forward direction
@@ -1160,7 +1121,7 @@ class BLstmProjectedStreams : public UpdatableComponent {
CuMatrix b_backpropagate_buf_;
};
-} // namespace nnet1
-} // namespace kaldi
+} // namespace nnet1
+} // namespace kaldi
#endif
diff --git a/src/nnet/nnet-component.cc b/src/nnet/nnet-component.cc
index 65f872fdc74ece29c857bf619e2cc3ecae2d95f9..42eb0166181250b3eb19f42b226b07396bb99656 100644
--- a/src/nnet/nnet-component.cc
+++ b/src/nnet/nnet-component.cc
@@ -120,7 +120,7 @@ Component* Component::NewComponentOfType(ComponentType comp_type,
ans = new LstmProjectedStreams(input_dim, output_dim);
break;
case Component::kBLstmProjectedStreams :
- ans = new BLstmProjectedStreams(input_dim, output_dim);
+ ans = new BLstmProjectedStreams(input_dim, output_dim);
break;
case Component::kSoftmax :
ans = new Softmax(input_dim, output_dim);
diff --git a/src/nnet/nnet-component.h b/src/nnet/nnet-component.h
index 75c6a482bf49c4c0586e8050ee914027c703a550..2ceabf851df1189936ed7de52fc5eccbaed9d0dc 100644
--- a/src/nnet/nnet-component.h
+++ b/src/nnet/nnet-component.h
@@ -91,7 +91,7 @@ class Component {
/// Convert component type to marker
static const char* TypeToMarker(ComponentType t);
/// Convert marker to component type (case insensitive)
- static ComponentType MarkerToType(const std::string &s);
+ static ComponentType MarkerToType(const std::string &s);
/// General interface of a component
public:
diff --git a/src/nnet/nnet-nnet.cc b/src/nnet/nnet-nnet.cc
index 5d001e3b04abf54f4bd7818c85a367b65b974788..cdad2938f89c5dcfca1ce35e27aade1bed6a2a00 100644
--- a/src/nnet/nnet-nnet.cc
+++ b/src/nnet/nnet-nnet.cc
@@ -32,7 +32,7 @@ namespace nnet1 {
Nnet::Nnet(const Nnet& other) {
// copy the components
- for(int32 i=0; i &stream_reset_flag) {
if (GetComponent(c).GetType() == Component::kLstmProjectedStreams) {
LstmProjectedStreams& comp = dynamic_cast(GetComponent(c));
comp.ResetLstmStreams(stream_reset_flag);
- }
+ }
+ }
+}
+
+void Nnet::SetSeqLengths(const std::vector &sequence_lengths) {
+ for (int32 c=0; c < NumComponents(); c++) {
if (GetComponent(c).GetType() == Component::kBLstmProjectedStreams) {
BLstmProjectedStreams& comp = dynamic_cast(GetComponent(c));
- comp.ResetLstmStreams(stream_reset_flag);
+ comp.SetSeqLengths(sequence_lengths);
}
}
}
-
void Nnet::Init(const std::string &file) {
Input in(file);
std::istream &is = in.Stream();
diff --git a/src/nnet/nnet-nnet.h b/src/nnet/nnet-nnet.h
index b387f10400a61f84894a29f38e6c6126f9c2d4d8..f33f1fbde6532291fea78e5168bd11b37162e48f 100644
--- a/src/nnet/nnet-nnet.h
+++ b/src/nnet/nnet-nnet.h
@@ -36,23 +36,23 @@ namespace nnet1 {
class Nnet {
public:
Nnet() {}
- Nnet(const Nnet& other); // Copy constructor.
+ Nnet(const Nnet& other); // Copy constructor.
Nnet &operator = (const Nnet& other); // Assignment operator.
- ~Nnet();
+ ~Nnet();
public:
/// Perform forward pass through the network
- void Propagate(const CuMatrixBase &in, CuMatrix *out);
+ void Propagate(const CuMatrixBase &in, CuMatrix *out);
/// Perform backward pass through the network
void Backpropagate(const CuMatrixBase &out_diff, CuMatrix *in_diff);
/// Perform forward pass through the network, don't keep buffers (use it when not training)
- void Feedforward(const CuMatrixBase &in, CuMatrix *out);
+ void Feedforward(const CuMatrixBase &in, CuMatrix *out);
/// Dimensionality on network input (input feature dim.)
- int32 InputDim() const;
+ int32 InputDim() const;
/// Dimensionality of network outputs (posteriors | bn-features | etc.)
- int32 OutputDim() const;
+ int32 OutputDim() const;
/// Returns number of components-- think of this as similar to # of layers, but
/// e.g. the nonlinearity and the linear part count as separate components,
@@ -65,7 +65,7 @@ class Nnet {
/// Sets the c'th component to "component", taking ownership of the pointer
/// and deleting the corresponding one that we own.
void SetComponent(int32 c, Component *component);
-
+
/// Appends this component to the components already in the neural net.
/// Takes ownership of the pointer
void AppendComponent(Component *dynamically_allocated_comp);
@@ -77,12 +77,12 @@ class Nnet {
void RemoveLastComponent() { RemoveComponent(NumComponents()-1); }
/// Access to forward pass buffers
- const std::vector >& PropagateBuffer() const {
- return propagate_buf_;
+ const std::vector >& PropagateBuffer() const {
+ return propagate_buf_;
}
/// Access to backward pass buffers
- const std::vector >& BackpropagateBuffer() const {
- return backpropagate_buf_;
+ const std::vector >& BackpropagateBuffer() const {
+ return backpropagate_buf_;
}
/// Get the number of parameters in the network
@@ -96,22 +96,25 @@ class Nnet {
/// Get the gradient stored in the network
void GetGradient(Vector* grad_copy) const;
- /// Set the dropout rate
+ /// Set the dropout rate
void SetDropoutRetention(BaseFloat r);
/// Reset streams in LSTM multi-stream training,
void ResetLstmStreams(const std::vector &stream_reset_flag);
+ /// set sequence length in LSTM multi-stream training
+ void SetSeqLengths(const std::vector &sequence_lengths);
+
/// Initialize MLP from config
void Init(const std::string &config_file);
/// Read the MLP from file (can add layers to exisiting instance of Nnet)
- void Read(const std::string &file);
+ void Read(const std::string &file);
/// Read the MLP from stream (can add layers to exisiting instance of Nnet)
- void Read(std::istream &in, bool binary);
+ void Read(std::istream &in, bool binary);
/// Write MLP to file
void Write(const std::string &file, bool binary) const;
- /// Write MLP to stream
- void Write(std::ostream &out, bool binary) const;
-
+ /// Write MLP to stream
+ void Write(std::ostream &out, bool binary) const;
+
/// Create string with human readable description of the nnet
std::string Info() const;
/// Create string with per-component gradient statistics
@@ -135,18 +138,17 @@ class Nnet {
private:
/// Vector which contains all the components composing the neural network,
/// the components are for example: AffineTransform, Sigmoid, Softmax
- std::vector components_;
+ std::vector components_;
- std::vector > propagate_buf_; ///< buffers for forward pass
- std::vector > backpropagate_buf_; ///< buffers for backward pass
+ std::vector > propagate_buf_; ///< buffers for forward pass
+ std::vector > backpropagate_buf_; ///< buffers for backward pass
/// Option class with hyper-parameters passed to UpdatableComponent(s)
NnetTrainOptions opts_;
};
-
-} // namespace nnet1
-} // namespace kaldi
+} // namespace nnet1
+} // namespace kaldi
#endif // KALDI_NNET_NNET_NNET_H_
diff --git a/src/nnetbin/Makefile b/src/nnetbin/Makefile
index ffa6163afc033af88c45fbf282031fe91632bc6f..1475041b99a0aca6320ee88cb4b666221e7c8172 100644
--- a/src/nnetbin/Makefile
+++ b/src/nnetbin/Makefile
@@ -10,7 +10,7 @@ BINFILES = nnet-train-frmshuff \
nnet-train-perutt \
nnet-train-mmi-sequential \
nnet-train-mpe-sequential \
- nnet-train-lstm-streams \
+ nnet-train-lstm-streams nnet-train-blstm-streams \
rbm-train-cd1-frmshuff rbm-convert-to-nnet \
nnet-forward nnet-copy nnet-info nnet-concat \
transf-to-nnet cmvn-to-nnet nnet-initialize \
diff --git a/src/nnetbin/nnet-train-blstm-streams.cc b/src/nnetbin/nnet-train-blstm-streams.cc
new file mode 100644
index 0000000000000000000000000000000000000000..0ba2f5aed4596e2f4f47db3b3ca305bcf75292b0
--- /dev/null
+++ b/src/nnetbin/nnet-train-blstm-streams.cc
@@ -0,0 +1,293 @@
+// nnetbin/nnet-train-blstm-parallel.cc
+
+// Copyright 2015 Chongjia Ni
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
+// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
+// MERCHANTABLITY OR NON-INFRINGEMENT.
+// See the Apache 2 License for the specific language governing permissions and
+// limitations under the License.
+
+#include "nnet/nnet-trnopts.h"
+#include "nnet/nnet-nnet.h"
+#include "nnet/nnet-loss.h"
+#include "nnet/nnet-randomizer.h"
+#include "base/kaldi-common.h"
+#include "util/common-utils.h"
+#include "base/timer.h"
+#include "cudamatrix/cu-device.h"
+
+int main(int argc, char *argv[]) {
+ using namespace kaldi;
+ using namespace kaldi::nnet1;
+ typedef kaldi::int32 int32;
+
+ try {
+ const char *usage =
+ "Perform one iteration of senones training by SGD.\n"
+ "The updates are done per-utternace and by processing multiple utterances in parallel.\n"
+ "\n"
+ "Usage: nnet-train-blstm-streams [options] []\n"
+ "e.g.: \n"
+ " nnet-train-blstm-streams scp:feature.scp ark:labels.ark nnet.init nnet.iter1\n";
+
+ ParseOptions po(usage);
+ // training options
+ NnetTrainOptions trn_opts;
+ trn_opts.Register(&po);
+
+ bool binary = true,
+ crossvalidate = false;
+ po.Register("binary", &binary, "Write model in binary mode");
+ po.Register("cross-validate", &crossvalidate, "Perform cross-validation (no backpropagation)");
+
+ std::string feature_transform;
+ po.Register("feature-transform", &feature_transform, "Feature transform in Nnet format");
+
+ int32 length_tolerance = 5;
+ po.Register("length-tolerance", &length_tolerance, "Allowed length difference of features/targets (frames)");
+
+ std::string frame_weights;
+ po.Register("frame-weights", &frame_weights, "Per-frame weights to scale gradients (frame selection/weighting).");
+
+ std::string objective_function = "xent";
+ po.Register("objective-function", &objective_function, "Objective function : xent|mse");
+
+ int32 num_streams = 4;
+ po.Register("num_streams", &num_streams, "Number of sequences processed in parallel");
+
+ double frame_limit = 100000;
+ po.Register("frame-limit", &frame_limit, "Max number of frames to be processed");
+
+ int32 report_step = 100;
+ po.Register("report-step", &report_step, "Step (number of sequences) for status reporting");
+
+ std::string use_gpu = "yes";
+ // po.Register("use-gpu", &use_gpu, "yes|no|optional, only has effect if compiled with CUDA");
+
+ po.Read(argc, argv);
+
+ if (po.NumArgs() != 4-(crossvalidate?1:0)) {
+ po.PrintUsage();
+ exit(1);
+ }
+
+ std::string feature_rspecifier = po.GetArg(1),
+ targets_rspecifier = po.GetArg(2),
+ model_filename = po.GetArg(3);
+
+ std::string target_model_filename;
+ if (!crossvalidate) {
+ target_model_filename = po.GetArg(4);
+ }
+
+ using namespace kaldi;
+ using namespace kaldi::nnet1;
+ typedef kaldi::int32 int32;
+ Vector weights;
+ // Select the GPU
+#if HAVE_CUDA == 1
+ CuDevice::Instantiate().SelectGpuId(use_gpu);
+#endif
+
+ Nnet nnet_transf;
+ if ( feature_transform != "" ) {
+ nnet_transf.Read(feature_transform);
+ }
+
+ Nnet nnet;
+ nnet.Read(model_filename);
+ nnet.SetTrainOptions(trn_opts);
+
+ kaldi::int64 total_frames = 0;
+
+ // Initialize feature ans labels readers
+ SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
+ RandomAccessPosteriorReader targets_reader(targets_rspecifier);
+ RandomAccessBaseFloatVectorReader weights_reader;
+ if (frame_weights != "") {
+ weights_reader.Open(frame_weights);
+ }
+
+ Xent xent;
+ Mse mse;
+
+ CuMatrix feats, feats_transf, nnet_out, obj_diff;
+
+ Timer time;
+ KALDI_LOG << (crossvalidate?"CROSS-VALIDATION":"TRAINING") << " STARTED";
+ // Feature matrix of every utterance
+ std::vector< Matrix > feats_utt(num_streams);
+ // Label vector of every utterance
+ std::vector< Posterior > labels_utt(num_streams);
+ std::vector< Vector > weights_utt(num_streams);
+
+ int32 feat_dim = nnet.InputDim();
+
+ int32 num_done = 0, num_no_tgt_mat = 0, num_other_error = 0;
+ while (1) {
+
+ std::vector frame_num_utt;
+ int32 sequence_index = 0, max_frame_num = 0;
+
+ for ( ; !feature_reader.Done(); feature_reader.Next()) {
+ std::string utt = feature_reader.Key();
+ // Check that we have targets
+ if (!targets_reader.HasKey(utt)) {
+ KALDI_WARN << utt << ", missing targets";
+ num_no_tgt_mat++;
+ continue;
+ }
+ // Get feature / target pair
+ Matrix mat = feature_reader.Value();
+ Posterior targets = targets_reader.Value(utt);
+
+ if (frame_weights != "") {
+ weights = weights_reader.Value(utt);
+ } else { // all per-frame weights are 1.0
+ weights.Resize(mat.NumRows());
+ weights.Set(1.0);
+ }
+ // correct small length mismatch ... or drop sentence
+ {
+ // add lengths to vector
+ std::vector lenght;
+ lenght.push_back(mat.NumRows());
+ lenght.push_back(targets.size());
+ lenght.push_back(weights.Dim());
+ // find min, max
+ int32 min = *std::min_element(lenght.begin(), lenght.end());
+ int32 max = *std::max_element(lenght.begin(), lenght.end());
+ // fix or drop ?
+ if (max - min < length_tolerance) {
+ if (mat.NumRows() != min) mat.Resize(min, mat.NumCols(), kCopyData);
+ if (targets.size() != min) targets.resize(min);
+ if (weights.Dim() != min) weights.Resize(min, kCopyData);
+ } else {
+ KALDI_WARN << utt << ", length mismatch of targets " << targets.size()
+ << " and features " << mat.NumRows();
+ num_other_error++;
+ continue;
+ }
+ }
+
+ if (max_frame_num < mat.NumRows()) max_frame_num = mat.NumRows();
+ feats_utt[sequence_index] = mat;
+ labels_utt[sequence_index] = targets;
+ weights_utt[sequence_index] = weights;
+
+ frame_num_utt.push_back(mat.NumRows());
+ sequence_index++;
+ // If the total number of frames reaches frame_limit, then stop adding more sequences, regardless of whether
+ // the number of utterances reaches num_sequence or not.
+ if (frame_num_utt.size() == num_streams || frame_num_utt.size() * max_frame_num > frame_limit) {
+ feature_reader.Next(); break;
+ }
+ }
+ int32 cur_sequence_num = frame_num_utt.size();
+
+ // Create the final feature matrix. Every utterance is padded to the max length within this group of utterances
+ Matrix feat_mat_host(cur_sequence_num * max_frame_num, feat_dim, kSetZero);
+ Posterior target_host;
+ Vector weight_host;
+
+ target_host.resize(cur_sequence_num * max_frame_num);
+ weight_host.Resize(cur_sequence_num * max_frame_num, kSetZero);
+
+ for (int s = 0; s < cur_sequence_num; s++) {
+ Matrix mat_tmp = feats_utt[s];
+ for (int r = 0; r < frame_num_utt[s]; r++) {
+ feat_mat_host.Row(r*cur_sequence_num + s).CopyFromVec(mat_tmp.Row(r));
+ }
+ }
+
+ for (int s = 0; s < cur_sequence_num; s++) {
+ Posterior target_tmp = labels_utt[s];
+ for (int r = 0; r < frame_num_utt[s]; r++) {
+ target_host[r*cur_sequence_num+s] = target_tmp[r];
+ }
+ Vector weight_tmp = weights_utt[s];
+ for (int r = 0; r < frame_num_utt[s]; r++) {
+ weight_host(r*cur_sequence_num+s) = weight_tmp(r);
+ }
+ }
+
+ // transform feature
+ nnet_transf.Feedforward(CuMatrix(feat_mat_host), &feats_transf);
+
+ // Set the original lengths of utterances before padding
+ nnet.SetSeqLengths(frame_num_utt);
+
+ // Propagation and xent training
+ nnet.Propagate(feats_transf, &nnet_out);
+
+ if (objective_function == "xent") {
+ // gradients re-scaled by weights in Eval,
+ xent.Eval(weight_host, nnet_out, target_host, &obj_diff);
+ } else if (objective_function == "mse") {
+ // gradients re-scaled by weights in Eval,
+ mse.Eval(weight_host, nnet_out, target_host, &obj_diff);
+ } else {
+ KALDI_ERR << "Unknown objective function code : " << objective_function;
+ }
+
+ // Backward pass
+ if (!crossvalidate) {
+ nnet.Backpropagate(obj_diff, NULL);
+ }
+
+ // 1st minibatch : show what happens in network
+ if (kaldi::g_kaldi_verbose_level >= 2 && total_frames == 0) { // vlog-1
+ KALDI_VLOG(1) << "### After " << total_frames << " frames,";
+ KALDI_VLOG(1) << nnet.InfoPropagate();
+ if (!crossvalidate) {
+ KALDI_VLOG(1) << nnet.InfoBackPropagate();
+ KALDI_VLOG(1) << nnet.InfoGradient();
+ }
+ }
+
+ num_done += cur_sequence_num;
+ total_frames += feats_transf.NumRows();
+
+ if (feature_reader.Done()) break; // end loop of while(1)
+ }
+
+ // Check network parameters and gradients when training finishes
+ if (kaldi::g_kaldi_verbose_level >= 1) { // vlog-1
+ KALDI_VLOG(1) << "### After " << total_frames << " frames,";
+ KALDI_VLOG(1) << nnet.InfoPropagate();
+ if (!crossvalidate) {
+ KALDI_VLOG(1) << nnet.InfoBackPropagate();
+ KALDI_VLOG(1) << nnet.InfoGradient();
+ }
+ }
+
+ if (!crossvalidate) {
+ nnet.Write(target_model_filename, binary);
+ }
+
+ KALDI_LOG << "Done " << num_done << " files, " << num_no_tgt_mat
+ << " with no tgt_mats, " << num_other_error
+ << " with other errors. "
+ << "[" << (crossvalidate?"CROSS-VALIDATION":"TRAINING")
+ << ", " << time.Elapsed()/60 << " min, fps" << total_frames/time.Elapsed()
+ << "]";
+ KALDI_LOG << xent.Report();
+
+#if HAVE_CUDA == 1
+ CuDevice::Instantiate().PrintProfile();
+#endif
+
+ return 0;
+ } catch(const std::exception &e) {
+ std::cerr << e.what();
+ return -1;
+ }
+}