kaldi-math.h 9.61 KB
Newer Older
1 2
// base/kaldi-math.h

Arnab Ghoshal's avatar
Arnab Ghoshal committed
3 4
// Copyright 2009-2011  Ondrej Glembek;  Microsoft Corporation;  Yanmin Qian;
//                      Jan Silovsky;  Saarland University
5
//
6 7
// See ../../COPYING for clarification regarding multiple authors
//
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//  http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#ifndef KALDI_BASE_KALDI_MATH_H_
#define KALDI_BASE_KALDI_MATH_H_ 1

#ifdef _MSC_VER
#include <float.h>
#endif

#include <cmath>
#include <limits>
#include <vector>

#include "base/kaldi-types.h"
#include "base/kaldi-common.h"


#ifndef DBL_EPSILON
#define DBL_EPSILON 2.2204460492503131e-16
#endif
#ifndef FLT_EPSILON
#define FLT_EPSILON 1.19209290e-7f
#endif

#ifndef M_PI
#  define M_PI 3.1415926535897932384626433832795
#endif

#ifndef M_SQRT2
#  define M_SQRT2 1.4142135623730950488016887
#endif


#ifndef M_2PI
#  define M_2PI 6.283185307179586476925286766559005
#endif

#ifndef M_SQRT1_2
# define M_SQRT1_2 0.7071067811865475244008443621048490
#endif

60
#ifndef M_LOG_2PI
61
#define M_LOG_2PI 1.8378770664093454835606594728112
62 63 64 65
#endif

#ifndef M_LN2
#define M_LN2 0.693147180559945309417232121458
Dan Povey's avatar
Dan Povey committed
66
#endif
67

68 69 70 71
#define KALDI_ISNAN std::isnan
#define KALDI_ISINF std::isinf
#define KALDI_ISFINITE(x) std::isfinite(x)

72 73 74 75 76 77
#if !defined(KALDI_SQR)
# define KALDI_SQR(x) ((x) * (x))
#endif

namespace kaldi {

78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122
#if !defined(_MSC_VER) || (_MSC_VER >= 1900)
inline double Exp(double x) { return exp(x); }
#ifndef KALDI_NO_EXPF
inline float Exp(float x) { return expf(x); }
#else
inline float Exp(float x) { return exp(static_cast<double>(x)); }
#endif // KALDI_NO_EXPF
#else
inline double Exp(double x) { return exp(x); }
#if !defined(__INTEL_COMPILER) && _MSC_VER == 1800 && defined(_M_X64)
// Microsoft CL v18.0 buggy 64-bit implementation of
// expf() incorrectly returns -inf for exp(-inf).
inline float Exp(float x) { return exp(static_cast<double>(x)); }
#else
inline float Exp(float x) { return expf(x); }
#endif // !defined(__INTEL_COMPILER) && _MSC_VER == 1800 && defined(_M_X64)
#endif // !defined(_MSC_VER) || (_MSC_VER >= 1900)

inline double Log(double x) { return log(x); }
inline float Log(float x) { return logf(x); }

#if !defined(_MSC_VER) || (_MSC_VER >= 1700)
inline double Log1p(double x) {  return log1p(x); }
inline float Log1p(float x) {  return log1pf(x); }
#else
inline double Log1p(double x) {
  const double cutoff = 1.0e-08;
  if (x < cutoff)
    return x - 2 * x * x;
  else
    return Log(1.0 + x);
}

inline float Log1p(float x) {
  const float cutoff = 1.0e-07;
  if (x < cutoff)
    return x - 2 * x * x;
  else
    return Log(1.0 + x);
}
#endif

static const double kMinLogDiffDouble = Log(DBL_EPSILON);  // negative!
static const float kMinLogDiffFloat = Log(FLT_EPSILON);  // negative!

123 124 125
// -infinity
const float kLogZeroFloat = -std::numeric_limits<float>::infinity();
const double kLogZeroDouble = -std::numeric_limits<double>::infinity();
126
const BaseFloat kLogZeroBaseFloat = -std::numeric_limits<BaseFloat>::infinity();
127

128 129 130 131 132
// Returns a random integer between 0 and RAND_MAX, inclusive
int Rand(struct RandomState* state=NULL);

// State for thread-safe random number generator
struct RandomState {
Dan Povey's avatar
Dan Povey committed
133
  RandomState();
134 135 136
  unsigned seed;
};

137
// Returns a random integer between min and max inclusive.
138
int32 RandInt(int32 min, int32 max, struct RandomState* state=NULL);
139

140
bool WithProb(BaseFloat prob, struct RandomState* state=NULL); // Returns true with probability "prob",
141
// with 0 <= prob <= 1 [we check this].
142
// Internally calls Rand().  This function is carefully implemented so
143 144
// that it should work even if prob is very small.

145 146 147
/// Returns a random number strictly between 0 and 1.
inline float RandUniform(struct RandomState* state = NULL) {
  return static_cast<float>((Rand(state) + 1.0) / (RAND_MAX+2.0));
148 149
}

150
inline float RandGauss(struct RandomState* state = NULL) {
151
  return static_cast<float>(sqrtf (-2 * Log(RandUniform(state)))
152
                            * cosf(2*M_PI*RandUniform(state)));
153 154 155 156 157
}

// Returns poisson-distributed random number.  Uses Knuth's algorithm.
// Take care: this takes time proportinal
// to lambda.  Faster algorithms exist but are more complex.
158 159 160 161 162
int32 RandPoisson(float lambda, struct RandomState* state=NULL);

// Returns a pair of gaussian random numbers. Uses Box-Muller transform
void RandGauss2(float *a, float *b, RandomState *state = NULL);
void RandGauss2(double *a, double *b, RandomState *state = NULL);
163

164 165
// Also see Vector<float,double>::RandCategorical().

166 167 168
// This is a randomized pruning mechanism that preserves expectations,
// that we typically use to prune posteriors.
template<class Float>
169
inline Float RandPrune(Float post, BaseFloat prune_thresh, struct RandomState* state=NULL) {
170 171 172 173
  KALDI_ASSERT(prune_thresh >= 0.0);
  if (post == 0.0 || std::abs(post) >= prune_thresh)
    return post;
  return (post >= 0 ? 1.0 : -1.0) *
174
      (RandUniform(state) <= fabs(post)/prune_thresh ? prune_thresh : 0.0);
175 176
}

177 178 179 180 181 182 183 184 185 186 187 188 189

inline double LogAdd(double x, double y) {
  double diff;
  if (x < y) {
    diff = x - y;
    x = y;
  } else {
    diff = y - x;
  }
  // diff is negative.  x is now the larger one.

  if (diff >= kMinLogDiffDouble) {
    double res;
190
    res = x + Log1p(Exp(diff));
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
    return res;
  } else {
    return x;  // return the larger one.
  }
}


inline float LogAdd(float x, float y) {
  float diff;
  if (x < y) {
    diff = x - y;
    x = y;
  } else {
    diff = y - x;
  }
  // diff is negative.  x is now the larger one.

  if (diff >= kMinLogDiffFloat) {
    float res;
210
    res = x + Log1p(Exp(diff));
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227
    return res;
  } else {
    return x;  // return the larger one.
  }
}


// returns exp(x) - exp(y).
inline double LogSub(double x, double y) {
  if (y >= x) {  // Throws exception if y>=x.
    if (y == x)
      return kLogZeroDouble;
    else
      KALDI_ERR << "Cannot subtract a larger from a smaller number.";
  }

  double diff = y - x;  // Will be negative.
228
  double res = x + Log(1.0 - Exp(diff));
229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246

  // res might be NAN if diff ~0.0, and 1.0-exp(diff) == 0 to machine precision
  if (KALDI_ISNAN(res))
    return kLogZeroDouble;
  return res;
}


// returns exp(x) - exp(y).
inline float LogSub(float x, float y) {
  if (y >= x) {  // Throws exception if y>=x.
    if (y == x)
      return kLogZeroDouble;
    else
      KALDI_ERR << "Cannot subtract a larger from a smaller number.";
  }

  float diff = y - x;  // Will be negative.
247
  float res = x + Log(1.0f - Exp(diff));
248 249 250 251 252 253 254

  // res might be NAN if diff ~0.0, and 1.0-exp(diff) == 0 to machine precision
  if (KALDI_ISNAN(res))
    return kLogZeroFloat;
  return res;
}

255
/// return abs(a - b) <= relative_tolerance * (abs(a)+abs(b)).
256 257
static inline bool ApproxEqual(float a, float b,
                               float relative_tolerance = 0.001) {
258
  // a==b handles infinities.
259
  if (a==b) return true;
260 261
  float diff = std::abs(a-b);
  if (diff == std::numeric_limits<float>::infinity()
262
      || diff != diff) return false; // diff is +inf or nan.
263
  return (diff <= relative_tolerance*(std::abs(a)+std::abs(b))); 
264 265
}

266
/// assert abs(a - b) <= relative_tolerance * (abs(a)+abs(b))
267 268
static inline void AssertEqual(float a, float b,
                               float relative_tolerance = 0.001) {
269
  // a==b handles infinities.
270
  KALDI_ASSERT(ApproxEqual(a, b, relative_tolerance));
271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286
}


// RoundUpToNearestPowerOfTwo does the obvious thing. It crashes if n <= 0.
int32 RoundUpToNearestPowerOfTwo(int32 n);

template<class I> I  Gcd(I m, I n) {
  if (m == 0 || n == 0) {
    if (m == 0 && n == 0) {  // gcd not defined, as all integers are divisors.
      KALDI_ERR << "Undefined GCD since m = 0, n = 0.";
    }
    return (m == 0 ? (n > 0 ? n : -n) : ( m > 0 ? m : -m));
    // return absolute value of whichever is nonzero
  }
  // could use compile-time assertion
  // but involves messing with complex template stuff.
287
  KALDI_ASSERT(std::numeric_limits<I>::is_integer);
288 289 290 291 292 293 294 295
  while (1) {
    m %= n;
    if (m == 0) return (n > 0 ? n : -n);
    n %= m;
    if (n == 0) return (m > 0 ? m : -m);
  }
}

296 297 298 299 300 301 302 303
/// Returns the least common multiple of two integers.  Will
/// crash unless the inputs are positive.
template<class I> I  Lcm(I m, I n) {
  KALDI_ASSERT(m > 0 && n > 0);
  I gcd = Gcd(m, n);
  return gcd * (m/gcd) * (n/gcd);
}

304

305 306 307 308 309 310
template<class I> void Factorize(I m, std::vector<I> *factors) {
  // Splits a number into its prime factors, in sorted order from
  // least to greatest,  with duplication.  A very inefficient
  // algorithm, which is mainly intended for use in the
  // mixed-radix FFT computation (where we assume most factors
  // are small).
311 312
  KALDI_ASSERT(factors != NULL);
  KALDI_ASSERT(m >= 1);  // Doesn't work for zero or negative numbers.
313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333
  factors->clear();
  I small_factors[10] = { 2, 3, 5, 7, 11, 13, 17, 19, 23, 29 };

  // First try small factors.
  for (I i = 0; i < 10; i++) {
    if (m == 1) return;  // We're done.
    while (m % small_factors[i] == 0) {
      m /= small_factors[i];
      factors->push_back(small_factors[i]);
    }
  }
  // Next try all odd numbers starting from 31.
  for (I j = 31;; j += 2) {
    if (m == 1) return;
    while (m % j == 0) {
      m /= j;
      factors->push_back(j);
    }
  }
}

334 335
inline double Hypot(double x, double y) {  return hypot(x, y); }
inline float Hypot(float x, float y) {  return hypotf(x, y); }
336

337 338


339 340 341 342 343

}  // namespace kaldi


#endif  // KALDI_BASE_KALDI_MATH_H_