Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Open sidebar
Yoann HOUPERT
kaldi_2015
Commits
e9439dd3
Commit
e9439dd3
authored
Aug 12, 2015
by
nichongjia
Browse files
blstm remove bug according to karel
parent
4163db46
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
294 additions
and
1 deletion
+294
-1
src/nnet/nnet-nnet.h
src/nnet/nnet-nnet.h
+1
-1
src/nnetbin/nnet-train-blstm-streams.cc
src/nnetbin/nnet-train-blstm-streams.cc
+293
-0
No files found.
src/nnet/nnet-nnet.h
View file @
e9439dd3
...
...
@@ -133,7 +133,7 @@ class Nnet {
/// Get training hyper-parameters from the network
const
NnetTrainOptions
&
GetTrainOptions
()
const
{
return
opts_
;
}
}
private:
/// Vector which contains all the components composing the neural network,
...
...
src/nnetbin/nnet-train-blstm-streams.cc
0 → 100644
View file @
e9439dd3
// nnetbin/nnet-train-blstm-parallel.cc
// Copyright 2015 Chongjia Ni
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.
#include "nnet/nnet-trnopts.h"
#include "nnet/nnet-nnet.h"
#include "nnet/nnet-loss.h"
#include "nnet/nnet-randomizer.h"
#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "base/timer.h"
#include "cudamatrix/cu-device.h"
int
main
(
int
argc
,
char
*
argv
[])
{
using
namespace
kaldi
;
using
namespace
kaldi
::
nnet1
;
typedef
kaldi
::
int32
int32
;
try
{
const
char
*
usage
=
"Perform one iteration of senones training by SGD.
\n
"
"The updates are done per-utternace and by processing multiple utterances in parallel.
\n
"
"
\n
"
"Usage: nnet-train-blstm-streams [options] <feature-rspecifier> <labels-rspecifier> <model-in> [<model-out>]
\n
"
"e.g.:
\n
"
" nnet-train-blstm-streams scp:feature.scp ark:labels.ark nnet.init nnet.iter1
\n
"
;
ParseOptions
po
(
usage
);
NnetTrainOptions
trn_opts
;
// training options
trn_opts
.
Register
(
&
po
);
bool
binary
=
true
,
crossvalidate
=
false
;
po
.
Register
(
"binary"
,
&
binary
,
"Write model in binary mode"
);
po
.
Register
(
"cross-validate"
,
&
crossvalidate
,
"Perform cross-validation (no backpropagation)"
);
std
::
string
feature_transform
;
po
.
Register
(
"feature-transform"
,
&
feature_transform
,
"Feature transform in Nnet format"
);
int32
length_tolerance
=
5
;
po
.
Register
(
"length-tolerance"
,
&
length_tolerance
,
"Allowed length difference of features/targets (frames)"
);
std
::
string
frame_weights
;
po
.
Register
(
"frame-weights"
,
&
frame_weights
,
"Per-frame weights to scale gradients (frame selection/weighting)."
);
std
::
string
objective_function
=
"xent"
;
po
.
Register
(
"objective-function"
,
&
objective_function
,
"Objective function : xent|mse"
);
int32
num_streams
=
4
;
po
.
Register
(
"num_streams"
,
&
num_streams
,
"Number of sequences processed in parallel"
);
double
frame_limit
=
100000
;
po
.
Register
(
"frame-limit"
,
&
frame_limit
,
"Max number of frames to be processed"
);
int32
report_step
=
100
;
po
.
Register
(
"report-step"
,
&
report_step
,
"Step (number of sequences) for status reporting"
);
std
::
string
use_gpu
=
"yes"
;
// po.Register("use-gpu", &use_gpu, "yes|no|optional, only has effect if compiled with CUDA");
po
.
Read
(
argc
,
argv
);
if
(
po
.
NumArgs
()
!=
4
-
(
crossvalidate
?
1
:
0
))
{
po
.
PrintUsage
();
exit
(
1
);
}
std
::
string
feature_rspecifier
=
po
.
GetArg
(
1
),
targets_rspecifier
=
po
.
GetArg
(
2
),
model_filename
=
po
.
GetArg
(
3
);
std
::
string
target_model_filename
;
if
(
!
crossvalidate
)
{
target_model_filename
=
po
.
GetArg
(
4
);
}
using
namespace
kaldi
;
using
namespace
kaldi
::
nnet1
;
typedef
kaldi
::
int32
int32
;
Vector
<
BaseFloat
>
weights
;
//Select the GPU
#if HAVE_CUDA==1
CuDevice
::
Instantiate
().
SelectGpuId
(
use_gpu
);
#endif
Nnet
nnet_transf
;
if
(
feature_transform
!=
""
)
{
nnet_transf
.
Read
(
feature_transform
);
}
Nnet
nnet
;
nnet
.
Read
(
model_filename
);
nnet
.
SetTrainOptions
(
trn_opts
);
kaldi
::
int64
total_frames
=
0
;
// Initialize feature ans labels readers
SequentialBaseFloatMatrixReader
feature_reader
(
feature_rspecifier
);
RandomAccessPosteriorReader
targets_reader
(
targets_rspecifier
);
RandomAccessBaseFloatVectorReader
weights_reader
;
if
(
frame_weights
!=
""
)
{
weights_reader
.
Open
(
frame_weights
);
}
Xent
xent
;
Mse
mse
;
CuMatrix
<
BaseFloat
>
feats
,
feats_transf
,
nnet_out
,
obj_diff
;
Timer
time
;
KALDI_LOG
<<
(
crossvalidate
?
"CROSS-VALIDATION"
:
"TRAINING"
)
<<
" STARTED"
;
std
::
vector
<
Matrix
<
BaseFloat
>
>
feats_utt
(
num_streams
);
// Feature matrix of every utterance
std
::
vector
<
Posterior
>
labels_utt
(
num_streams
);
// Label vector of every utterance
std
::
vector
<
Vector
<
BaseFloat
>
>
weights_utt
(
num_streams
);
int32
feat_dim
=
nnet
.
InputDim
();
int32
num_done
=
0
,
num_no_tgt_mat
=
0
,
num_other_error
=
0
;
while
(
1
)
{
std
::
vector
<
int32
>
frame_num_utt
;
int32
sequence_index
=
0
,
max_frame_num
=
0
;
for
(
;
!
feature_reader
.
Done
();
feature_reader
.
Next
())
{
std
::
string
utt
=
feature_reader
.
Key
();
// Check that we have targets
if
(
!
targets_reader
.
HasKey
(
utt
))
{
KALDI_WARN
<<
utt
<<
", missing targets"
;
num_no_tgt_mat
++
;
continue
;
}
// Get feature / target pair
Matrix
<
BaseFloat
>
mat
=
feature_reader
.
Value
();
Posterior
targets
=
targets_reader
.
Value
(
utt
);
if
(
frame_weights
!=
""
)
{
weights
=
weights_reader
.
Value
(
utt
);
}
else
{
// all per-frame weights are 1.0
weights
.
Resize
(
mat
.
NumRows
());
weights
.
Set
(
1.0
);
}
// correct small length mismatch ... or drop sentence
{
// add lengths to vector
std
::
vector
<
int32
>
lenght
;
lenght
.
push_back
(
mat
.
NumRows
());
lenght
.
push_back
(
targets
.
size
());
lenght
.
push_back
(
weights
.
Dim
());
// find min, max
int32
min
=
*
std
::
min_element
(
lenght
.
begin
(),
lenght
.
end
());
int32
max
=
*
std
::
max_element
(
lenght
.
begin
(),
lenght
.
end
());
// fix or drop ?
if
(
max
-
min
<
length_tolerance
)
{
if
(
mat
.
NumRows
()
!=
min
)
mat
.
Resize
(
min
,
mat
.
NumCols
(),
kCopyData
);
if
(
targets
.
size
()
!=
min
)
targets
.
resize
(
min
);
if
(
weights
.
Dim
()
!=
min
)
weights
.
Resize
(
min
,
kCopyData
);
}
else
{
KALDI_WARN
<<
utt
<<
", length mismatch of targets "
<<
targets
.
size
()
<<
" and features "
<<
mat
.
NumRows
();
num_other_error
++
;
continue
;
}
}
if
(
max_frame_num
<
mat
.
NumRows
())
max_frame_num
=
mat
.
NumRows
();
feats_utt
[
sequence_index
]
=
mat
;
labels_utt
[
sequence_index
]
=
targets
;
weights_utt
[
sequence_index
]
=
weights
;
frame_num_utt
.
push_back
(
mat
.
NumRows
());
sequence_index
++
;
// If the total number of frames reaches frame_limit, then stop adding more sequences, regardless of whether
// the number of utterances reaches num_sequence or not.
if
(
frame_num_utt
.
size
()
==
num_streams
||
frame_num_utt
.
size
()
*
max_frame_num
>
frame_limit
)
{
feature_reader
.
Next
();
break
;
}
}
int32
cur_sequence_num
=
frame_num_utt
.
size
();
// Create the final feature matrix. Every utterance is padded to the max length within this group of utterances
Matrix
<
BaseFloat
>
feat_mat_host
(
cur_sequence_num
*
max_frame_num
,
feat_dim
,
kSetZero
);
Posterior
target_host
;
Vector
<
BaseFloat
>
weight_host
;
target_host
.
resize
(
cur_sequence_num
*
max_frame_num
);
weight_host
.
Resize
(
cur_sequence_num
*
max_frame_num
,
kSetZero
);
///
for
(
int
s
=
0
;
s
<
cur_sequence_num
;
s
++
)
{
Matrix
<
BaseFloat
>
mat_tmp
=
feats_utt
[
s
];
for
(
int
r
=
0
;
r
<
frame_num_utt
[
s
];
r
++
)
{
feat_mat_host
.
Row
(
r
*
cur_sequence_num
+
s
).
CopyFromVec
(
mat_tmp
.
Row
(
r
));
}
}
///
for
(
int
s
=
0
;
s
<
cur_sequence_num
;
s
++
)
{
Posterior
target_tmp
=
labels_utt
[
s
];
for
(
int
r
=
0
;
r
<
frame_num_utt
[
s
];
r
++
)
{
target_host
[
r
*
cur_sequence_num
+
s
]
=
target_tmp
[
r
];
}
Vector
<
BaseFloat
>
weight_tmp
=
weights_utt
[
s
];
for
(
int
r
=
0
;
r
<
frame_num_utt
[
s
];
r
++
)
{
weight_host
(
r
*
cur_sequence_num
+
s
)
=
weight_tmp
(
r
);
}
}
////create
nnet_transf
.
Feedforward
(
CuMatrix
<
BaseFloat
>
(
feat_mat_host
),
&
feats_transf
);
// Set the original lengths of utterances before padding
nnet
.
SetSeqLengths
(
frame_num_utt
);
// Propagation and xent training
nnet
.
Propagate
(
feats_transf
,
&
nnet_out
);
if
(
objective_function
==
"xent"
)
{
// gradients re-scaled by weights in Eval,
xent
.
Eval
(
weight_host
,
nnet_out
,
target_host
,
&
obj_diff
);
}
else
if
(
objective_function
==
"mse"
)
{
// gradients re-scaled by weights in Eval,
mse
.
Eval
(
weight_host
,
nnet_out
,
target_host
,
&
obj_diff
);
}
else
{
KALDI_ERR
<<
"Unknown objective function code : "
<<
objective_function
;
}
// Backward pass
if
(
!
crossvalidate
)
{
nnet
.
Backpropagate
(
obj_diff
,
NULL
);
}
// 1st minibatch : show what happens in network
if
(
kaldi
::
g_kaldi_verbose_level
>=
2
&&
total_frames
==
0
)
{
// vlog-1
KALDI_VLOG
(
1
)
<<
"### After "
<<
total_frames
<<
" frames,"
;
KALDI_VLOG
(
1
)
<<
nnet
.
InfoPropagate
();
if
(
!
crossvalidate
)
{
KALDI_VLOG
(
1
)
<<
nnet
.
InfoBackPropagate
();
KALDI_VLOG
(
1
)
<<
nnet
.
InfoGradient
();
}
}
num_done
+=
cur_sequence_num
;
total_frames
+=
feats_transf
.
NumRows
();
if
(
feature_reader
.
Done
())
break
;
// end loop of while(1)
}
// Check network parameters and gradients when training finishes
if
(
kaldi
::
g_kaldi_verbose_level
>=
1
)
{
// vlog-1
KALDI_VLOG
(
1
)
<<
"### After "
<<
total_frames
<<
" frames,"
;
KALDI_VLOG
(
1
)
<<
nnet
.
InfoPropagate
();
if
(
!
crossvalidate
)
{
KALDI_VLOG
(
1
)
<<
nnet
.
InfoBackPropagate
();
KALDI_VLOG
(
1
)
<<
nnet
.
InfoGradient
();
}
}
if
(
!
crossvalidate
)
{
nnet
.
Write
(
target_model_filename
,
binary
);
}
KALDI_LOG
<<
"Done "
<<
num_done
<<
" files, "
<<
num_no_tgt_mat
<<
" with no tgt_mats, "
<<
num_other_error
<<
" with other errors. "
<<
"["
<<
(
crossvalidate
?
"CROSS-VALIDATION"
:
"TRAINING"
)
<<
", "
<<
time
.
Elapsed
()
/
60
<<
" min, fps"
<<
total_frames
/
time
.
Elapsed
()
<<
"]"
;
KALDI_LOG
<<
xent
.
Report
();
#if HAVE_CUDA==1
CuDevice
::
Instantiate
().
PrintProfile
();
#endif
return
0
;
}
catch
(
const
std
::
exception
&
e
)
{
std
::
cerr
<<
e
.
what
();
return
-
1
;
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment