Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Open sidebar
Abdelwahab HEBA
kaldi_2015
Commits
03934435
Commit
03934435
authored
Jul 21, 2015
by
Vijayaditya Peddinti
Browse files
Modified ResizeOutputLayer method in nnet-nnet.cc to accommodate networks with FixedScaleComponents
parent
af37d842
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
48 additions
and
4 deletions
+48
-4
src/nnet2/nnet-component.cc
src/nnet2/nnet-component.cc
+14
-0
src/nnet2/nnet-component.h
src/nnet2/nnet-component.h
+3
-0
src/nnet2/nnet-nnet.cc
src/nnet2/nnet-nnet.cc
+31
-4
No files found.
src/nnet2/nnet-component.cc
View file @
03934435
...
...
@@ -1407,6 +1407,20 @@ Component *AffineComponent::CollapseWithNext(
return
ans
;
}
Component
*
AffineComponent
::
CollapseWithNext
(
const
FixedScaleComponent
&
next_component
)
const
{
KALDI_ASSERT
(
this
->
OutputDim
()
==
next_component
.
InputDim
());
AffineComponent
*
ans
=
dynamic_cast
<
AffineComponent
*>
(
this
->
Copy
());
KALDI_ASSERT
(
ans
!=
NULL
);
ans
->
linear_params_
.
MulRowsVec
(
next_component
.
scales_
);
ans
->
bias_params_
.
MulElements
(
next_component
.
scales_
);
return
ans
;
}
Component
*
AffineComponent
::
CollapseWithPrevious
(
const
FixedAffineComponent
&
prev_component
)
const
{
// If at least one was non-updatable, make the whole non-updatable.
...
...
src/nnet2/nnet-component.h
View file @
03934435
...
...
@@ -709,6 +709,7 @@ class ScaleComponent: public Component {
class
SumGroupComponent
;
// Forward declaration.
class
AffineComponent
;
// Forward declaration.
class
FixedScaleComponent
;
// Forward declaration.
class
SoftmaxComponent
:
public
NonlinearComponent
{
public:
...
...
@@ -803,6 +804,7 @@ class AffineComponent: public UpdatableComponent {
// FixedLinearComponent yet.
Component
*
CollapseWithNext
(
const
AffineComponent
&
next
)
const
;
Component
*
CollapseWithNext
(
const
FixedAffineComponent
&
next
)
const
;
Component
*
CollapseWithNext
(
const
FixedScaleComponent
&
next
)
const
;
Component
*
CollapseWithPrevious
(
const
FixedAffineComponent
&
prev
)
const
;
virtual
std
::
string
Info
()
const
;
...
...
@@ -1473,6 +1475,7 @@ class FixedScaleComponent: public Component {
virtual
void
Write
(
std
::
ostream
&
os
,
bool
binary
)
const
;
protected:
friend
class
AffineComponent
;
// necessary for collapse
CuVector
<
BaseFloat
>
scales_
;
KALDI_DISALLOW_COPY_AND_ASSIGN
(
FixedScaleComponent
);
};
...
...
src/nnet2/nnet-nnet.cc
View file @
03934435
...
...
@@ -372,16 +372,43 @@ void Nnet::ResizeOutputLayer(int32 new_num_pdfs) {
if
((
sc
=
dynamic_cast
<
SoftmaxComponent
*>
(
components_
[
nc
-
1
]))
==
NULL
)
KALDI_ERR
<<
"Expected last component to be SoftmaxComponent."
;
// check if nc-1 has a FixedScaleComponent
bool
has_fixed_scale_component
=
false
;
int32
fixed_scale_component_index
=
-
1
;
int32
final_affine_component_index
=
nc
-
2
;
int32
softmax_component_index
=
nc
-
1
;
FixedScaleComponent
*
fsc
=
dynamic_cast
<
FixedScaleComponent
*>
(
components_
[
final_affine_component_index
]);
if
(
fsc
!=
NULL
)
{
has_fixed_scale_component
=
true
;
fixed_scale_component_index
=
nc
-
2
;
final_affine_component_index
=
nc
-
3
;
}
// note: it could be child class of AffineComponent.
AffineComponent
*
ac
=
dynamic_cast
<
AffineComponent
*>
(
components_
[
nc
-
2
]);
AffineComponent
*
ac
=
dynamic_cast
<
AffineComponent
*>
(
components_
[
final_affine_component_index
]);
if
(
ac
==
NULL
)
KALDI_ERR
<<
"Network doesn't have expected structure (didn't find final "
<<
"AffineComponent)."
;
if
(
has_fixed_scale_component
)
{
// collapse the fixed_scale_component with the affine_component before it
AffineComponent
*
ac_new
=
dynamic_cast
<
AffineComponent
*>
(
ac
->
CollapseWithNext
(
*
fsc
));
KALDI_ASSERT
(
ac_new
!=
NULL
);
delete
fsc
;
delete
ac
;
components_
.
erase
(
components_
.
begin
()
+
fixed_scale_component_index
,
components_
.
begin
()
+
(
fixed_scale_component_index
+
1
));
components_
[
final_affine_component_index
]
=
ac_new
;
ac
=
ac_new
;
softmax_component_index
=
softmax_component_index
-
1
;
}
ac
->
Resize
(
ac
->
InputDim
(),
new_num_pdfs
);
// Remove the softmax component, and replace it with a new one
delete
components_
[
nc
-
1
];
components_
[
nc
-
1
]
=
new
SoftmaxComponent
(
new_num_pdfs
);
delete
components_
[
softmax_component_index
];
components_
[
softmax_component_index
]
=
new
SoftmaxComponent
(
new_num_pdfs
);
this
->
Check
();
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment