Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Open sidebar
Yoann HOUPERT
kaldi_2015
Commits
5d15fb37
Commit
5d15fb37
authored
Aug 06, 2015
by
Jan "yenda" Trmal
Browse files
Merge pull request #60 from vesis84/nnet1_loss_fix
nnet1: bugfix in MultiTaskLoss,
parents
efec92b1
4f5d4b4c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
4 deletions
+11
-4
src/nnet/nnet-loss.cc
src/nnet/nnet-loss.cc
+11
-4
No files found.
src/nnet/nnet-loss.cc
View file @
5d15fb37
...
...
@@ -163,8 +163,9 @@ void Xent::Eval(const VectorBase<BaseFloat> &frame_weights,
std
::
string
Xent
::
Report
()
{
std
::
ostringstream
oss
;
oss
<<
"AvgLoss: "
<<
(
loss_
-
entropy_
)
/
frames_
<<
" (Xent), "
<<
"[AvgXent: "
<<
loss_
/
frames_
<<
", AvgTargetEnt: "
<<
entropy_
/
frames_
<<
"]"
<<
std
::
endl
;
<<
"[AvgXent "
<<
loss_
/
frames_
<<
", AvgTargetEnt "
<<
entropy_
/
frames_
<<
", frames "
<<
frames_
<<
"]"
<<
std
::
endl
;
if
(
loss_vec_
.
size
()
>
0
)
{
oss
<<
"progress: ["
;
std
::
copy
(
loss_vec_
.
begin
(),
loss_vec_
.
end
(),
std
::
ostream_iterator
<
float
>
(
oss
,
" "
));
...
...
@@ -257,7 +258,8 @@ std::string Mse::Report() {
BaseFloat
root_mean_square
=
sqrt
(
loss_
/
frames_
/
num_tgt
);
// build the message,
std
::
ostringstream
oss
;
oss
<<
"AvgLoss: "
<<
loss_
/
frames_
<<
" (Mse), "
<<
"[RMS "
<<
root_mean_square
<<
"]"
<<
std
::
endl
;
oss
<<
"AvgLoss: "
<<
loss_
/
frames_
<<
" (Mse), "
<<
"[RMS "
<<
root_mean_square
<<
", frames "
<<
frames_
<<
"]"
<<
std
::
endl
;
oss
<<
"progress: ["
;
std
::
copy
(
loss_vec_
.
begin
(),
loss_vec_
.
end
(),
std
::
ostream_iterator
<
float
>
(
oss
,
" "
));
oss
<<
"]"
<<
std
::
endl
;
...
...
@@ -372,7 +374,12 @@ std::string MultiTaskLoss::Report() {
BaseFloat
MultiTaskLoss
::
AvgLoss
()
{
BaseFloat
ans
(
0.0
);
for
(
int32
i
=
0
;
i
<
loss_vec_
.
size
();
i
++
)
{
ans
+=
loss_weights_
[
i
]
*
loss_vec_
[
i
]
->
AvgLoss
();
BaseFloat
val
=
loss_weights_
[
i
]
*
loss_vec_
[
i
]
->
AvgLoss
();
if
(
!
KALDI_ISFINITE
(
val
))
{
KALDI_WARN
<<
"Loss "
<<
i
+
1
<<
", has bad objective function value '"
<<
val
<<
"', using 0.0 instead."
;
val
=
0.0
;
}
ans
+=
val
;
}
return
ans
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment