Commit afe298c3 authored by nichongjia's avatar nichongjia
Browse files

blstm remove bug

parent 05e3fbd5
...@@ -58,17 +58,17 @@ print "<NnetProto>" ...@@ -58,17 +58,17 @@ print "<NnetProto>"
# normally we won't use more than 2 layers of LSTM # normally we won't use more than 2 layers of LSTM
if o.num_layers == 1: if o.num_layers == 1:
print "<BLstmProjectedStreams> <InputDim> %d <OutputDim> %d <CellDim> %s <ParamScale> %f <ClipGradient> %f" % \ print "<BLstmProjectedStreams> <InputDim> %d <OutputDim> %d <CellDim> %s <ParamScale> %f <ClipGradient> %f" % \
(feat_dim, o.num_recurrent, o.num_cells, o.lstm_stddev_factor, o.clip_gradient) (feat_dim, 2*o.num_recurrent, o.num_cells, o.lstm_stddev_factor, o.clip_gradient)
elif o.num_layers == 2: elif o.num_layers == 2:
print "<BLstmProjectedStreams> <InputDim> %d <OutputDim> %d <CellDim> %s <ParamScale> %f <ClipGradient> %f" % \ print "<BLstmProjectedStreams> <InputDim> %d <OutputDim> %d <CellDim> %s <ParamScale> %f <ClipGradient> %f" % \
(feat_dim, o.num_recurrent, o.num_cells, o.lstm_stddev_factor, o.clip_gradient) (feat_dim, 2*o.num_recurrent, o.num_cells, o.lstm_stddev_factor, o.clip_gradient)
print "<BLstmProjectedStreams> <InputDim> %d <OutputDim> %d <CellDim> %s <ParamScale> %f <ClipGradient> %f" % \ print "<BLstmProjectedStreams> <InputDim> %d <OutputDim> %d <CellDim> %s <ParamScale> %f <ClipGradient> %f" % \
(o.num_recurrent, o.num_recurrent, o.num_cells, o.lstm_stddev_factor, o.clip_gradient) (2*o.num_recurrent, 2*o.num_recurrent, o.num_cells, o.lstm_stddev_factor, o.clip_gradient)
else: else:
sys.stderr.write("make_lstm_proto.py ERROR: more than 2 layers of LSTM, not supported yet.\n") sys.stderr.write("make_lstm_proto.py ERROR: more than 2 layers of LSTM, not supported yet.\n")
sys.exit(1) sys.exit(1)
print "<AffineTransform> <InputDim> %d <OutputDim> %d <BiasMean> 0.0 <BiasRange> 0.0 <ParamStddev> %f" % \ print "<AffineTransform> <InputDim> %d <OutputDim> %d <BiasMean> 0.0 <BiasRange> 0.0 <ParamStddev> %f" % \
(o.num_recurrent, num_leaves, o.param_stddev_factor) (2*o.num_recurrent, num_leaves, o.param_stddev_factor)
print "<Softmax> <InputDim> %d <OutputDim> %d" % \ print "<Softmax> <InputDim> %d <OutputDim> %d" % \
(num_leaves, num_leaves) (num_leaves, num_leaves)
print "</NnetProto>" print "</NnetProto>"
......
...@@ -19,8 +19,8 @@ ...@@ -19,8 +19,8 @@
#ifndef KALDI_NNET_NNET_BLSTM_PROJECTED_STREAMS_H_ #ifndef KALDI_NNET_BLSTM_PROJECTED_STREAMS_H_
#define KALDI_NNET_NNET_BLSTM_PROJECTED_STREAMS_H_ #define KALDI_NNET_BLSTM_PROJECTED_STREAMS_H_
#include "nnet/nnet-component.h" #include "nnet/nnet-component.h"
#include "nnet/nnet-utils.h" #include "nnet/nnet-utils.h"
...@@ -49,7 +49,7 @@ class BLstmProjectedStreams : public UpdatableComponent { ...@@ -49,7 +49,7 @@ class BLstmProjectedStreams : public UpdatableComponent {
BLstmProjectedStreams(int32 input_dim, int32 output_dim) : BLstmProjectedStreams(int32 input_dim, int32 output_dim) :
UpdatableComponent(input_dim, output_dim), UpdatableComponent(input_dim, output_dim),
ncell_(0), ncell_(0),
nrecur_(output_dim), nrecur_(int32(output_dim/2)),
nstream_(0), nstream_(0),
clip_gradient_(0.0) clip_gradient_(0.0)
//, dropout_rate_(0.0) //, dropout_rate_(0.0)
...@@ -74,6 +74,11 @@ class BLstmProjectedStreams : public UpdatableComponent { ...@@ -74,6 +74,11 @@ class BLstmProjectedStreams : public UpdatableComponent {
} }
v = tmp; v = tmp;
} }
/// set the utterance length used for parallel training
void SetSeqLengths(std::vector<int> &sequence_lengths) {
sequence_lengths_ = sequence_lengths;
}
void InitData(std::istream &is) { void InitData(std::istream &is) {
// define options // define options
...@@ -439,52 +444,9 @@ class BLstmProjectedStreams : public UpdatableComponent { ...@@ -439,52 +444,9 @@ class BLstmProjectedStreams : public UpdatableComponent {
"\n B_DR " + MomentStatistics(B_DR); "\n B_DR " + MomentStatistics(B_DR);
} }
void ResetLstmStreams(const std::vector<int32> &stream_reset_flag) {
// allocate f_prev_nnet_state_, b_prev_nnet_state_ if not done yet,
if (nstream_ == 0) {
// Karel: we just got number of streams! (before the 1st batch comes)
nstream_ = stream_reset_flag.size();
// forward direction
f_prev_nnet_state_.Resize(nstream_, 7*ncell_ + 1*nrecur_, kSetZero);
// backward direction
b_prev_nnet_state_.Resize(nstream_, 7*ncell_ + 1*nrecur_, kSetZero);
KALDI_LOG << "Running training with " << nstream_ << " streams.";
}
// reset flag: 1 - reset stream network state
KALDI_ASSERT(f_prev_nnet_state_.NumRows() == stream_reset_flag.size());
KALDI_ASSERT(b_prev_nnet_state_.NumRows() == stream_reset_flag.size());
for (int s = 0; s < stream_reset_flag.size(); s++) {
if (stream_reset_flag[s] == 1) {
// forward direction
f_prev_nnet_state_.Row(s).SetZero();
// backward direction
b_prev_nnet_state_.Row(s).SetZero();
}
}
}
void PropagateFnc(const CuMatrixBase<BaseFloat> &in, CuMatrixBase<BaseFloat> *out) { void PropagateFnc(const CuMatrixBase<BaseFloat> &in, CuMatrixBase<BaseFloat> *out) {
int DEBUG = 0; int DEBUG = 0;
int32 nstream_ = sequence_lengths_.size();
static bool do_stream_reset = false;
if (nstream_ == 0) {
do_stream_reset = true;
nstream_ = 1; // Karel: we are in nnet-forward, so we will use 1 stream,
// forward direction
f_prev_nnet_state_.Resize(nstream_, 7*ncell_ + 1*nrecur_, kSetZero);
// backward direction
b_prev_nnet_state_.Resize(nstream_, 7*ncell_ + 1*nrecur_, kSetZero);
KALDI_LOG << "Running nnet-forward with per-utterance BLSTM-state reset";
}
if (do_stream_reset) {
// resetting the forward and backward streams
f_prev_nnet_state_.SetZero();
b_prev_nnet_state_.SetZero();
}
KALDI_ASSERT(nstream_ > 0);
KALDI_ASSERT(in.NumRows() % nstream_ == 0); KALDI_ASSERT(in.NumRows() % nstream_ == 0);
int32 T = in.NumRows() / nstream_; int32 T = in.NumRows() / nstream_;
int32 S = nstream_; int32 S = nstream_;
...@@ -492,13 +454,9 @@ class BLstmProjectedStreams : public UpdatableComponent { ...@@ -492,13 +454,9 @@ class BLstmProjectedStreams : public UpdatableComponent {
// 0:forward pass history, [1, T]:current sequence, T+1:dummy // 0:forward pass history, [1, T]:current sequence, T+1:dummy
// forward direction // forward direction
f_propagate_buf_.Resize((T+2)*S, 7 * ncell_ + nrecur_, kSetZero); f_propagate_buf_.Resize((T+2)*S, 7 * ncell_ + nrecur_, kSetZero);
f_propagate_buf_.RowRange(0*S,S).CopyFromMat(f_prev_nnet_state_);
// backward direction // backward direction
b_propagate_buf_.Resize((T+2)*S, 7 * ncell_ + nrecur_, kSetZero); b_propagate_buf_.Resize((T+2)*S, 7 * ncell_ + nrecur_, kSetZero);
// for the backward direction, we initialize it at (T+1) frame
b_propagate_buf_.RowRange((T+1)*S,S).CopyFromMat(b_prev_nnet_state_);
// disassembling forward-pass forward-propagation buffer into different neurons, // disassembling forward-pass forward-propagation buffer into different neurons,
CuSubMatrix<BaseFloat> F_YG(f_propagate_buf_.ColRange(0*ncell_, ncell_)); CuSubMatrix<BaseFloat> F_YG(f_propagate_buf_.ColRange(0*ncell_, ncell_));
CuSubMatrix<BaseFloat> F_YI(f_propagate_buf_.ColRange(1*ncell_, ncell_)); CuSubMatrix<BaseFloat> F_YI(f_propagate_buf_.ColRange(1*ncell_, ncell_));
...@@ -532,6 +490,7 @@ class BLstmProjectedStreams : public UpdatableComponent { ...@@ -532,6 +490,7 @@ class BLstmProjectedStreams : public UpdatableComponent {
for (int t = 1; t <= T; t++) { for (int t = 1; t <= T; t++) {
// multistream buffers for current time-step // multistream buffers for current time-step
CuSubMatrix<BaseFloat> y_all(f_propagate_buf_.RowRange(t*S,S));
CuSubMatrix<BaseFloat> y_g(F_YG.RowRange(t*S,S)); CuSubMatrix<BaseFloat> y_g(F_YG.RowRange(t*S,S));
CuSubMatrix<BaseFloat> y_i(F_YI.RowRange(t*S,S)); CuSubMatrix<BaseFloat> y_i(F_YI.RowRange(t*S,S));
CuSubMatrix<BaseFloat> y_f(F_YF.RowRange(t*S,S)); CuSubMatrix<BaseFloat> y_f(F_YF.RowRange(t*S,S));
...@@ -582,6 +541,12 @@ class BLstmProjectedStreams : public UpdatableComponent { ...@@ -582,6 +541,12 @@ class BLstmProjectedStreams : public UpdatableComponent {
// m -> r // m -> r
y_r.AddMatMat(1.0, y_m, kNoTrans, f_w_r_m_, kTrans, 0.0); y_r.AddMatMat(1.0, y_m, kNoTrans, f_w_r_m_, kTrans, 0.0);
// set zeros
//for (int s = 0; s < S; s++) {
// if (t > sequence_lengths_[s])
// y_all.Row(s).SetZero();
//}
if (DEBUG) { if (DEBUG) {
std::cerr << "forward direction forward-pass frame " << t << "\n"; std::cerr << "forward direction forward-pass frame " << t << "\n";
...@@ -615,6 +580,7 @@ class BLstmProjectedStreams : public UpdatableComponent { ...@@ -615,6 +580,7 @@ class BLstmProjectedStreams : public UpdatableComponent {
// backward direction, from T to 1, t-- // backward direction, from T to 1, t--
for (int t = T; t >= 1; t--) { for (int t = T; t >= 1; t--) {
// multistream buffers for current time-step // multistream buffers for current time-step
CuSubMatrix<BaseFloat> y_all(b_propagate_buf_.RowRange(t*S,S));
CuSubMatrix<BaseFloat> y_g(B_YG.RowRange(t*S,S)); CuSubMatrix<BaseFloat> y_g(B_YG.RowRange(t*S,S));
CuSubMatrix<BaseFloat> y_i(B_YI.RowRange(t*S,S)); CuSubMatrix<BaseFloat> y_i(B_YI.RowRange(t*S,S));
CuSubMatrix<BaseFloat> y_f(B_YF.RowRange(t*S,S)); CuSubMatrix<BaseFloat> y_f(B_YF.RowRange(t*S,S));
...@@ -665,7 +631,12 @@ class BLstmProjectedStreams : public UpdatableComponent { ...@@ -665,7 +631,12 @@ class BLstmProjectedStreams : public UpdatableComponent {
// m -> r // m -> r
y_r.AddMatMat(1.0, y_m, kNoTrans, b_w_r_m_, kTrans, 0.0); y_r.AddMatMat(1.0, y_m, kNoTrans, b_w_r_m_, kTrans, 0.0);
for (int s = 0; s < S; s++) {
if (t > sequence_lengths_[s])
y_all.Row(s).SetZero();
}
if (DEBUG) { if (DEBUG) {
std::cerr << "backward direction forward-pass frame " << t << "\n"; std::cerr << "backward direction forward-pass frame " << t << "\n";
std::cerr << "activation of g: " << y_g; std::cerr << "activation of g: " << y_g;
...@@ -679,18 +650,17 @@ class BLstmProjectedStreams : public UpdatableComponent { ...@@ -679,18 +650,17 @@ class BLstmProjectedStreams : public UpdatableComponent {
} }
} }
// According to definition of BLSTM, for output YR of BLSTM, YR should be F_YR + B_YR /// final outputs now become the concatenation of the foward and backward activations
CuSubMatrix<BaseFloat> YR(F_YR.RowRange(1*S,T*S)); CuMatrix<BaseFloat> YR_FB;
YR.AddMat(1.0,B_YR.RowRange(1*S,T*S)); YR_FB.Resize((T+2)*S, 2 * nrecur_, kSetZero);
// forward part
YR_FB.ColRange(0, nrecur_).CopyFromMat(f_propagate_buf_.ColRange(7*ncell_, nrecur_));
// backward part
YR_FB.ColRange(nrecur_, nrecur_).CopyFromMat(b_propagate_buf_.ColRange(7*ncell_, nrecur_));
// recurrent projection layer is also feed-forward as BLSTM output // recurrent projection layer is also feed-forward as BLSTM output
out->CopyFromMat(YR);
out->CopyFromMat(YR_FB.RowRange(1*S,T*S));
// now the last frame state becomes previous network state for next batch
f_prev_nnet_state_.CopyFromMat(f_propagate_buf_.RowRange(T*S,S));
// now the last frame (,that is the first frame) becomes previous netwok state for next batch
b_prev_nnet_state_.CopyFromMat(b_propagate_buf_.RowRange(1*S,S));
} }
...@@ -698,7 +668,8 @@ class BLstmProjectedStreams : public UpdatableComponent { ...@@ -698,7 +668,8 @@ class BLstmProjectedStreams : public UpdatableComponent {
const CuMatrixBase<BaseFloat> &out_diff, CuMatrixBase<BaseFloat> *in_diff) { const CuMatrixBase<BaseFloat> &out_diff, CuMatrixBase<BaseFloat> *in_diff) {
int DEBUG = 0; int DEBUG = 0;
int32 nstream_ = sequence_lengths_.size(); // the number of sequences to be processed in parallel
int32 T = in.NumRows() / nstream_; int32 T = in.NumRows() / nstream_;
int32 S = nstream_; int32 S = nstream_;
// disassembling forward-pass forward-propagation buffer into different neurons, // disassembling forward-pass forward-propagation buffer into different neurons,
...@@ -727,7 +698,7 @@ class BLstmProjectedStreams : public UpdatableComponent { ...@@ -727,7 +698,7 @@ class BLstmProjectedStreams : public UpdatableComponent {
CuSubMatrix<BaseFloat> F_DGIFO(f_backpropagate_buf_.ColRange(0, 4*ncell_)); CuSubMatrix<BaseFloat> F_DGIFO(f_backpropagate_buf_.ColRange(0, 4*ncell_));
// projection layer to BLSTM output is not recurrent, so backprop it all in once // projection layer to BLSTM output is not recurrent, so backprop it all in once
F_DR.RowRange(1*S,T*S).CopyFromMat(out_diff); F_DR.RowRange(1*S,T*S).CopyFromMat(out_diff.ColRange(0, nrecur_));
for (int t = T; t >= 1; t--) { for (int t = T; t >= 1; t--) {
CuSubMatrix<BaseFloat> y_g(F_YG.RowRange(t*S,S)); CuSubMatrix<BaseFloat> y_g(F_YG.RowRange(t*S,S));
...@@ -747,7 +718,7 @@ class BLstmProjectedStreams : public UpdatableComponent { ...@@ -747,7 +718,7 @@ class BLstmProjectedStreams : public UpdatableComponent {
CuSubMatrix<BaseFloat> d_h(F_DH.RowRange(t*S,S)); CuSubMatrix<BaseFloat> d_h(F_DH.RowRange(t*S,S));
CuSubMatrix<BaseFloat> d_m(F_DM.RowRange(t*S,S)); CuSubMatrix<BaseFloat> d_m(F_DM.RowRange(t*S,S));
CuSubMatrix<BaseFloat> d_r(F_DR.RowRange(t*S,S)); CuSubMatrix<BaseFloat> d_r(F_DR.RowRange(t*S,S));
CuSubMatrix<BaseFloat> d_all(f_backpropagate_buf_.RowRange(t*S, S));
// r // r
// Version 1 (precise gradients): // Version 1 (precise gradients):
// backprop error from g(t+1), i(t+1), f(t+1), o(t+1) to r(t) // backprop error from g(t+1), i(t+1), f(t+1), o(t+1) to r(t)
...@@ -842,7 +813,7 @@ class BLstmProjectedStreams : public UpdatableComponent { ...@@ -842,7 +813,7 @@ class BLstmProjectedStreams : public UpdatableComponent {
CuSubMatrix<BaseFloat> B_DGIFO(b_backpropagate_buf_.ColRange(0, 4*ncell_)); CuSubMatrix<BaseFloat> B_DGIFO(b_backpropagate_buf_.ColRange(0, 4*ncell_));
// projection layer to BLSTM output is not recurrent, so backprop it all in once // projection layer to BLSTM output is not recurrent, so backprop it all in once
B_DR.RowRange(1*S,T*S).CopyFromMat(out_diff); B_DR.RowRange(1*S,T*S).CopyFromMat(out_diff.ColRange(nrecur_, nrecur_));
for (int t = 1; t <= T; t++) { for (int t = 1; t <= T; t++) {
CuSubMatrix<BaseFloat> y_g(B_YG.RowRange(t*S,S)); CuSubMatrix<BaseFloat> y_g(B_YG.RowRange(t*S,S));
...@@ -862,6 +833,7 @@ class BLstmProjectedStreams : public UpdatableComponent { ...@@ -862,6 +833,7 @@ class BLstmProjectedStreams : public UpdatableComponent {
CuSubMatrix<BaseFloat> d_h(B_DH.RowRange(t*S,S)); CuSubMatrix<BaseFloat> d_h(B_DH.RowRange(t*S,S));
CuSubMatrix<BaseFloat> d_m(B_DM.RowRange(t*S,S)); CuSubMatrix<BaseFloat> d_m(B_DM.RowRange(t*S,S));
CuSubMatrix<BaseFloat> d_r(B_DR.RowRange(t*S,S)); CuSubMatrix<BaseFloat> d_r(B_DR.RowRange(t*S,S));
CuSubMatrix<BaseFloat> d_all(b_backpropagate_buf_.RowRange(t*S, S));
// r // r
// Version 1 (precise gradients): // Version 1 (precise gradients):
...@@ -1083,9 +1055,7 @@ class BLstmProjectedStreams : public UpdatableComponent { ...@@ -1083,9 +1055,7 @@ class BLstmProjectedStreams : public UpdatableComponent {
int32 ncell_; ///< the number of cell blocks int32 ncell_; ///< the number of cell blocks
int32 nrecur_; ///< recurrent projection layer dim int32 nrecur_; ///< recurrent projection layer dim
int32 nstream_; int32 nstream_;
std::vector<int> sequence_lengths_;
CuMatrix<BaseFloat> f_prev_nnet_state_;
CuMatrix<BaseFloat> b_prev_nnet_state_;
// gradient-clipping value, // gradient-clipping value,
BaseFloat clip_gradient_; BaseFloat clip_gradient_;
......
...@@ -92,6 +92,8 @@ class Component { ...@@ -92,6 +92,8 @@ class Component {
static const char* TypeToMarker(ComponentType t); static const char* TypeToMarker(ComponentType t);
/// Convert marker to component type (case insensitive) /// Convert marker to component type (case insensitive)
static ComponentType MarkerToType(const std::string &s); static ComponentType MarkerToType(const std::string &s);
/// during training of LSTM models.
virtual void SetSeqLengths(std::vector<int> &sequence_lengths) { }
/// General interface of a component /// General interface of a component
public: public:
......
...@@ -356,11 +356,7 @@ void Nnet::ResetLstmStreams(const std::vector<int32> &stream_reset_flag) { ...@@ -356,11 +356,7 @@ void Nnet::ResetLstmStreams(const std::vector<int32> &stream_reset_flag) {
if (GetComponent(c).GetType() == Component::kLstmProjectedStreams) { if (GetComponent(c).GetType() == Component::kLstmProjectedStreams) {
LstmProjectedStreams& comp = dynamic_cast<LstmProjectedStreams&>(GetComponent(c)); LstmProjectedStreams& comp = dynamic_cast<LstmProjectedStreams&>(GetComponent(c));
comp.ResetLstmStreams(stream_reset_flag); comp.ResetLstmStreams(stream_reset_flag);
} }
if (GetComponent(c).GetType() == Component::kBLstmProjectedStreams) {
BLstmProjectedStreams& comp = dynamic_cast<BLstmProjectedStreams&>(GetComponent(c));
comp.ResetLstmStreams(stream_reset_flag);
}
} }
} }
......
...@@ -131,6 +131,12 @@ class Nnet { ...@@ -131,6 +131,12 @@ class Nnet {
const NnetTrainOptions& GetTrainOptions() const { const NnetTrainOptions& GetTrainOptions() const {
return opts_; return opts_;
} }
/// Set lengths of utterances for LSTM parallel training
void SetSeqLengths(std::vector<int> &sequence_lengths) {
for(int32 i=0; i < (int32)components_.size(); i++) {
components_[i]->SetSeqLengths(sequence_lengths);
}
}
private: private:
/// Vector which contains all the components composing the neural network, /// Vector which contains all the components composing the neural network,
......
...@@ -10,7 +10,7 @@ BINFILES = nnet-train-frmshuff \ ...@@ -10,7 +10,7 @@ BINFILES = nnet-train-frmshuff \
nnet-train-perutt \ nnet-train-perutt \
nnet-train-mmi-sequential \ nnet-train-mmi-sequential \
nnet-train-mpe-sequential \ nnet-train-mpe-sequential \
nnet-train-lstm-streams \ nnet-train-lstm-streams nnet-train-blstm-parallel \
rbm-train-cd1-frmshuff rbm-convert-to-nnet \ rbm-train-cd1-frmshuff rbm-convert-to-nnet \
nnet-forward nnet-copy nnet-info nnet-concat \ nnet-forward nnet-copy nnet-info nnet-concat \
transf-to-nnet cmvn-to-nnet nnet-initialize \ transf-to-nnet cmvn-to-nnet nnet-initialize \
......
// nnetbin/nnet-train-blstm-parallel.cc
// Copyright 2015 Chongjia Ni
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "nnet/nnet-trnopts.h"
#include "nnet/nnet-nnet.h"
#include "nnet/nnet-loss.h"
#include "nnet/nnet-randomizer.h"
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "base/timer.h"
#include "cudamatrix/cu-device.h"
int main(int argc, char *argv[]) {
using namespace kaldi;
using namespace kaldi::nnet1;
typedef kaldi::int32 int32;
try {
const char *usage =
"Perform one iteration of senones training by SGD.\n"
"The updates are done per-utternace and by processing multiple utterances in parallel.\n"
"\n"
"Usage: nnet-train-blstm-parallel [options] <feature-rspecifier> <labels-rspecifier> <model-in> [<model-out>]\n"
"e.g.: \n"
" nnet-train-blstm-parallel scp:feature.scp ark:labels.ark nnet.init nnet.iter1\n";
ParseOptions po(usage);
NnetTrainOptions trn_opts; // training options
trn_opts.Register(&po);
bool binary = true,
crossvalidate = false;
po.Register("binary", &binary, "Write model in binary mode");
po.Register("cross-validate", &crossvalidate, "Perform cross-validation (no backpropagation)");
std::string feature_transform;
po.Register("feature-transform", &feature_transform, "Feature transform in Nnet format");
int32 length_tolerance = 5;
po.Register("length-tolerance", &length_tolerance, "Allowed length difference of features/targets (frames)");
std::string frame_weights;
po.Register("frame-weights", &frame_weights, "Per-frame weights to scale gradients (frame selection/weighting).");
std::string objective_function = "xent";
po.Register("objective-function", &objective_function, "Objective function : xent|mse");
int32 num_sequence = 5;
po.Register("num-sequence", &num_sequence, "Number of sequences processed in parallel");
double frame_limit = 100000;
po.Register("frame-limit", &frame_limit, "Max number of frames to be processed");
int32 report_step=100;
po.Register("report-step", &report_step, "Step (number of sequences) for status reporting");
std::string use_gpu="yes";
// po.Register("use-gpu", &use_gpu, "yes|no|optional, only has effect if compiled with CUDA");
po.Read(argc, argv);
if (po.NumArgs() != 4-(crossvalidate?1:0)) {
po.PrintUsage();
exit(1);
}
std::string feature_rspecifier = po.GetArg(1),
targets_rspecifier = po.GetArg(2),
model_filename = po.GetArg(3);
std::string target_model_filename;
if (!crossvalidate) {
target_model_filename = po.GetArg(4);
}
using namespace kaldi;
using namespace kaldi::nnet1;
typedef kaldi::int32 int32;
Vector<BaseFloat> weights;
//Select the GPU
#if HAVE_CUDA==1
CuDevice::Instantiate().SelectGpuId(use_gpu);
#endif
Nnet nnet_transf;
if(feature_transform != "") {
nnet_transf.Read(feature_transform);
}
Nnet nnet;
nnet.Read(model_filename);
nnet.SetTrainOptions(trn_opts);
kaldi::int64 total_frames = 0;
// Initialize feature ans labels readers
SequentialBaseFloatMatrixReader feature_reader(feature_rspecifier);
RandomAccessPosteriorReader targets_reader(targets_rspecifier);
RandomAccessBaseFloatVectorReader weights_reader;
if (frame_weights != "") {
weights_reader.Open(frame_weights);
}
Xent xent;
Mse mse;
CuMatrix<BaseFloat> feats, feats_transf, nnet_out, obj_diff;
Timer time;
KALDI_LOG << (crossvalidate?"CROSS-VALIDATION":"TRAINING") << " STARTED";
std::vector< Matrix<BaseFloat> > feats_utt(num_sequence); // Feature matrix of every utterance
std::vector< Posterior > labels_utt(num_sequence); // Label vector of every utterance
std::vector< Vector<BaseFloat> > weights_utt(num_sequence);
int32 feat_dim = nnet.InputDim();
int32 num_done = 0, num_no_tgt_mat = 0, num_other_error = 0;
while (1) {
std::vector<int> frame_num_utt;
int32 sequence_index = 0, max_frame_num = 0;
for ( ; !feature_reader.Done(); feature_reader.Next()) {
std::string utt = feature_reader.Key();
// Check that we have targets
if (!targets_reader.HasKey(utt)) {
KALDI_WARN << utt << ", missing targets";
num_no_tgt_mat++;
continue;
}
// Get feature / target pair
Matrix<BaseFloat> mat = feature_reader.Value();
Posterior targets = targets_reader.Value(utt);
if (frame_weights != "") {
weights = weights_reader.Value(utt);
} else { // all per-frame weights are 1.0
weights.Resize(mat.NumRows());
weights.Set(1.0);
}
// correct small length mismatch ... or drop sentence
{
// add lengths to vector
std::vector<int32> lenght;
lenght.push_back(mat.NumRows());
lenght.push_back(targets.size());
lenght.push_back(weights.Dim());
// find min, max
int32 min = *std::min_element(lenght.begin(),lenght.end());
int32 max = *std::max_element(lenght.begin(),lenght.end());
// fix or drop ?
if (max - min < length_tolerance) {
if(mat.NumRows() != min) mat.Resize(min, mat.NumCols(), kCopyData);
if(targets.size() != min) targets.resize(min);
if(weights.Dim() != min) weights.Resize(min, kCopyData);
} else {
KALDI_WARN << utt << ", length mismatch of targets " << targets.size()
<< " and features " << mat.NumRows();
num_other_error++;
continue;
}
}
if (max_frame_num < mat.NumRows()) max_frame_num = mat.NumRows();
feats_utt[sequence_index] = mat;
labels_utt[sequence_index] = targets;
weights_utt[sequence_index] = weights;
frame_num_utt.push_back(mat.NumRows());
sequence_index++;
// If the total number of frames reaches frame_limit, then stop adding more sequences, regardless of whether
// the number of utterances reaches num_sequence or not.
if (frame_num_utt.size() == num_sequence || frame_num_utt.size() * max_frame_num > frame_limit) {
feature_reader.Next(); break;
}
}
int32 cur_sequence_num = frame_num_utt.size();
// Create the final feature matrix. Every utterance is padded to the max length within this group of utterances
Matrix<BaseFloat>