Commit 18057466 authored by Vijayaditya Peddinti's avatar Vijayaditya Peddinti
Browse files

Minor changes.

parent 03934435
...@@ -707,9 +707,9 @@ class ScaleComponent: public Component { ...@@ -707,9 +707,9 @@ class ScaleComponent: public Component {
class SumGroupComponent; // Forward declaration. class SumGroupComponent; // Forward declaration.
class AffineComponent; // Forward declaration. class AffineComponent; // Forward declaration.
class FixedScaleComponent; // Forward declaration. class FixedScaleComponent; // Forward declaration.
class SoftmaxComponent: public NonlinearComponent { class SoftmaxComponent: public NonlinearComponent {
public: public:
...@@ -1475,7 +1475,7 @@ class FixedScaleComponent: public Component { ...@@ -1475,7 +1475,7 @@ class FixedScaleComponent: public Component {
virtual void Write(std::ostream &os, bool binary) const; virtual void Write(std::ostream &os, bool binary) const;
protected: protected:
friend class AffineComponent; // necessary for collapse friend class AffineComponent; // necessary for collapse
CuVector<BaseFloat> scales_; CuVector<BaseFloat> scales_;
KALDI_DISALLOW_COPY_AND_ASSIGN(FixedScaleComponent); KALDI_DISALLOW_COPY_AND_ASSIGN(FixedScaleComponent);
}; };
......
...@@ -51,7 +51,6 @@ int32 Nnet::LeftContext() const { ...@@ -51,7 +51,6 @@ int32 Nnet::LeftContext() const {
// non-negative left context. In addition, the NnetExample also stores data // non-negative left context. In addition, the NnetExample also stores data
// left context as positive integer. To be compatible with these other classes // left context as positive integer. To be compatible with these other classes
// Nnet::LeftContext() returns a non-negative left context. // Nnet::LeftContext() returns a non-negative left context.
} }
int32 Nnet::RightContext() const { int32 Nnet::RightContext() const {
...@@ -66,8 +65,8 @@ int32 Nnet::RightContext() const { ...@@ -66,8 +65,8 @@ int32 Nnet::RightContext() const {
void Nnet::ComputeChunkInfo(int32 input_chunk_size, void Nnet::ComputeChunkInfo(int32 input_chunk_size,
int32 num_chunks, int32 num_chunks,
std::vector<ChunkInfo> *chunk_info_out) const { std::vector<ChunkInfo> *chunk_info_out) const {
// First compute the output-chunk indices for the last component in the network. // First compute the output-chunk indices for the last component in the
// we assume that the numbering of the input starts from zero. // network. we assume that the numbering of the input starts from zero.
int32 output_chunk_size = input_chunk_size - LeftContext() - RightContext(); int32 output_chunk_size = input_chunk_size - LeftContext() - RightContext();
KALDI_ASSERT(output_chunk_size > 0); KALDI_ASSERT(output_chunk_size > 0);
std::vector<int32> current_output_inds; std::vector<int32> current_output_inds;
...@@ -88,7 +87,7 @@ void Nnet::ComputeChunkInfo(int32 input_chunk_size, ...@@ -88,7 +87,7 @@ void Nnet::ComputeChunkInfo(int32 input_chunk_size,
for (int32 i = NumComponents() - 1; i >= 0; i--) { for (int32 i = NumComponents() - 1; i >= 0; i--) {
std::vector<int32> current_context = GetComponent(i).Context(); std::vector<int32> current_context = GetComponent(i).Context();
std::set<int32> current_input_ind_set; std::set<int32> current_input_ind_set;
for (size_t j = 0; j < current_context.size(); j++) for (size_t j = 0; j < current_context.size(); j++)
for (size_t k = 0; k < current_output_inds.size(); k++) for (size_t k = 0; k < current_output_inds.size(); k++)
current_input_ind_set.insert(current_context[j] + current_input_ind_set.insert(current_context[j] +
current_output_inds[k]); current_output_inds[k]);
...@@ -137,7 +136,6 @@ void Nnet::ComputeChunkInfo(int32 input_chunk_size, ...@@ -137,7 +136,6 @@ void Nnet::ComputeChunkInfo(int32 input_chunk_size,
(*chunk_info_out)[i].Check(); (*chunk_info_out)[i].Check();
// (*chunk_info_out)[i].ToString(); // (*chunk_info_out)[i].ToString();
} }
} }
const Component& Nnet::GetComponent(int32 component) const { const Component& Nnet::GetComponent(int32 component) const {
...@@ -359,7 +357,8 @@ void Nnet::ResizeOutputLayer(int32 new_num_pdfs) { ...@@ -359,7 +357,8 @@ void Nnet::ResizeOutputLayer(int32 new_num_pdfs) {
KALDI_ASSERT(new_num_pdfs > 0); KALDI_ASSERT(new_num_pdfs > 0);
KALDI_ASSERT(NumComponents() > 2); KALDI_ASSERT(NumComponents() > 2);
int32 nc = NumComponents(); int32 nc = NumComponents();
SumGroupComponent *sgc = dynamic_cast<SumGroupComponent*>(components_[nc - 1]); SumGroupComponent *sgc =
dynamic_cast<SumGroupComponent*>(components_[nc - 1]);
if (sgc != NULL) { if (sgc != NULL) {
// Remove it. We'll resize things later. // Remove it. We'll resize things later.
delete sgc; delete sgc;
...@@ -367,7 +366,6 @@ void Nnet::ResizeOutputLayer(int32 new_num_pdfs) { ...@@ -367,7 +366,6 @@ void Nnet::ResizeOutputLayer(int32 new_num_pdfs) {
components_.begin() + nc); components_.begin() + nc);
nc--; nc--;
} }
SoftmaxComponent *sc; SoftmaxComponent *sc;
if ((sc = dynamic_cast<SoftmaxComponent*>(components_[nc - 1])) == NULL) if ((sc = dynamic_cast<SoftmaxComponent*>(components_[nc - 1])) == NULL)
KALDI_ERR << "Expected last component to be SoftmaxComponent."; KALDI_ERR << "Expected last component to be SoftmaxComponent.";
...@@ -381,11 +379,10 @@ void Nnet::ResizeOutputLayer(int32 new_num_pdfs) { ...@@ -381,11 +379,10 @@ void Nnet::ResizeOutputLayer(int32 new_num_pdfs) {
dynamic_cast<FixedScaleComponent*>( dynamic_cast<FixedScaleComponent*>(
components_[final_affine_component_index]); components_[final_affine_component_index]);
if (fsc != NULL) { if (fsc != NULL) {
has_fixed_scale_component = true; has_fixed_scale_component = true;
fixed_scale_component_index = nc - 2; fixed_scale_component_index = nc - 2;
final_affine_component_index = nc - 3; final_affine_component_index = nc - 3;
} }
// note: it could be child class of AffineComponent. // note: it could be child class of AffineComponent.
AffineComponent *ac = dynamic_cast<AffineComponent*>( AffineComponent *ac = dynamic_cast<AffineComponent*>(
components_[final_affine_component_index]); components_[final_affine_component_index]);
...@@ -394,7 +391,8 @@ void Nnet::ResizeOutputLayer(int32 new_num_pdfs) { ...@@ -394,7 +391,8 @@ void Nnet::ResizeOutputLayer(int32 new_num_pdfs) {
<< "AffineComponent)."; << "AffineComponent).";
if (has_fixed_scale_component) { if (has_fixed_scale_component) {
// collapse the fixed_scale_component with the affine_component before it // collapse the fixed_scale_component with the affine_component before it
AffineComponent *ac_new = dynamic_cast<AffineComponent*>(ac->CollapseWithNext(*fsc)); AffineComponent *ac_new =
dynamic_cast<AffineComponent*>(ac->CollapseWithNext(*fsc));
KALDI_ASSERT(ac_new != NULL); KALDI_ASSERT(ac_new != NULL);
delete fsc; delete fsc;
delete ac; delete ac;
...@@ -404,11 +402,11 @@ void Nnet::ResizeOutputLayer(int32 new_num_pdfs) { ...@@ -404,11 +402,11 @@ void Nnet::ResizeOutputLayer(int32 new_num_pdfs) {
ac = ac_new; ac = ac_new;
softmax_component_index = softmax_component_index - 1; softmax_component_index = softmax_component_index - 1;
} }
ac->Resize(ac->InputDim(), new_num_pdfs); ac->Resize(ac->InputDim(), new_num_pdfs);
// Remove the softmax component, and replace it with a new one // Remove the softmax component, and replace it with a new one
delete components_[softmax_component_index]; delete components_[softmax_component_index];
components_[softmax_component_index] = new SoftmaxComponent(new_num_pdfs); components_[softmax_component_index] = new SoftmaxComponent(new_num_pdfs);
this->SetIndexes(); // used for debugging
this->Check(); this->Check();
} }
...@@ -682,8 +680,9 @@ void Nnet::Vectorize(VectorBase<BaseFloat> *params) const { ...@@ -682,8 +680,9 @@ void Nnet::Vectorize(VectorBase<BaseFloat> *params) const {
KALDI_ASSERT(offset == GetParameterDim()); KALDI_ASSERT(offset == GetParameterDim());
} }
void Nnet::ResetGenerators() { // resets random-number generators for all random void Nnet::ResetGenerators() {
// components. // resets random-number generators for all random
// components.
for (int32 c = 0; c < NumComponents(); c++) { for (int32 c = 0; c < NumComponents(); c++) {
RandomComponent *rc = dynamic_cast<RandomComponent*>( RandomComponent *rc = dynamic_cast<RandomComponent*>(
&(GetComponent(c))); &(GetComponent(c)));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment