Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
K
kaldi_2015
Project overview
Project overview
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Yoann HOUPERT
kaldi_2015
Commits
4f5d4b4c
Commit
4f5d4b4c
authored
Aug 06, 2015
by
vesis84
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
nnet1: bugfix in MultiTaskLoss,
- in case a task gets 0 frames, use 0.0 as objective value, avoid nan,
parent
efec92b1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
4 deletions
+11
-4
src/nnet/nnet-loss.cc
src/nnet/nnet-loss.cc
+11
-4
No files found.
src/nnet/nnet-loss.cc
View file @
4f5d4b4c
...
...
@@ -163,8 +163,9 @@ void Xent::Eval(const VectorBase<BaseFloat> &frame_weights,
std
::
string
Xent
::
Report
()
{
std
::
ostringstream
oss
;
oss
<<
"AvgLoss: "
<<
(
loss_
-
entropy_
)
/
frames_
<<
" (Xent), "
<<
"[AvgXent: "
<<
loss_
/
frames_
<<
", AvgTargetEnt: "
<<
entropy_
/
frames_
<<
"]"
<<
std
::
endl
;
<<
"[AvgXent "
<<
loss_
/
frames_
<<
", AvgTargetEnt "
<<
entropy_
/
frames_
<<
", frames "
<<
frames_
<<
"]"
<<
std
::
endl
;
if
(
loss_vec_
.
size
()
>
0
)
{
oss
<<
"progress: ["
;
std
::
copy
(
loss_vec_
.
begin
(),
loss_vec_
.
end
(),
std
::
ostream_iterator
<
float
>
(
oss
,
" "
));
...
...
@@ -257,7 +258,8 @@ std::string Mse::Report() {
BaseFloat
root_mean_square
=
sqrt
(
loss_
/
frames_
/
num_tgt
);
// build the message,
std
::
ostringstream
oss
;
oss
<<
"AvgLoss: "
<<
loss_
/
frames_
<<
" (Mse), "
<<
"[RMS "
<<
root_mean_square
<<
"]"
<<
std
::
endl
;
oss
<<
"AvgLoss: "
<<
loss_
/
frames_
<<
" (Mse), "
<<
"[RMS "
<<
root_mean_square
<<
", frames "
<<
frames_
<<
"]"
<<
std
::
endl
;
oss
<<
"progress: ["
;
std
::
copy
(
loss_vec_
.
begin
(),
loss_vec_
.
end
(),
std
::
ostream_iterator
<
float
>
(
oss
,
" "
));
oss
<<
"]"
<<
std
::
endl
;
...
...
@@ -372,7 +374,12 @@ std::string MultiTaskLoss::Report() {
BaseFloat
MultiTaskLoss
::
AvgLoss
()
{
BaseFloat
ans
(
0.0
);
for
(
int32
i
=
0
;
i
<
loss_vec_
.
size
();
i
++
)
{
ans
+=
loss_weights_
[
i
]
*
loss_vec_
[
i
]
->
AvgLoss
();
BaseFloat
val
=
loss_weights_
[
i
]
*
loss_vec_
[
i
]
->
AvgLoss
();
if
(
!
KALDI_ISFINITE
(
val
))
{
KALDI_WARN
<<
"Loss "
<<
i
+
1
<<
", has bad objective function value '"
<<
val
<<
"', using 0.0 instead."
;
val
=
0.0
;
}
ans
+=
val
;
}
return
ans
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment