Commit 4f5d4b4c authored by vesis84's avatar vesis84

nnet1: bugfix in MultiTaskLoss,

- in case a task gets 0 frames, use 0.0 as objective value, avoid nan,
parent efec92b1
......@@ -163,8 +163,9 @@ void Xent::Eval(const VectorBase<BaseFloat> &frame_weights,
std::string Xent::Report() {
std::ostringstream oss;
oss << "AvgLoss: " << (loss_-entropy_)/frames_ << " (Xent), "
<< "[AvgXent: " << loss_/frames_
<< ", AvgTargetEnt: " << entropy_/frames_ << "]" << std::endl;
<< "[AvgXent " << loss_/frames_
<< ", AvgTargetEnt " << entropy_/frames_
<< ", frames " << frames_ << "]" << std::endl;
if (loss_vec_.size() > 0) {
oss << "progress: [";
std::copy(loss_vec_.begin(),loss_vec_.end(),std::ostream_iterator<float>(oss," "));
......@@ -257,7 +258,8 @@ std::string Mse::Report() {
BaseFloat root_mean_square = sqrt(loss_/frames_/num_tgt);
// build the message,
std::ostringstream oss;
oss << "AvgLoss: " << loss_/frames_ << " (Mse), " << "[RMS " << root_mean_square << "]" << std::endl;
oss << "AvgLoss: " << loss_/frames_ << " (Mse), "
<< "[RMS " << root_mean_square << ", frames " << frames_ << "]" << std::endl;
oss << "progress: [";
std::copy(loss_vec_.begin(),loss_vec_.end(),std::ostream_iterator<float>(oss," "));
oss << "]" << std::endl;
......@@ -372,7 +374,12 @@ std::string MultiTaskLoss::Report() {
BaseFloat MultiTaskLoss::AvgLoss() {
BaseFloat ans(0.0);
for (int32 i = 0; i < loss_vec_.size(); i++) {
ans += loss_weights_[i] * loss_vec_[i]->AvgLoss();
BaseFloat val = loss_weights_[i] * loss_vec_[i]->AvgLoss();
if(!KALDI_ISFINITE(val)) {
KALDI_WARN << "Loss " << i+1 << ", has bad objective function value '" << val << "', using 0.0 instead.";
val = 0.0;
}
ans += val;
}
return ans;
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment