Commit 03934435 authored by Vijayaditya Peddinti's avatar Vijayaditya Peddinti
Browse files

Modified ResizeOutputLayer method in nnet-nnet.cc to accommodate networks with FixedScaleComponents

parent af37d842
...@@ -1407,6 +1407,20 @@ Component *AffineComponent::CollapseWithNext( ...@@ -1407,6 +1407,20 @@ Component *AffineComponent::CollapseWithNext(
return ans; return ans;
} }
Component *AffineComponent::CollapseWithNext(
const FixedScaleComponent &next_component) const {
KALDI_ASSERT(this->OutputDim() == next_component.InputDim());
AffineComponent *ans =
dynamic_cast<AffineComponent*>(this->Copy());
KALDI_ASSERT(ans != NULL);
ans->linear_params_.MulRowsVec(next_component.scales_);
ans->bias_params_.MulElements(next_component.scales_);
return ans;
}
Component *AffineComponent::CollapseWithPrevious( Component *AffineComponent::CollapseWithPrevious(
const FixedAffineComponent &prev_component) const { const FixedAffineComponent &prev_component) const {
// If at least one was non-updatable, make the whole non-updatable. // If at least one was non-updatable, make the whole non-updatable.
......
...@@ -709,6 +709,7 @@ class ScaleComponent: public Component { ...@@ -709,6 +709,7 @@ class ScaleComponent: public Component {
class SumGroupComponent; // Forward declaration. class SumGroupComponent; // Forward declaration.
class AffineComponent; // Forward declaration. class AffineComponent; // Forward declaration.
class FixedScaleComponent; // Forward declaration.
class SoftmaxComponent: public NonlinearComponent { class SoftmaxComponent: public NonlinearComponent {
public: public:
...@@ -803,6 +804,7 @@ class AffineComponent: public UpdatableComponent { ...@@ -803,6 +804,7 @@ class AffineComponent: public UpdatableComponent {
// FixedLinearComponent yet. // FixedLinearComponent yet.
Component *CollapseWithNext(const AffineComponent &next) const ; Component *CollapseWithNext(const AffineComponent &next) const ;
Component *CollapseWithNext(const FixedAffineComponent &next) const; Component *CollapseWithNext(const FixedAffineComponent &next) const;
Component *CollapseWithNext(const FixedScaleComponent &next) const;
Component *CollapseWithPrevious(const FixedAffineComponent &prev) const; Component *CollapseWithPrevious(const FixedAffineComponent &prev) const;
virtual std::string Info() const; virtual std::string Info() const;
...@@ -1473,6 +1475,7 @@ class FixedScaleComponent: public Component { ...@@ -1473,6 +1475,7 @@ class FixedScaleComponent: public Component {
virtual void Write(std::ostream &os, bool binary) const; virtual void Write(std::ostream &os, bool binary) const;
protected: protected:
friend class AffineComponent; // necessary for collapse
CuVector<BaseFloat> scales_; CuVector<BaseFloat> scales_;
KALDI_DISALLOW_COPY_AND_ASSIGN(FixedScaleComponent); KALDI_DISALLOW_COPY_AND_ASSIGN(FixedScaleComponent);
}; };
......
...@@ -372,16 +372,43 @@ void Nnet::ResizeOutputLayer(int32 new_num_pdfs) { ...@@ -372,16 +372,43 @@ void Nnet::ResizeOutputLayer(int32 new_num_pdfs) {
if ((sc = dynamic_cast<SoftmaxComponent*>(components_[nc - 1])) == NULL) if ((sc = dynamic_cast<SoftmaxComponent*>(components_[nc - 1])) == NULL)
KALDI_ERR << "Expected last component to be SoftmaxComponent."; KALDI_ERR << "Expected last component to be SoftmaxComponent.";
// check if nc-1 has a FixedScaleComponent
bool has_fixed_scale_component = false;
int32 fixed_scale_component_index = -1;
int32 final_affine_component_index = nc - 2;
int32 softmax_component_index = nc - 1;
FixedScaleComponent *fsc =
dynamic_cast<FixedScaleComponent*>(
components_[final_affine_component_index]);
if (fsc != NULL) {
has_fixed_scale_component = true;
fixed_scale_component_index = nc - 2;
final_affine_component_index = nc - 3;
}
// note: it could be child class of AffineComponent. // note: it could be child class of AffineComponent.
AffineComponent *ac = dynamic_cast<AffineComponent*>(components_[nc - 2]); AffineComponent *ac = dynamic_cast<AffineComponent*>(
components_[final_affine_component_index]);
if (ac == NULL) if (ac == NULL)
KALDI_ERR << "Network doesn't have expected structure (didn't find final " KALDI_ERR << "Network doesn't have expected structure (didn't find final "
<< "AffineComponent)."; << "AffineComponent).";
if (has_fixed_scale_component) {
// collapse the fixed_scale_component with the affine_component before it
AffineComponent *ac_new = dynamic_cast<AffineComponent*>(ac->CollapseWithNext(*fsc));
KALDI_ASSERT(ac_new != NULL);
delete fsc;
delete ac;
components_.erase(components_.begin() + fixed_scale_component_index,
components_.begin() + (fixed_scale_component_index + 1));
components_[final_affine_component_index] = ac_new;
ac = ac_new;
softmax_component_index = softmax_component_index - 1;
}
ac->Resize(ac->InputDim(), new_num_pdfs); ac->Resize(ac->InputDim(), new_num_pdfs);
// Remove the softmax component, and replace it with a new one // Remove the softmax component, and replace it with a new one
delete components_[nc - 1]; delete components_[softmax_component_index];
components_[nc - 1] = new SoftmaxComponent(new_num_pdfs); components_[softmax_component_index] = new SoftmaxComponent(new_num_pdfs);
this->Check(); this->Check();
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment