Commit a8ad8600 authored by Pegah Ghahremani's avatar Pegah Ghahremani
Browse files

trunk: some fixes to p-norm related functions

git-svn-id: https://svn.code.sf.net/p/kaldi/code/trunk@4230 5e6a8d80-dfce-4ca6-a32a-6e07a63d50c8
parent d904fb57
......@@ -1002,53 +1002,52 @@ void MatrixBase<Real>::MulRowsVec(const VectorBase<Real> &scale) {
}
}
template<typename Real>
void MatrixBase<Real>::MulRowsGroupMat(const MatrixBase<Real> &src) {
KALDI_ASSERT(src.NumCols() > 0 && src.NumCols() <= this->NumCols());
KALDI_ASSERT(this->NumCols() % src.NumCols() == 0 ||
this->NumCols() % (src.NumCols() - 1) < this->NumCols() / (src.NumCols() - 1));
int group_size = 0;
if (this->NumCols() % src.NumCols() == 0) {
group_size = this->NumCols() / src.NumCols();
} else {
group_size = this->NumCols() / src.NumCols() + 1;
}
MatrixIndexT M = num_rows_, N = num_cols_;
KALDI_ASSERT(src.NumRows() == this->NumRows() &&
this->NumCols() % src.NumCols() == 0);
int32 group_size = this->NumCols() / src.NumCols(),
num_groups = this->NumCols() / group_size,
num_rows = this->NumRows();
for (MatrixIndexT i = 0; i < M; i++)
for (MatrixIndexT j = 0; j < N; j++)
(*this)(i, j) *= src(i, j / group_size);
for (MatrixIndexT i = 0; i < num_rows; i++) {
Real *data = this->RowData(i);
for (MatrixIndexT j = 0; j < num_groups; j++, data += group_size) {
Real scale = src(i, j);
cblas_Xscal(group_size, scale, data, 1);
}
}
}
template<typename Real>
void MatrixBase<Real>::GroupPnormDeriv(const MatrixBase<Real> &src1,
const MatrixBase<Real> &src2,
void MatrixBase<Real>::GroupPnormDeriv(const MatrixBase<Real> &input,
const MatrixBase<Real> &output,
Real power) {
KALDI_ASSERT(src2.NumCols() > 0 && src2.NumCols() <= this->NumCols());
KALDI_ASSERT(this->NumCols() % src2.NumCols() == 0 ||
this->NumCols() % (src2.NumCols() - 1) < this->NumCols() / (src2.NumCols() - 1));
int group_size = 0;
if (this->NumCols() % src2.NumCols() == 0) {
group_size = this->NumCols() / src2.NumCols();
} else {
group_size = this->NumCols() / src2.NumCols() + 1;
}
MatrixIndexT M = this->NumRows(), N = this->NumCols();
KALDI_ASSERT(input.NumCols() == this->NumCols() && input.NumRows() == this->NumRows());
KALDI_ASSERT(this->NumCols() % output.NumCols() == 0 &&
this->NumRows() == output.NumRows());
int group_size = this->NumCols() / output.NumCols(),
num_rows = this->NumRows(), num_cols = this->NumCols();
if (power == 1.0) {
for (MatrixIndexT i = 0; i < M; i++)
for (MatrixIndexT j = 0; j < N; j++)
(*this)(i, j) = (src1(i, j) == 0 ? 0 : (src1(i, j) > 0 ? 1 : -1));
for (MatrixIndexT i = 0; i < num_rows; i++) {
for (MatrixIndexT j = 0; j < num_cols; j++) {
Real input_val = input(i, j);
(*this)(i, j) = (input_val == 0 ? 0 : (input_val > 0 ? 1 : -1));
}
}
} else {
for (MatrixIndexT i = 0; i < M; i++) {
for (MatrixIndexT j = 0; j < N; j++) {
if (src2(i, j / group_size) == 0) {
for (MatrixIndexT i = 0; i < num_rows; i++) {
for (MatrixIndexT j = 0; j < num_cols; j++) {
Real output_val = output(i, j / group_size),
input_val = input(i, j);
if (output_val == 0)
(*this)(i, j) = 0;
} else {
(*this)(i, j) = pow(std::abs(src1(i, j)), power - 1) *
(src2(i, j / group_size) > 0 ? pow(src2(i, j / group_size), 1 - power) : 1) *
(src1(i, j) >= 0 ? 1 : -1) ;
}
else
(*this)(i, j) = pow(std::abs(input_val), power - 1) *
pow(output_val, 1 - power) * (input_val >= 0 ? 1 : -1) ;
}
}
}
......@@ -2428,12 +2427,15 @@ void MatrixBase<Real>::SoftHinge(const MatrixBase<Real> &src) {
}
}
}
template<typename Real>
void MatrixBase<Real>::GroupPnorm(const MatrixBase<Real> &src, Real power) {
int group_size = src.NumCols() / this->NumCols();
KALDI_ASSERT(src.NumCols() == this->NumCols() * group_size);
for (MatrixIndexT i = 0; i < src.NumRows(); i++)
for (MatrixIndexT j = 0; j < this->NumCols(); j++)
KALDI_ASSERT(src.NumCols() % this->NumCols() == 0 &&
src.NumRows() == this->NumRows());
int group_size = src.NumCols() / this->NumCols(),
num_rows = this->NumRows(), num_cols = this->NumCols();
for (MatrixIndexT i = 0; i < num_rows; i++)
for (MatrixIndexT j = 0; j < num_cols; j++)
(*this)(i, j) = src.Row(i).Range(j * group_size, group_size).Norm(power);
}
......
......@@ -240,8 +240,9 @@ class MatrixBase {
/// each row by a scalar taken from that dimension of the vector.
void MulRowsVec(const VectorBase<Real> &scale);
/// divide each row into src.NumCols() groups,
/// and then scale i'th row's jth group of elements by src[i, j].
/// Divide each row into src.NumCols() equal groups, and then scale i'th row's
/// j'th group of elements by src(i, j). Requires src.NumRows() ==
/// this->NumRows() and this->NumCols() % src.NumCols() == 0.
void MulRowsGroupMat(const MatrixBase<Real> &src);
/// Returns logdet of matrix.
......@@ -418,8 +419,8 @@ class MatrixBase {
/// Set each element to y = log(1 + exp(x))
void SoftHinge(const MatrixBase<Real> &src);
/// Apply the function y(i) = (sum_{j = i*G}^{(i+1)*G-1} x_j ^ (power)) ^ (1 / p)
/// where G = x.NumCols() / y.NumCols() must be an integer.
/// Apply the function y(i) = (sum_{j = i*G}^{(i+1)*G-1} x_j^(power))^(1 / p).
/// Requires src.NumRows() == this->NumRows() and src.NumCols() % this->NumCols() == 0.
void GroupPnorm(const MatrixBase<Real> &src, Real power);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment