Commit e2700de1 authored by Karel Vesely's avatar Karel Vesely
Browse files

trunk,nnet1 : reimplementing the RBM training ani-weight explosion mechanism



git-svn-id: https://svn.code.sf.net/p/kaldi/code/trunk@4264 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
parent 55e226c0
......@@ -41,12 +41,9 @@ class RbmBase : public Component {
: Component(dim_in, dim_out)
{ }
/*Is included in Component:: itf
virtual void Propagate(
const CuMatrix<BaseFloat> &vis_probs,
CuMatrix<BaseFloat> *hid_probs
) = 0;
*/
// Inherited from Component::
// void Propagate(...)
// virtual void PropagateFnc(...) = 0
virtual void Reconstruct(
const CuMatrix<BaseFloat> &hid_state,
......@@ -76,17 +73,14 @@ class RbmBase : public Component {
protected:
RbmTrainOptions rbm_opts_;
//// Make these methods inaccessible for descendants.
//
private:
// For RBMs we use Reconstruct(.)
//// Make inherited methods inaccessible,
// as for RBMs we use Reconstruct(.)
void Backpropagate(const CuMatrix<BaseFloat> &in, const CuMatrix<BaseFloat> &out,
const CuMatrix<BaseFloat> &out_diff, CuMatrix<BaseFloat> *in_diff) { }
void BackpropagateFnc(const CuMatrix<BaseFloat> &in, const CuMatrix<BaseFloat> &out,
const CuMatrix<BaseFloat> &out_diff, CuMatrix<BaseFloat> *in_diff) { }
//
////
////
};
......@@ -225,7 +219,7 @@ class Rbm : public RbmBase {
// Component API
void PropagateFnc(const CuMatrix<BaseFloat> &in, CuMatrix<BaseFloat> *out) {
// precopy bias
// pre-fill with bias
out->AddVecToRows(1.0, hid_bias_, 0.0);
// multiply by weights^t
out->AddMatMat(1.0, in, kNoTrans, vis_hid_, kTrans, 1.0);
......@@ -246,7 +240,7 @@ class Rbm : public RbmBase {
vis_probs->Resize(hid_state.NumRows(), input_dim_);
}
// precopy bias
// pre-fill with bias
vis_probs->AddVecToRows(1.0, vis_bias_, 0.0);
// multiply by weights
vis_probs->AddMatMat(1.0, hid_state, kNoTrans, vis_hid_, kNoTrans, 1.0);
......@@ -257,7 +251,7 @@ class Rbm : public RbmBase {
}
void RbmUpdate(const CuMatrix<BaseFloat> &pos_vis, const CuMatrix<BaseFloat> &pos_hid, const CuMatrix<BaseFloat> &neg_vis, const CuMatrix<BaseFloat> &neg_hid) {
// dims
KALDI_ASSERT(pos_vis.NumRows() == pos_hid.NumRows() &&
pos_vis.NumRows() == neg_vis.NumRows() &&
pos_vis.NumRows() == neg_hid.NumRows() &&
......@@ -266,23 +260,20 @@ class Rbm : public RbmBase {
pos_vis.NumCols() == input_dim_ &&
pos_hid.NumCols() == output_dim_);
//lazy initialization of buffers
// lazy initialization of buffers
if ( vis_hid_corr_.NumRows() != vis_hid_.NumRows() ||
vis_hid_corr_.NumCols() != vis_hid_.NumCols() ||
vis_bias_corr_.Dim() != vis_bias_.Dim() ||
hid_bias_corr_.Dim() != hid_bias_.Dim() ){
vis_hid_corr_.Resize(vis_hid_.NumRows(),vis_hid_.NumCols(),kSetZero);
//vis_bias_corr_.Resize(vis_bias_.Dim(),kSetZero);
//hid_bias_corr_.Resize(hid_bias_.Dim(),kSetZero);
vis_bias_corr_.Resize(vis_bias_.Dim());
hid_bias_corr_.Resize(hid_bias_.Dim());
vis_bias_corr_.Resize(vis_bias_.Dim(), kSetZero);
hid_bias_corr_.Resize(hid_bias_.Dim(), kSetZero);
}
//
// ANTI-WEIGHT-EXPLOSION PROTECTION
// ANTI-WEIGHT-EXPLOSION PROTECTION (Gaussian-Bernoulli RBM)
// in the following section we detect that the weights in Gaussian-Bernoulli RBM
// are about to explode. The weight explosion is caused by large variance of the
// reconstructed data, which causes increase of weight variance towards the explosion.
// reconstructed data, which causes a feed-back loop that keeps increasing the weights.
//
// To avoid explosion, the variance of the visible-data and reconstructed-data
// should be about the same. The model is particularly sensitive at the very
......@@ -294,90 +285,46 @@ class Rbm : public RbmBase {
// 2. shrink learning rate by 0.9x
// 3. reset the momentum buffer
//
// Wa also display a warning. Note that in later stage
// the training returns back to higher learning rate.
// Also a warning message is put to log. Note that in later stage
// the learning-rate returns to its original value.
//
// An alternative approach is to use smaller values in weight-matrix initialization.
//
if (vis_type_ == RbmBase::Gaussian) {
//get the standard deviations of pos_vis and neg_vis data
//pos_vis
CuMatrix<BaseFloat> pos_vis_pow2(pos_vis);
pos_vis_pow2.MulElements(pos_vis);
CuVector<BaseFloat> pos_vis_second(pos_vis.NumCols());
pos_vis_second.AddRowSumMat(1.0,pos_vis_pow2,0.0);
CuVector<BaseFloat> pos_vis_mean(pos_vis.NumCols());
pos_vis_mean.AddRowSumMat(1.0/pos_vis.NumRows(),pos_vis,0.0);
Vector<BaseFloat> pos_vis_second_h(pos_vis_second.Dim());
pos_vis_second.CopyToVec(&pos_vis_second_h);
Vector<BaseFloat> pos_vis_mean_h(pos_vis_mean.Dim());
pos_vis_mean.CopyToVec(&pos_vis_mean_h);
Vector<BaseFloat> pos_vis_stddev(pos_vis_mean_h);
pos_vis_stddev.MulElements(pos_vis_mean_h);
pos_vis_stddev.Scale(-1.0);
pos_vis_stddev.AddVec(1.0/pos_vis.NumRows(),pos_vis_second_h);
/* set negative values to zero before the square root */
for (int32 i=0; i<pos_vis_stddev.Dim(); i++) {
if(pos_vis_stddev(i) < 0.0) {
KALDI_WARN << "Forcing the variance to be non-negative! (set to zero)"
<< pos_vis_stddev(i);
pos_vis_stddev(i) = 0.0;
}
}
pos_vis_stddev.ApplyPow(0.5);
//neg_vis
CuMatrix<BaseFloat> neg_vis_pow2(neg_vis);
neg_vis_pow2.MulElements(neg_vis);
CuVector<BaseFloat> neg_vis_second(neg_vis.NumCols());
neg_vis_second.AddRowSumMat(1.0,neg_vis_pow2,0.0);
CuVector<BaseFloat> neg_vis_mean(neg_vis.NumCols());
neg_vis_mean.AddRowSumMat(1.0/neg_vis.NumRows(),neg_vis,0.0);
Vector<BaseFloat> neg_vis_second_h(neg_vis_second.Dim());
neg_vis_second.CopyToVec(&neg_vis_second_h);
Vector<BaseFloat> neg_vis_mean_h(neg_vis_mean.Dim());
neg_vis_mean.CopyToVec(&neg_vis_mean_h);
Vector<BaseFloat> neg_vis_stddev(neg_vis_mean_h);
neg_vis_stddev.MulElements(neg_vis_mean_h);
neg_vis_stddev.Scale(-1.0);
neg_vis_stddev.AddVec(1.0/neg_vis.NumRows(),neg_vis_second_h);
/* set negative values to zero before the square root */
for (int32 i=0; i<neg_vis_stddev.Dim(); i++) {
if(neg_vis_stddev(i) < 0.0) {
KALDI_WARN << "Forcing the variance to be non-negative! (set to zero)"
<< neg_vis_stddev(i);
neg_vis_stddev(i) = 0.0;
}
}
neg_vis_stddev.ApplyPow(0.5);
//monitor the standard deviation discrepancy between pos_vis and neg_vis
if (pos_vis_stddev.Sum() * 2 < neg_vis_stddev.Sum()) {
//1) scale-down the weights and biases
BaseFloat scale = pos_vis_stddev.Sum() / neg_vis_stddev.Sum();
// check the data have no nan/inf:
CheckNanInf(pos_vis,"pos_vis");
CheckNanInf(pos_hid,"pos_hid");
CheckNanInf(neg_vis,"neg_vis");
CheckNanInf(neg_hid,"pos_hid");
// get standard deviations of pos_vis and neg_vis:
BaseFloat pos_vis_std = ComputeStdDev(pos_vis);
BaseFloat neg_vis_std = ComputeStdDev(neg_vis);
// monitor the standard deviation mismatch : data vs. reconstruction
if (pos_vis_std * 2 < neg_vis_std) {
// 1) scale-down the weights and biases
BaseFloat scale = pos_vis_std / neg_vis_std;
vis_hid_.Scale(scale);
vis_bias_.Scale(scale);
hid_bias_.Scale(scale);
//2) reduce the learning rate
// 2) reduce the learning rate
rbm_opts_.learn_rate *= 0.9;
//3) reset the momentum buffers
// 3) reset the momentum buffers
vis_hid_corr_.SetZero();
vis_bias_corr_.SetZero();
hid_bias_corr_.SetZero();
KALDI_WARN << "Discrepancy between pos_hid and neg_hid variances, "
KALDI_WARN << "Mismatch between pos_vis and neg_vis variances, "
<< "danger of weight explosion. a) Reducing weights with scale " << scale
<< " b) Lowering learning rate to " << rbm_opts_.learn_rate
<< " [pos_vis_stddev(~1.0):" << pos_vis_stddev.Sum()/pos_vis.NumCols()
<< ",neg_vis_stddev:" << neg_vis_stddev.Sum()/neg_vis.NumCols() << "]";
return; /* i.e. don't update weights with current stats */
<< " [pos_vis_std:" << pos_vis_std
<< ",neg_vis_std:" << neg_vis_std << "]";
return; /* i.e. don't update weights with current stats, as the update would be too BIG */
}
}
//
// End of Gaussian-Bernoulli weight-explosion check
// End of weight-explosion check
// We use these training hyper-parameters
......
......@@ -112,6 +112,32 @@ std::string MomentStatistics(const CuMatrix<Real> &mat) {
return MomentStatistics(mat_host);
}
/**
* Check that matrix contains no nan or inf
*/
template <typename Real>
void CheckNanInf(const CuMatrix<Real> &mat, const char *msg = "") {
Real sum = mat.Sum();
if(KALDI_ISINF(sum)) { KALDI_ERR << "'inf' in " << msg; }
if(KALDI_ISNAN(sum)) { KALDI_ERR << "'nan' in " << msg; }
}
/**
* Get the standard deviation of values in the matrix
*/
template <typename Real>
Real ComputeStdDev(const CuMatrix<Real> &mat) {
int32 N = mat.NumRows() * mat.NumCols();
Real mean = mat.Sum() / N;
CuMatrix<Real> pow_2(mat);
pow_2.MulElements(mat);
Real var = pow_2.Sum() / N - mean * mean;
if (var < 0.0) {
KALDI_WARN << "Forcing the variance to be non-negative! " << var << "->0.0";
var = 0.0;
}
return sqrt(var);
}
/**
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment