Commit 7afae3f8 authored by nichongjia's avatar nichongjia

google style code

parent e9439dd3
This diff is collapsed.
...@@ -32,7 +32,7 @@ namespace nnet1 { ...@@ -32,7 +32,7 @@ namespace nnet1 {
Nnet::Nnet(const Nnet& other) { Nnet::Nnet(const Nnet& other) {
// copy the components // copy the components
for(int32 i=0; i<other.NumComponents(); i++) { for(int32 i = 0; i < other.NumComponents(); i++) {
components_.push_back(other.GetComponent(i).Copy()); components_.push_back(other.GetComponent(i).Copy());
} }
// create empty buffers // create empty buffers
...@@ -46,7 +46,7 @@ Nnet::Nnet(const Nnet& other) { ...@@ -46,7 +46,7 @@ Nnet::Nnet(const Nnet& other) {
Nnet & Nnet::operator = (const Nnet& other) { Nnet & Nnet::operator = (const Nnet& other) {
Destroy(); Destroy();
// copy the components // copy the components
for(int32 i=0; i<other.NumComponents(); i++) { for(int32 i = 0; i < other.NumComponents(); i++) {
components_.push_back(other.GetComponent(i).Copy()); components_.push_back(other.GetComponent(i).Copy());
} }
// create empty buffers // create empty buffers
......
...@@ -147,7 +147,6 @@ class Nnet { ...@@ -147,7 +147,6 @@ class Nnet {
NnetTrainOptions opts_; NnetTrainOptions opts_;
}; };
} // namespace nnet1 } // namespace nnet1
} // namespace kaldi } // namespace kaldi
......
...@@ -39,8 +39,8 @@ int main(int argc, char *argv[]) { ...@@ -39,8 +39,8 @@ int main(int argc, char *argv[]) {
" nnet-train-blstm-streams scp:feature.scp ark:labels.ark nnet.init nnet.iter1\n"; " nnet-train-blstm-streams scp:feature.scp ark:labels.ark nnet.init nnet.iter1\n";
ParseOptions po(usage); ParseOptions po(usage);
// training options
NnetTrainOptions trn_opts; // training options NnetTrainOptions trn_opts;
trn_opts.Register(&po); trn_opts.Register(&po);
bool binary = true, bool binary = true,
...@@ -66,11 +66,11 @@ int main(int argc, char *argv[]) { ...@@ -66,11 +66,11 @@ int main(int argc, char *argv[]) {
double frame_limit = 100000; double frame_limit = 100000;
po.Register("frame-limit", &frame_limit, "Max number of frames to be processed"); po.Register("frame-limit", &frame_limit, "Max number of frames to be processed");
int32 report_step=100; int32 report_step = 100;
po.Register("report-step", &report_step, "Step (number of sequences) for status reporting"); po.Register("report-step", &report_step, "Step (number of sequences) for status reporting");
std::string use_gpu="yes"; std::string use_gpu = "yes";
// po.Register("use-gpu", &use_gpu, "yes|no|optional, only has effect if compiled with CUDA"); // po.Register("use-gpu", &use_gpu, "yes|no|optional, only has effect if compiled with CUDA");
po.Read(argc, argv); po.Read(argc, argv);
...@@ -92,13 +92,13 @@ int main(int argc, char *argv[]) { ...@@ -92,13 +92,13 @@ int main(int argc, char *argv[]) {
using namespace kaldi::nnet1; using namespace kaldi::nnet1;
typedef kaldi::int32 int32; typedef kaldi::int32 int32;
Vector<BaseFloat> weights; Vector<BaseFloat> weights;
//Select the GPU // Select the GPU
#if HAVE_CUDA==1 #if HAVE_CUDA == 1
CuDevice::Instantiate().SelectGpuId(use_gpu); CuDevice::Instantiate().SelectGpuId(use_gpu);
#endif #endif
Nnet nnet_transf; Nnet nnet_transf;
if(feature_transform != "") { if ( feature_transform != "" ) {
nnet_transf.Read(feature_transform); nnet_transf.Read(feature_transform);
} }
...@@ -123,9 +123,10 @@ int main(int argc, char *argv[]) { ...@@ -123,9 +123,10 @@ int main(int argc, char *argv[]) {
Timer time; Timer time;
KALDI_LOG << (crossvalidate?"CROSS-VALIDATION":"TRAINING") << " STARTED"; KALDI_LOG << (crossvalidate?"CROSS-VALIDATION":"TRAINING") << " STARTED";
// Feature matrix of every utterance
std::vector< Matrix<BaseFloat> > feats_utt(num_streams); // Feature matrix of every utterance std::vector< Matrix<BaseFloat> > feats_utt(num_streams);
std::vector< Posterior > labels_utt(num_streams); // Label vector of every utterance // Label vector of every utterance
std::vector< Posterior > labels_utt(num_streams);
std::vector< Vector<BaseFloat> > weights_utt(num_streams); std::vector< Vector<BaseFloat> > weights_utt(num_streams);
int32 feat_dim = nnet.InputDim(); int32 feat_dim = nnet.InputDim();
...@@ -162,13 +163,13 @@ int main(int argc, char *argv[]) { ...@@ -162,13 +163,13 @@ int main(int argc, char *argv[]) {
lenght.push_back(targets.size()); lenght.push_back(targets.size());
lenght.push_back(weights.Dim()); lenght.push_back(weights.Dim());
// find min, max // find min, max
int32 min = *std::min_element(lenght.begin(),lenght.end()); int32 min = *std::min_element(lenght.begin(), lenght.end());
int32 max = *std::max_element(lenght.begin(),lenght.end()); int32 max = *std::max_element(lenght.begin(), lenght.end());
// fix or drop ? // fix or drop ?
if (max - min < length_tolerance) { if (max - min < length_tolerance) {
if(mat.NumRows() != min) mat.Resize(min, mat.NumCols(), kCopyData); if (mat.NumRows() != min) mat.Resize(min, mat.NumCols(), kCopyData);
if(targets.size() != min) targets.resize(min); if (targets.size() != min) targets.resize(min);
if(weights.Dim() != min) weights.Resize(min, kCopyData); if (weights.Dim() != min) weights.Resize(min, kCopyData);
} else { } else {
KALDI_WARN << utt << ", length mismatch of targets " << targets.size() KALDI_WARN << utt << ", length mismatch of targets " << targets.size()
<< " and features " << mat.NumRows(); << " and features " << mat.NumRows();
...@@ -200,14 +201,13 @@ int main(int argc, char *argv[]) { ...@@ -200,14 +201,13 @@ int main(int argc, char *argv[]) {
target_host.resize(cur_sequence_num * max_frame_num); target_host.resize(cur_sequence_num * max_frame_num);
weight_host.Resize(cur_sequence_num * max_frame_num, kSetZero); weight_host.Resize(cur_sequence_num * max_frame_num, kSetZero);
///
for (int s = 0; s < cur_sequence_num; s++) { for (int s = 0; s < cur_sequence_num; s++) {
Matrix<BaseFloat> mat_tmp = feats_utt[s]; Matrix<BaseFloat> mat_tmp = feats_utt[s];
for (int r = 0; r < frame_num_utt[s]; r++) { for (int r = 0; r < frame_num_utt[s]; r++) {
feat_mat_host.Row(r*cur_sequence_num + s).CopyFromVec(mat_tmp.Row(r)); feat_mat_host.Row(r*cur_sequence_num + s).CopyFromVec(mat_tmp.Row(r));
} }
} }
///
for (int s = 0; s < cur_sequence_num; s++) { for (int s = 0; s < cur_sequence_num; s++) {
Posterior target_tmp = labels_utt[s]; Posterior target_tmp = labels_utt[s];
for (int r = 0; r < frame_num_utt[s]; r++) { for (int r = 0; r < frame_num_utt[s]; r++) {
...@@ -219,7 +219,7 @@ int main(int argc, char *argv[]) { ...@@ -219,7 +219,7 @@ int main(int argc, char *argv[]) {
} }
} }
////create // transform feature
nnet_transf.Feedforward(CuMatrix<BaseFloat>(feat_mat_host), &feats_transf); nnet_transf.Feedforward(CuMatrix<BaseFloat>(feat_mat_host), &feats_transf);
// Set the original lengths of utterances before padding // Set the original lengths of utterances before padding
...@@ -281,7 +281,7 @@ int main(int argc, char *argv[]) { ...@@ -281,7 +281,7 @@ int main(int argc, char *argv[]) {
<< "]"; << "]";
KALDI_LOG << xent.Report(); KALDI_LOG << xent.Report();
#if HAVE_CUDA==1 #if HAVE_CUDA == 1
CuDevice::Instantiate().PrintProfile(); CuDevice::Instantiate().PrintProfile();
#endif #endif
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment