Commit 5d15fb37 authored by Jan "yenda" Trmal's avatar Jan "yenda" Trmal

Merge pull request #60 from vesis84/nnet1_loss_fix

nnet1: bugfix in MultiTaskLoss,
parents efec92b1 4f5d4b4c
...@@ -163,8 +163,9 @@ void Xent::Eval(const VectorBase<BaseFloat> &frame_weights, ...@@ -163,8 +163,9 @@ void Xent::Eval(const VectorBase<BaseFloat> &frame_weights,
std::string Xent::Report() { std::string Xent::Report() {
std::ostringstream oss; std::ostringstream oss;
oss << "AvgLoss: " << (loss_-entropy_)/frames_ << " (Xent), " oss << "AvgLoss: " << (loss_-entropy_)/frames_ << " (Xent), "
<< "[AvgXent: " << loss_/frames_ << "[AvgXent " << loss_/frames_
<< ", AvgTargetEnt: " << entropy_/frames_ << "]" << std::endl; << ", AvgTargetEnt " << entropy_/frames_
<< ", frames " << frames_ << "]" << std::endl;
if (loss_vec_.size() > 0) { if (loss_vec_.size() > 0) {
oss << "progress: ["; oss << "progress: [";
std::copy(loss_vec_.begin(),loss_vec_.end(),std::ostream_iterator<float>(oss," ")); std::copy(loss_vec_.begin(),loss_vec_.end(),std::ostream_iterator<float>(oss," "));
...@@ -257,7 +258,8 @@ std::string Mse::Report() { ...@@ -257,7 +258,8 @@ std::string Mse::Report() {
BaseFloat root_mean_square = sqrt(loss_/frames_/num_tgt); BaseFloat root_mean_square = sqrt(loss_/frames_/num_tgt);
// build the message, // build the message,
std::ostringstream oss; std::ostringstream oss;
oss << "AvgLoss: " << loss_/frames_ << " (Mse), " << "[RMS " << root_mean_square << "]" << std::endl; oss << "AvgLoss: " << loss_/frames_ << " (Mse), "
<< "[RMS " << root_mean_square << ", frames " << frames_ << "]" << std::endl;
oss << "progress: ["; oss << "progress: [";
std::copy(loss_vec_.begin(),loss_vec_.end(),std::ostream_iterator<float>(oss," ")); std::copy(loss_vec_.begin(),loss_vec_.end(),std::ostream_iterator<float>(oss," "));
oss << "]" << std::endl; oss << "]" << std::endl;
...@@ -372,7 +374,12 @@ std::string MultiTaskLoss::Report() { ...@@ -372,7 +374,12 @@ std::string MultiTaskLoss::Report() {
BaseFloat MultiTaskLoss::AvgLoss() { BaseFloat MultiTaskLoss::AvgLoss() {
BaseFloat ans(0.0); BaseFloat ans(0.0);
for (int32 i = 0; i < loss_vec_.size(); i++) { for (int32 i = 0; i < loss_vec_.size(); i++) {
ans += loss_weights_[i] * loss_vec_[i]->AvgLoss(); BaseFloat val = loss_weights_[i] * loss_vec_[i]->AvgLoss();
if(!KALDI_ISFINITE(val)) {
KALDI_WARN << "Loss " << i+1 << ", has bad objective function value '" << val << "', using 0.0 instead.";
val = 0.0;
}
ans += val;
} }
return ans; return ans;
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment