Commit d4a25434 authored by Karel Vesely's avatar Karel Vesely
Browse files

trunk,nnet1: cosmetic changes, allowing initializing <LinearTransform> by loading from file.


git-svn-id: https://svn.code.sf.net/p/kaldi/code/trunk@5196 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
parent f161fca4
......@@ -45,30 +45,43 @@ class LinearTransform : public UpdatableComponent {
// define options
float param_stddev = 0.1;
float learn_rate_coef = 1.0;
std::string read_matrix_file;
// parse config
std::string token;
while (!is.eof()) {
ReadToken(is, false, &token);
/**/ if (token == "<ParamStddev>") ReadBasicType(is, false, &param_stddev);
else if (token == "<LearnRateCoef>") ReadBasicType(is, false, &learn_rate_coef);
else if (token == "<ReadMatrix>") ReadToken(is, false, &read_matrix_file);
else KALDI_ERR << "Unknown token " << token << ", a typo in config?"
<< " (ParamStddev)";
<< " (ParamStddev|ReadMatrix|LearnRateCoef)";
is >> std::ws; // eat-up whitespace
}
//
// initialize
//
Matrix<BaseFloat> mat(output_dim_, input_dim_);
for (int32 r=0; r<output_dim_; r++) {
for (int32 c=0; c<input_dim_; c++) {
mat(r,c) = param_stddev * RandGauss(); // 0-mean Gauss with given std_dev
if (read_matrix_file != "") { // load from file,
bool binary;
Input in(read_matrix_file, &binary);
linearity_.Read(in.Stream(), binary);
in.Close();
KALDI_LOG << "Loaded <LinearTransform> matrix from file : " << read_matrix_file;
} else { // random initialization,
linearity_.Resize(output_dim_, input_dim_);
for (int32 r=0; r<output_dim_; r++) {
for (int32 c=0; c<input_dim_; c++) {
linearity_(r,c) = param_stddev * RandGauss(); // 0-mean Gauss with given std_dev
}
}
}
linearity_ = mat;
//
learn_rate_coef_ = learn_rate_coef;
//
// check dims,
KALDI_ASSERT(linearity_.NumRows() == output_dim_);
KALDI_ASSERT(linearity_.NumCols() == input_dim_);
}
void ReadData(std::istream &is, bool binary) {
......
......@@ -76,6 +76,7 @@ class ParallelComponent : public UpdatableComponent {
Nnet nnet;
nnet.Read(nested_nnet_filename[i]);
nnet_.push_back(nnet);
KALDI_LOG << "Loaded nested <Nnet> from file : " << nested_nnet_filename[i];
}
}
// initialize nnets from prototypes
......@@ -84,6 +85,7 @@ class ParallelComponent : public UpdatableComponent {
Nnet nnet;
nnet.Init(nested_nnet_proto[i]);
nnet_.push_back(nnet);
KALDI_LOG << "Initialized nested <Nnet> from prototype : " << nested_nnet_proto[i];
}
}
// check dim-sum of nested nnets
......
......@@ -28,7 +28,7 @@ int main(int argc, char *argv[]) {
typedef kaldi::int32 int32;
const char *usage =
"Initialize Neural Network parameters according to a prototype.\n"
"Initialize Neural Network parameters according to a prototype (nnet1).\n"
"Usage: nnet-initialize [options] <nnet-prototype-in> <nnet-out>\n"
"e.g.:\n"
" nnet-initialize --binary=false nnet.proto nnet.init\n";
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment