Commit 1a429da1 authored by Karel Vesely's avatar Karel Vesely
Browse files

trunk,nnet1 : changing CuMatrix -> CuMatrixBase and CuVector -> CuVectorBase...

trunk,nnet1 : changing CuMatrix -> CuMatrixBase and CuVector -> CuVectorBase in interfaces, where applicable. Note the 'Base' classes mean that the dimensions are fixed and object has pre-allocated data.



git-svn-id: https://svn.code.sf.net/p/kaldi/code/trunk@4272 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
parent 70be4ae0
......@@ -127,7 +127,7 @@ void Randomize(const CuMatrixBase<Real> &src,
template<typename Real>
void Splice(const CuMatrix<Real> &src, const CuArray<int32> &frame_offsets,
void Splice(const CuMatrixBase<Real> &src, const CuArray<int32> &frame_offsets,
CuMatrixBase<Real> *tgt) {
KALDI_ASSERT(src.NumCols()*frame_offsets.Dim() == tgt->NumCols());
......@@ -167,7 +167,7 @@ void Splice(const CuMatrix<Real> &src, const CuArray<int32> &frame_offsets,
template<typename Real>
void Copy(const CuMatrix<Real> &src, const CuArray<int32> &copy_from_indices,
void Copy(const CuMatrixBase<Real> &src, const CuArray<int32> &copy_from_indices,
CuMatrixBase<Real> *tgt) {
KALDI_ASSERT(copy_from_indices.Dim() == tgt->NumCols());
......@@ -208,16 +208,16 @@ template
void RegularizeL1(CuMatrixBase<double> *weight, CuMatrixBase<double> *grad, double l1, double lr);
template
void Splice(const CuMatrix<float> &src, const CuArray<int32> &frame_offsets,
void Splice(const CuMatrixBase<float> &src, const CuArray<int32> &frame_offsets,
CuMatrixBase<float> *tgt);
template
void Splice(const CuMatrix<double> &src, const CuArray<int32> &frame_offsets,
void Splice(const CuMatrixBase<double> &src, const CuArray<int32> &frame_offsets,
CuMatrixBase<double> *tgt);
template
void Copy(const CuMatrix<float> &src, const CuArray<int32> &copy_from_indices,
void Copy(const CuMatrixBase<float> &src, const CuArray<int32> &copy_from_indices,
CuMatrixBase<float> *tgt);
template
void Copy(const CuMatrix<double> &src, const CuArray<int32> &copy_from_indices,
void Copy(const CuMatrixBase<double> &src, const CuArray<int32> &copy_from_indices,
CuMatrixBase<double> *tgt);
template
......
......@@ -59,7 +59,7 @@ void Randomize(const CuMatrixBase<Real> &src,
/// is replaced by src(src.NumRows()-1, j) or src(0, j) respectively, to avoid
/// an index out of bounds.
template<typename Real>
void Splice(const CuMatrix<Real> &src,
void Splice(const CuMatrixBase<Real> &src,
const CuArray<int32> &frame_offsets,
CuMatrixBase<Real> *tgt);
......@@ -69,7 +69,7 @@ void Splice(const CuMatrix<Real> &src,
/// in the src matrix. As a result, tgt(i, j) == src(i, copy_from_indices[j]).
/// Also see CuMatrix::CopyCols(), which is more general.
template<typename Real>
void Copy(const CuMatrix<Real> &src,
void Copy(const CuMatrixBase<Real> &src,
const CuArray<int32> &copy_from_indices,
CuMatrixBase<Real> *tgt);
......
......@@ -76,10 +76,10 @@ class CuMatrixBase {
friend class CuBlockMatrix<Real>;
friend void cu::RegularizeL1<Real>(CuMatrixBase<Real> *weight,
CuMatrixBase<Real> *grad, Real l1, Real lr);
friend void cu::Splice<Real>(const CuMatrix<Real> &src,
friend void cu::Splice<Real>(const CuMatrixBase<Real> &src,
const CuArray<int32> &frame_offsets,
CuMatrixBase<Real> *tgt);
friend void cu::Copy<Real>(const CuMatrix<Real> &src,
friend void cu::Copy<Real>(const CuMatrixBase<Real> &src,
const CuArray<int32> &copy_from_indices,
CuMatrixBase<Real> *tgt);
friend void cu::Randomize<Real>(const CuMatrixBase<Real> &src,
......
......@@ -57,7 +57,7 @@ class CuVectorBase {
template <typename OtherReal>
friend OtherReal VecVec(const CuVectorBase<OtherReal> &v1,
const CuVectorBase<OtherReal> &v2);
friend void cu::Splice<Real>(const CuMatrix<Real> &src,
friend void cu::Splice<Real>(const CuMatrixBase<Real> &src,
const CuArray<int32> &frame_offsets,
CuMatrixBase<Real> *tgt);
friend class CuRand<Real>;
......
......@@ -39,13 +39,13 @@ class Softmax : public Component {
Component* Copy() const { return new Softmax(*this); }
ComponentType GetType() const { return kSoftmax; }
void PropagateFnc(const CuMatrix<BaseFloat> &in, CuMatrix<BaseFloat> *out) {
void PropagateFnc(const CuMatrixBase<BaseFloat> &in, CuMatrixBase<BaseFloat> *out) {
// y = e^x_j/sum_j(e^x_j)
out->ApplySoftMaxPerRow(in);
}
void BackpropagateFnc(const CuMatrix<BaseFloat> &in, const CuMatrix<BaseFloat> &out,
const CuMatrix<BaseFloat> &out_diff, CuMatrix<BaseFloat> *in_diff) {
void BackpropagateFnc(const CuMatrixBase<BaseFloat> &in, const CuMatrixBase<BaseFloat> &out,
const CuMatrixBase<BaseFloat> &out_diff, CuMatrixBase<BaseFloat> *in_diff) {
// simply copy the error derivative
// (ie. assume crossentropy error function,
// while in_diff contains (net_output-target) :
......@@ -68,13 +68,13 @@ class Sigmoid : public Component {
Component* Copy() const { return new Sigmoid(*this); }
ComponentType GetType() const { return kSigmoid; }
void PropagateFnc(const CuMatrix<BaseFloat> &in, CuMatrix<BaseFloat> *out) {
void PropagateFnc(const CuMatrixBase<BaseFloat> &in, CuMatrixBase<BaseFloat> *out) {
// y = 1/(1+e^-x)
out->Sigmoid(in);
}
void BackpropagateFnc(const CuMatrix<BaseFloat> &in, const CuMatrix<BaseFloat> &out,
const CuMatrix<BaseFloat> &out_diff, CuMatrix<BaseFloat> *in_diff) {
void BackpropagateFnc(const CuMatrixBase<BaseFloat> &in, const CuMatrixBase<BaseFloat> &out,
const CuMatrixBase<BaseFloat> &out_diff, CuMatrixBase<BaseFloat> *in_diff) {
// ey = y(1-y)ex
in_diff->DiffSigmoid(out, out_diff);
}
......@@ -93,13 +93,13 @@ class Tanh : public Component {
Component* Copy() const { return new Tanh(*this); }
ComponentType GetType() const { return kTanh; }
void PropagateFnc(const CuMatrix<BaseFloat> &in, CuMatrix<BaseFloat> *out) {
void PropagateFnc(const CuMatrixBase<BaseFloat> &in, CuMatrixBase<BaseFloat> *out) {
// y = (e^x - e^(-x)) / (e^x + e^(-x))
out->Tanh(in);
}
void BackpropagateFnc(const CuMatrix<BaseFloat> &in, const CuMatrix<BaseFloat> &out,
const CuMatrix<BaseFloat> &out_diff, CuMatrix<BaseFloat> *in_diff) {
void BackpropagateFnc(const CuMatrixBase<BaseFloat> &in, const CuMatrixBase<BaseFloat> &out,
const CuMatrixBase<BaseFloat> &out_diff, CuMatrixBase<BaseFloat> *in_diff) {
// ey = (1 - y^2)ex
in_diff->DiffTanh(out, out_diff);
}
......@@ -118,7 +118,7 @@ class Dropout : public Component {
Component* Copy() const { return new Dropout(*this); }
ComponentType GetType() const { return kDropout; }
void PropagateFnc(const CuMatrix<BaseFloat> &in, CuMatrix<BaseFloat> *out) {
void PropagateFnc(const CuMatrixBase<BaseFloat> &in, CuMatrixBase<BaseFloat> *out) {
out->CopyFromMat(in);
// switch off 50% of the inputs...
dropout_mask_.Resize(out->NumRows(),out->NumCols());
......@@ -127,8 +127,8 @@ class Dropout : public Component {
out->MulElements(dropout_mask_);
}
void BackpropagateFnc(const CuMatrix<BaseFloat> &in, const CuMatrix<BaseFloat> &out,
const CuMatrix<BaseFloat> &out_diff, CuMatrix<BaseFloat> *in_diff) {
void BackpropagateFnc(const CuMatrixBase<BaseFloat> &in, const CuMatrixBase<BaseFloat> &out,
const CuMatrixBase<BaseFloat> &out_diff, CuMatrixBase<BaseFloat> *in_diff) {
in_diff->CopyFromMat(out_diff);
// use same mask on the error derivatives...
in_diff->MulElements(dropout_mask_);
......
......@@ -132,21 +132,21 @@ class AffineTransform : public UpdatableComponent {
}
void PropagateFnc(const CuMatrix<BaseFloat> &in, CuMatrix<BaseFloat> *out) {
void PropagateFnc(const CuMatrixBase<BaseFloat> &in, CuMatrixBase<BaseFloat> *out) {
// precopy bias
out->AddVecToRows(1.0, bias_, 0.0);
// multiply by weights^t
out->AddMatMat(1.0, in, kNoTrans, linearity_, kTrans, 1.0);
}
void BackpropagateFnc(const CuMatrix<BaseFloat> &in, const CuMatrix<BaseFloat> &out,
const CuMatrix<BaseFloat> &out_diff, CuMatrix<BaseFloat> *in_diff) {
void BackpropagateFnc(const CuMatrixBase<BaseFloat> &in, const CuMatrixBase<BaseFloat> &out,
const CuMatrixBase<BaseFloat> &out_diff, CuMatrixBase<BaseFloat> *in_diff) {
// multiply error derivative by weights
in_diff->AddMatMat(1.0, out_diff, kNoTrans, linearity_, kNoTrans, 0.0);
}
void Update(const CuMatrix<BaseFloat> &input, const CuMatrix<BaseFloat> &diff) {
void Update(const CuMatrixBase<BaseFloat> &input, const CuMatrixBase<BaseFloat> &diff) {
// we use following hyperparameters from the option class
const BaseFloat lr = opts_.learn_rate;
const BaseFloat mmt = opts_.momentum;
......@@ -171,30 +171,30 @@ class AffineTransform : public UpdatableComponent {
}
/// Accessors to the component parameters
const CuVector<BaseFloat>& GetBias() const {
const CuVectorBase<BaseFloat>& GetBias() const {
return bias_;
}
void SetBias(const CuVector<BaseFloat>& bias) {
void SetBias(const CuVectorBase<BaseFloat>& bias) {
KALDI_ASSERT(bias.Dim() == bias_.Dim());
bias_.CopyFromVec(bias);
}
const CuMatrix<BaseFloat>& GetLinearity() const {
const CuMatrixBase<BaseFloat>& GetLinearity() const {
return linearity_;
}
void SetLinearity(const CuMatrix<BaseFloat>& linearity) {
void SetLinearity(const CuMatrixBase<BaseFloat>& linearity) {
KALDI_ASSERT(linearity.NumRows() == linearity_.NumRows());
KALDI_ASSERT(linearity.NumCols() == linearity_.NumCols());
linearity_.CopyFromMat(linearity);
}
const CuVector<BaseFloat>& GetBiasCorr() const {
const CuVectorBase<BaseFloat>& GetBiasCorr() const {
return bias_corr_;
}
const CuMatrix<BaseFloat>& GetLinearityCorr() const {
const CuMatrixBase<BaseFloat>& GetLinearityCorr() const {
return linearity_corr_;
}
......
......@@ -120,7 +120,7 @@ class AveragePooling2DComponent : public Component {
}
void PropagateFnc(const CuMatrix<BaseFloat> &in, CuMatrix<BaseFloat> *out) {
void PropagateFnc(const CuMatrixBase<BaseFloat> &in, CuMatrixBase<BaseFloat> *out) {
// useful dims
int32 num_input_fmaps = input_dim_ / (fmap_x_len_ * fmap_y_len_);
......@@ -151,8 +151,8 @@ class AveragePooling2DComponent : public Component {
}
}
void BackpropagateFnc(const CuMatrix<BaseFloat> &in, const CuMatrix<BaseFloat> &out,
const CuMatrix<BaseFloat> &out_diff, CuMatrix<BaseFloat> *in_diff) {
void BackpropagateFnc(const CuMatrixBase<BaseFloat> &in, const CuMatrixBase<BaseFloat> &out,
const CuMatrixBase<BaseFloat> &out_diff, CuMatrixBase<BaseFloat> *in_diff) {
// useful dims
int32 num_input_fmaps = input_dim_ / (fmap_x_len_ * fmap_y_len_);
......
......@@ -95,7 +95,7 @@ class AveragePoolingComponent : public Component {
WriteBasicType(os, binary, pool_stride_);
}
void PropagateFnc(const CuMatrix<BaseFloat> &in, CuMatrix<BaseFloat> *out) {
void PropagateFnc(const CuMatrixBase<BaseFloat> &in, CuMatrixBase<BaseFloat> *out) {
// useful dims
int32 num_patches = input_dim_ / pool_stride_;
int32 num_pools = 1 + (num_patches - pool_size_) / pool_step_;
......@@ -113,8 +113,8 @@ class AveragePoolingComponent : public Component {
}
}
void BackpropagateFnc(const CuMatrix<BaseFloat> &in, const CuMatrix<BaseFloat> &out,
const CuMatrix<BaseFloat> &out_diff, CuMatrix<BaseFloat> *in_diff) {
void BackpropagateFnc(const CuMatrixBase<BaseFloat> &in, const CuMatrixBase<BaseFloat> &out,
const CuMatrixBase<BaseFloat> &out_diff, CuMatrixBase<BaseFloat> *in_diff) {
// useful dims
int32 num_patches = input_dim_ / pool_stride_;
int32 num_pools = 1 + (num_patches - pool_size_) / pool_step_;
......
......@@ -115,12 +115,12 @@ class Component {
}
/// Perform forward pass propagation Input->Output
void Propagate(const CuMatrix<BaseFloat> &in, CuMatrix<BaseFloat> *out);
void Propagate(const CuMatrixBase<BaseFloat> &in, CuMatrix<BaseFloat> *out);
/// Perform backward pass propagation, out_diff -> in_diff
/// '&in' and '&out' will sometimes be unused...
void Backpropagate(const CuMatrix<BaseFloat> &in,
const CuMatrix<BaseFloat> &out,
const CuMatrix<BaseFloat> &out_diff,
void Backpropagate(const CuMatrixBase<BaseFloat> &in,
const CuMatrixBase<BaseFloat> &out,
const CuMatrixBase<BaseFloat> &out_diff,
CuMatrix<BaseFloat> *in_diff);
/// Initialize component from a line in config file
......@@ -138,13 +138,13 @@ class Component {
/// Abstract interface for propagation/backpropagation
protected:
/// Forward pass transformation (to be implemented by descending class...)
virtual void PropagateFnc(const CuMatrix<BaseFloat> &in,
CuMatrix<BaseFloat> *out) = 0;
virtual void PropagateFnc(const CuMatrixBase<BaseFloat> &in,
CuMatrixBase<BaseFloat> *out) = 0;
/// Backward pass transformation (to be implemented by descending class...)
virtual void BackpropagateFnc(const CuMatrix<BaseFloat> &in,
const CuMatrix<BaseFloat> &out,
const CuMatrix<BaseFloat> &out_diff,
CuMatrix<BaseFloat> *in_diff) = 0;
virtual void BackpropagateFnc(const CuMatrixBase<BaseFloat> &in,
const CuMatrixBase<BaseFloat> &out,
const CuMatrixBase<BaseFloat> &out_diff,
CuMatrixBase<BaseFloat> *in_diff) = 0;
/// Initialize internal data of a component
virtual void InitData(std::istream &is) { }
......@@ -190,8 +190,8 @@ class UpdatableComponent : public Component {
virtual void GetParams(Vector<BaseFloat> *params) const = 0;
/// Compute gradient and update parameters
virtual void Update(const CuMatrix<BaseFloat> &input,
const CuMatrix<BaseFloat> &diff) = 0;
virtual void Update(const CuMatrixBase<BaseFloat> &input,
const CuMatrixBase<BaseFloat> &diff) = 0;
/// Sets the training options to the component
virtual void SetTrainOptions(const NnetTrainOptions &opts) {
......@@ -210,7 +210,7 @@ class UpdatableComponent : public Component {
};
inline void Component::Propagate(const CuMatrix<BaseFloat> &in,
inline void Component::Propagate(const CuMatrixBase<BaseFloat> &in,
CuMatrix<BaseFloat> *out) {
// Check the dims
if (input_dim_ != in.NumCols()) {
......@@ -224,9 +224,9 @@ inline void Component::Propagate(const CuMatrix<BaseFloat> &in,
}
inline void Component::Backpropagate(const CuMatrix<BaseFloat> &in,
const CuMatrix<BaseFloat> &out,
const CuMatrix<BaseFloat> &out_diff,
inline void Component::Backpropagate(const CuMatrixBase<BaseFloat> &in,
const CuMatrixBase<BaseFloat> &out,
const CuMatrixBase<BaseFloat> &out_diff,
CuMatrix<BaseFloat> *in_diff) {
// Check the dims
if (output_dim_ != out_diff.NumCols()) {
......
......@@ -231,7 +231,7 @@ class Convolutional2DComponent : public UpdatableComponent {
"\n bias_grad" + MomentStatistics(bias_grad_);
}
void PropagateFnc(const CuMatrix<BaseFloat> &in, CuMatrix<BaseFloat> *out) {
void PropagateFnc(const CuMatrixBase<BaseFloat> &in, CuMatrixBase<BaseFloat> *out) {
// useful dims
int32 num_input_fmaps = input_dim_ / (fmap_x_len_ * fmap_y_len_);
// int32 inp_fmap_size = fmap_x_len_ * fmap_y_len_;
......@@ -292,8 +292,8 @@ class Convolutional2DComponent : public UpdatableComponent {
}
void BackpropagateFnc(const CuMatrix<BaseFloat> &in, const CuMatrix<BaseFloat> &out,
const CuMatrix<BaseFloat> &out_diff, CuMatrix<BaseFloat> *in_diff) {
void BackpropagateFnc(const CuMatrixBase<BaseFloat> &in, const CuMatrixBase<BaseFloat> &out,
const CuMatrixBase<BaseFloat> &out_diff, CuMatrixBase<BaseFloat> *in_diff) {
// useful dims
int32 num_input_fmaps = input_dim_ / (fmap_x_len_ * fmap_y_len_);
......@@ -350,7 +350,7 @@ class Convolutional2DComponent : public UpdatableComponent {
}
void Update(const CuMatrix<BaseFloat> &input, const CuMatrix<BaseFloat> &diff) {
void Update(const CuMatrixBase<BaseFloat> &input, const CuMatrixBase<BaseFloat> &diff) {
// useful dims
// int32 num_input_fmaps = input_dim_ / (fmap_x_len_ * fmap_y_len_);
......
......@@ -203,7 +203,7 @@ class ConvolutionalComponent : public UpdatableComponent {
"\n bias_grad" + MomentStatistics(bias_grad_);
}
void PropagateFnc(const CuMatrix<BaseFloat> &in, CuMatrix<BaseFloat> *out) {
void PropagateFnc(const CuMatrixBase<BaseFloat> &in, CuMatrixBase<BaseFloat> *out) {
// useful dims
int32 num_splice = input_dim_ / patch_stride_;
int32 num_patches = 1 + (patch_stride_ - patch_dim_) / patch_step_;
......@@ -253,8 +253,8 @@ class ConvolutionalComponent : public UpdatableComponent {
}
void BackpropagateFnc(const CuMatrix<BaseFloat> &in, const CuMatrix<BaseFloat> &out,
const CuMatrix<BaseFloat> &out_diff, CuMatrix<BaseFloat> *in_diff) {
void BackpropagateFnc(const CuMatrixBase<BaseFloat> &in, const CuMatrixBase<BaseFloat> &out,
const CuMatrixBase<BaseFloat> &out_diff, CuMatrixBase<BaseFloat> *in_diff) {
// useful dims
int32 num_splice = input_dim_ / patch_stride_;
int32 num_patches = 1 + (patch_stride_ - patch_dim_) / patch_step_;
......@@ -287,7 +287,7 @@ class ConvolutionalComponent : public UpdatableComponent {
}
void Update(const CuMatrix<BaseFloat> &input, const CuMatrix<BaseFloat> &diff) {
void Update(const CuMatrixBase<BaseFloat> &input, const CuMatrixBase<BaseFloat> &diff) {
// useful dims
int32 num_patches = 1 + (patch_stride_ - patch_dim_) / patch_step_;
int32 num_filters = filters_.NumRows();
......
......@@ -43,7 +43,7 @@ class KlHmm : public Component {
return kKlHmm;
}
void PropagateFnc(const CuMatrix<BaseFloat> &in, CuMatrix<BaseFloat> *out) {
void PropagateFnc(const CuMatrixBase<BaseFloat> &in, CuMatrixBase<BaseFloat> *out) {
if (kl_inv_q_.NumRows() == 0) {
// Copy the CudaMatrix to a Matrix
Matrix<BaseFloat> in_tmp(in.NumRows(), in.NumCols());
......@@ -97,8 +97,8 @@ class KlHmm : public Component {
out->Scale(-1);
}
void BackpropagateFnc(const CuMatrix<BaseFloat> &in, const CuMatrix<BaseFloat> &out,
const CuMatrix<BaseFloat> &out_diff, CuMatrix<BaseFloat> *in_diff) {
void BackpropagateFnc(const CuMatrixBase<BaseFloat> &in, const CuMatrixBase<BaseFloat> &out,
const CuMatrixBase<BaseFloat> &out_diff, CuMatrixBase<BaseFloat> *in_diff) {
KALDI_ERR << "Unimplemented";
}
......
......@@ -104,19 +104,19 @@ class LinearTransform : public UpdatableComponent {
", lr-coef " + ToString(learn_rate_coef_);
}
void PropagateFnc(const CuMatrix<BaseFloat> &in, CuMatrix<BaseFloat> *out) {
void PropagateFnc(const CuMatrixBase<BaseFloat> &in, CuMatrixBase<BaseFloat> *out) {
// multiply by weights^t
out->AddMatMat(1.0, in, kNoTrans, linearity_, kTrans, 0.0);
}
void BackpropagateFnc(const CuMatrix<BaseFloat> &in, const CuMatrix<BaseFloat> &out,
const CuMatrix<BaseFloat> &out_diff, CuMatrix<BaseFloat> *in_diff) {
void BackpropagateFnc(const CuMatrixBase<BaseFloat> &in, const CuMatrixBase<BaseFloat> &out,
const CuMatrixBase<BaseFloat> &out_diff, CuMatrixBase<BaseFloat> *in_diff) {
// multiply error derivative by weights
in_diff->AddMatMat(1.0, out_diff, kNoTrans, linearity_, kNoTrans, 0.0);
}
void Update(const CuMatrix<BaseFloat> &input, const CuMatrix<BaseFloat> &diff) {
void Update(const CuMatrixBase<BaseFloat> &input, const CuMatrixBase<BaseFloat> &diff) {
// we use following hyperparameters from the option class
const BaseFloat lr = opts_.learn_rate;
const BaseFloat mmt = opts_.momentum;
......@@ -139,17 +139,17 @@ class LinearTransform : public UpdatableComponent {
}
/// Accessors to the component parameters
const CuMatrix<BaseFloat>& GetLinearity() {
const CuMatrixBase<BaseFloat>& GetLinearity() {
return linearity_;
}
void SetLinearity(const CuMatrix<BaseFloat>& linearity) {
void SetLinearity(const CuMatrixBase<BaseFloat>& linearity) {
KALDI_ASSERT(linearity.NumRows() == linearity_.NumRows());
KALDI_ASSERT(linearity.NumCols() == linearity_.NumCols());
linearity_.CopyFromMat(linearity);
}
const CuMatrix<BaseFloat>& GetLinearityCorr() {
const CuMatrixBase<BaseFloat>& GetLinearityCorr() {
return linearity_corr_;
}
......
......@@ -30,7 +30,7 @@ namespace nnet1 {
/* Xent */
void Xent::Eval(const CuMatrix<BaseFloat> &net_out, const CuMatrix<BaseFloat> &target, CuMatrix<BaseFloat> *diff) {
void Xent::Eval(const CuMatrixBase<BaseFloat> &net_out, const CuMatrixBase<BaseFloat> &target, CuMatrix<BaseFloat> *diff) {
KALDI_ASSERT(net_out.NumCols() == target.NumCols());
KALDI_ASSERT(net_out.NumRows() == target.NumRows());
diff->Resize(net_out.NumRows(), net_out.NumCols());
......@@ -71,7 +71,7 @@ void Xent::Eval(const CuMatrix<BaseFloat> &net_out, const CuMatrix<BaseFloat> &t
}
void Xent::Eval(const CuMatrix<BaseFloat>& net_out, const Posterior& post, CuMatrix<BaseFloat>* diff) {
void Xent::Eval(const CuMatrixBase<BaseFloat>& net_out, const Posterior& post, CuMatrix<BaseFloat>* diff) {
int32 num_frames = net_out.NumRows(),
num_pdf = net_out.NumCols();
KALDI_ASSERT(num_frames == post.size());
......@@ -155,7 +155,7 @@ void Xent::Eval(const CuMatrix<BaseFloat>& net_out, const Posterior& post, CuMat
}
void Xent::EvalVec(const CuMatrix<BaseFloat> &net_out, const std::vector<int32> &target, CuMatrix<BaseFloat> *diff) {
void Xent::EvalVec(const CuMatrixBase<BaseFloat> &net_out, const std::vector<int32> &target, CuMatrix<BaseFloat> *diff) {
// evaluate the frame-level classification
int32 correct=0;
net_out.FindRowMaxId(&max_id_out_);
......@@ -213,7 +213,7 @@ std::string Xent::Report() {
/* Mse */
void Mse::Eval(const CuMatrix<BaseFloat>& net_out, const CuMatrix<BaseFloat>& target, CuMatrix<BaseFloat>* diff) {
void Mse::Eval(const CuMatrixBase<BaseFloat>& net_out, const CuMatrixBase<BaseFloat>& target, CuMatrix<BaseFloat>* diff) {
KALDI_ASSERT(net_out.NumCols() == target.NumCols());
KALDI_ASSERT(net_out.NumRows() == target.NumRows());
int32 num_frames = net_out.NumRows();
......@@ -253,7 +253,7 @@ void Mse::Eval(const CuMatrix<BaseFloat>& net_out, const CuMatrix<BaseFloat>& ta
}
void Mse::Eval(const CuMatrix<BaseFloat>& net_out, const Posterior& post, CuMatrix<BaseFloat>* diff) {
void Mse::Eval(const CuMatrixBase<BaseFloat>& net_out, const Posterior& post, CuMatrix<BaseFloat>* diff) {
int32 num_frames = net_out.NumRows(),
num_pdf = net_out.NumCols();
KALDI_ASSERT(num_frames == post.size());
......
......@@ -37,13 +37,13 @@ class Xent {
~Xent() { }
/// Evaluate cross entropy from hard labels
void Eval(const CuMatrix<BaseFloat> &net_out, const CuMatrix<BaseFloat> &target,
void Eval(const CuMatrixBase<BaseFloat> &net_out, const CuMatrixBase<BaseFloat> &target,
CuMatrix<BaseFloat> *diff);
/// Evaluate cross entropy from posteriors
void Eval(const CuMatrix<BaseFloat> &net_out, const Posterior &target,
void Eval(const CuMatrixBase<BaseFloat> &net_out, const Posterior &target,
CuMatrix<BaseFloat> *diff);
/// Evaluate cross entropy from soft labels
void EvalVec(const CuMatrix<BaseFloat> &net_out, const std::vector<int32> &target,
void EvalVec(const CuMatrixBase<BaseFloat> &net_out, const std::vector<int32> &target,
CuMatrix<BaseFloat> *diff);
/// Generate string with error report
......@@ -85,9 +85,9 @@ class Mse {
~Mse() { }
/// Evaluate mean square error from target values
void Eval(const CuMatrix<BaseFloat>& net_out, const CuMatrix<BaseFloat>& target,
void Eval(const CuMatrixBase<BaseFloat>& net_out, const CuMatrixBase<BaseFloat>& target,
CuMatrix<BaseFloat>* diff);
void Eval(const CuMatrix<BaseFloat>& net_out, const Posterior& target,
void Eval(const CuMatrixBase<BaseFloat>& net_out, const Posterior& target,
CuMatrix<BaseFloat>* diff);
/// Generate string with error report
......
......@@ -120,7 +120,7 @@ class MaxPooling2DComponent : public Component {
}
void PropagateFnc(const CuMatrix<BaseFloat> &in, CuMatrix<BaseFloat> *out) {
void PropagateFnc(const CuMatrixBase<BaseFloat> &in, CuMatrixBase<BaseFloat> *out) {
// useful dims
int32 num_input_fmaps = input_dim_ / (fmap_x_len_ * fmap_y_len_);
......@@ -150,8 +150,8 @@ class MaxPooling2DComponent : public Component {
}
}
void BackpropagateFnc(const CuMatrix<BaseFloat> &in, const CuMatrix<BaseFloat> &out,
const CuMatrix<BaseFloat> &out_diff, CuMatrix<BaseFloat> *in_diff) {
void BackpropagateFnc(const CuMatrixBase<BaseFloat> &in, const CuMatrixBase<BaseFloat> &out,
const CuMatrixBase<BaseFloat> &out_diff, CuMatrixBase<BaseFloat> *in_diff) {
// useful dims
int32 num_input_fmaps = input_dim_ / (fmap_x_len_ * fmap_y_len_);
......
......@@ -95,7 +95,7 @@ class MaxPoolingComponent : public Component {
WriteBasicType(os, binary, pool_stride_);
}
void PropagateFnc(const CuMatrix<BaseFloat> &in, CuMatrix<BaseFloat> *out) {
void PropagateFnc(const CuMatrixBase<BaseFloat> &in, CuMatrixBase<BaseFloat> *out) {
// useful dims
int32 num_patches = input_dim_ / pool_stride_;
int32 num_pools = 1 + (num_patches - pool_size_) / pool_step_;
......@@ -112,8 +112,8 @@ class MaxPoolingComponent : public Component {
}
}
void BackpropagateFnc(const CuMatrix<BaseFloat> &in, const CuMatrix<BaseFloat> &out,
const CuMatrix<BaseFloat> &out_diff, CuMatrix<BaseFloat> *in_diff) {
void BackpropagateFnc(const CuMatrixBase<BaseFloat> &in, const CuMatrixBase<BaseFloat> &out,
const CuMatrixBase<BaseFloat> &out_diff, CuMatrixBase<BaseFloat> *in_diff) {
// useful dims
int32 num_patches = input_dim_ / pool_stride_;
int32 num_pools = 1 + (num_patches - pool_size_) / pool_step_;
......
......@@ -62,7 +62,7 @@ Nnet::~Nnet() {
}
void Nnet::Propagate(const CuMatrix<BaseFloat> &in, CuMatrix<BaseFloat> *out) {
void Nnet::Propagate(const CuMatrixBase<BaseFloat> &in, CuMatrix<BaseFloat> *out) {
KALDI_ASSERT(NULL != out);
if (NumComponents() == 0) {
......@@ -79,14 +79,12 @@ void Nnet::Propagate(const CuMatrix<BaseFloat> &in, CuMatrix<BaseFloat> *out) {
for(int32 i=0; i<(int32)components_.size(); i++) {
components_[i]->Propagate(propagate_buf_[i], &propagate_buf_[i+1]);
}
CuMatrix<BaseFloat> &mat = propagate_buf_[components_.size()];
out->Resize(mat.NumRows(), mat.NumCols());
out->CopyFromMat(mat);