Commit 4eaefc9b authored by Hainan Xu's avatar Hainan Xu
Browse files

implemented nnet1-to-raw-nnet.cc

git-svn-id: https://svn.code.sf.net/p/kaldi/code/trunk@4254 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
parent 03d77c7c
// nnet2bin/nnet1-to-raw-nnet.cc
// Copyright 2013 Johns Hopkins University (author: Daniel Povey)
// Copyright 2013 Johns Hopkins University (author: Daniel Povey, Hainan Xu)
// See ../../COPYING for clarification regarding multiple authors
//
......@@ -22,8 +22,9 @@
#include "hmm/transition-model.h"
#include "nnet/nnet-nnet.h"
#include "nnet/nnet-affine-transform.h"
// may need more includes here.
#include "nnet/nnet-activation.h"
#include "nnet2/nnet-nnet.h"
#include "nnet2/nnet-component.h"
namespace kaldi {
......@@ -40,12 +41,86 @@ nnet2::Component *ConvertAffineTransformComponent(
learning_rate);
}
nnet2::Component *ConvertSoftmaxComponent(
const nnet1::Component &nnet1_component) {
const nnet1::Softmax *softmax =
dynamic_cast<const nnet1::Softmax*>(&nnet1_component);
KALDI_ASSERT(softmax != NULL);
return new nnet2::SoftmaxComponent(softmax->InputDim());
}
nnet2::Component *ConvertSigmoidComponent(
const nnet1::Component &nnet1_component) {
const nnet1::Sigmoid *sigmoid =
dynamic_cast<const nnet1::Sigmoid*>(&nnet1_component);
KALDI_ASSERT(sigmoid != NULL);
return new nnet2::SoftmaxComponent(sigmoid->InputDim());
}
nnet2::Component *ConvertSpliceComponent(
const nnet1::Component &nnet1_component) {
const nnet1::Splice *splice =
dynamic_cast<const nnet1::Splice*>(&nnet1_component);
KALDI_ASSERT(splice != NULL);
int32 low, high;
std::vector<int32> frame_offsets;
std::ostringstream ostr;
splice->WriteData(ostr, false);
std::istringstream istr(ostr.str());
ReadIntegerVector(istr, false, &frame_offsets);
for (size_t i = 1; i < frame_offsets.size(); i++) {
KALDI_ASSERT(frame_offsets[i-1] + 1 == frame_offsets[i]);
}
low = frame_offsets[0];
high = frame_offsets[frame_offsets.size() - 1];
nnet2::SpliceComponent *res = new nnet2::SpliceComponent();
res->Init(splice->InputDim(), -low, high);
return res;
}
nnet2::Component *ConvertAddShiftComponent(
const nnet1::Component &nnet1_component) {
const nnet1::AddShift *add_shift =
dynamic_cast<const nnet1::AddShift*>(&nnet1_component);
KALDI_ASSERT(add_shift != NULL);
Vector<BaseFloat> bias;
add_shift->GetParams(&bias);
CuVector<BaseFloat> cu_bias(bias);
nnet2::FixedBiasComponent *res = new nnet2::FixedBiasComponent();
res->Init(cu_bias);
return res;
}
nnet2::Component *ConvertRescaleComponent(
const nnet1::Component &nnet1_component) {
const nnet1::Rescale *rescale =
dynamic_cast<const nnet1::Rescale*>(&nnet1_component);
KALDI_ASSERT(rescale != NULL);
Vector<BaseFloat> scale;
rescale->GetParams(&scale);
CuVector<BaseFloat> cu_scale(scale);
nnet2::FixedScaleComponent *res = new nnet2::FixedScaleComponent();
res->Init(cu_scale);
return res;
}
nnet2::Component *ConvertComponent(const nnet1::Component &nnet1_component) {
nnet1::Component::ComponentType type_in = nnet1_component.GetType();
switch (type_in) {
case nnet1::Component::kAffineTransform:
return ConvertAffineTransformComponent(nnet1_component);
/* case nnet1::Component::kSoftmax:
case nnet1::Component::kSoftmax:
return ConvertSoftmaxComponent(nnet1_component);
case nnet1::Component::kSigmoid:
return ConvertSigmoidComponent(nnet1_component);
......@@ -55,9 +130,8 @@ nnet2::Component *ConvertComponent(const nnet1::Component &nnet1_component) {
// -5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5 .
case nnet1::Component::kAddShift:
return ConvertAddShiftComponent(nnet1_component); // convert to FixedBiasComponent
case nnet1::Component::kRescale
case nnet1::Component::kRescale:
return ConvertRescaleComponent(nnet1_component); // convert to FixedScaleComponent
*/
default: KALDI_ERR << "Un-handled nnet1 component type "
<< nnet1::Component::TypeToMarker(type_in);
return NULL;
......@@ -67,8 +141,18 @@ nnet2::Component *ConvertComponent(const nnet1::Component &nnet1_component) {
nnet2::Nnet *ConvertNnet1ToNnet2(const nnet1::Nnet &nnet1) {
// get a vector of nnet2::Component pointers and initialize the nnet2::Nnet with it.
return NULL;
size_t size = nnet1.NumComponents();
std::vector<nnet2::Component*> *components = new std::vector<nnet2::Component*>();
components->resize(size);
for (size_t i = 0; i < size; i++) {
(*components)[i] = ConvertComponent(nnet1.GetComponent(i));
}
nnet2::Nnet *res = new nnet2::Nnet();
res->Init(components);
// not de-allocate the memory for components
// since the nnet takes the ownership
return res;
}
}
......@@ -87,8 +171,6 @@ int main(int argc, char *argv[]) {
"e.g.:\n"
" nnet1-to-raw-nnet srcdir/final.nnet - | nnet-am-init dest/tree dest/topo - dest/0.mdl\n";
KALDI_ERR << "This program is not finished.";
bool binary_write = true;
int32 srand_seed = 0;
......@@ -112,6 +194,7 @@ int main(int argc, char *argv[]) {
WriteKaldiObject(*nnet2, raw_nnet2_wxfilename, binary_write);
KALDI_LOG << "Converted nnet1 neural net to raw nnet2 and wrote it to "
<< PrintableWxfilename(raw_nnet2_wxfilename);
delete nnet2;
return 0;
} catch(const std::exception &e) {
std::cerr << e.what() << '\n';
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment