nnet-component.h 73.5 KB
Newer Older
1
// nnet2/nnet-component.h
2

3
// Copyright 2011-2013  Karel Vesely
4
//           2012-2014  Johns Hopkins University (author: Daniel Povey)
5 6
//                2013  Xiaohui Zhang    
//                2014  Vijayaditya Peddinti
7
//           2014-2015  Guoguo Chen
8

9 10
// See ../../COPYING for clarification regarding multiple authors
//
11 12 13 14 15 16 17 18 19 20 21 22 23
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//  http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

24 25
#ifndef KALDI_NNET2_NNET_COMPONENT_H_
#define KALDI_NNET2_NNET_COMPONENT_H_
26 27

#include "base/kaldi-common.h"
28
#include "itf/options-itf.h"
29
#include "matrix/matrix-lib.h"
30
#include "cudamatrix/cu-matrix-lib.h"
31
#include "thread/kaldi-mutex.h"
32
#include "nnet2/nnet-precondition-online.h"
33 34 35 36

#include <iostream>

namespace kaldi {
37
namespace nnet2 {
38 39


40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
/**
   ChunkInfo is a class whose purpose is to describe the structure of matrices 
   holding features.  This is useful mostly in training time. 
   The main reason why we have this is to support efficient
   training for networks which we have splicing components that splice in a
   non-contiguous way, e.g. frames -5, 0 and 5.  We also have in mind future
   extensibility to convnets which might have similar issues.  This class
   describes the structure of a minibatch of features, or of a single
   contiguous block of features.
   Examples are as follows, and offsets is empty if not mentioned:
     When decoding, at input to the network:
       feat_dim = 13, num_chunks = 1, first_offset = 0, last_offset = 691
      and in the middle of the network (assuming splicing is +-7):
       feat_dim = 1024, num_chunks = 1, first_offset = 7, last_offset = 684
    When training, at input to the network:
      feat_dim = 13, num_chunks = 512, first_offset = 0, last_offset= 14
     and in the middle of the network:
      feat_dim = 1024, num_chunks = 512, first_offset = 7, last_offset = 7
   The only situation where offsets would be nonempty would be if we do
   splicing with gaps in.  E.g. suppose at network input we splice +-2 frames
   (contiguous) and somewhere in the middle we splice frames {-5, 0, 5}, then
   we would have the following while training
     At input to the network:
      feat_dim = 13, num_chunks = 512, first_offset = 0, last_offset = 14
     After the first hidden layer:
      feat_dim = 1024, num_chunks = 512, first_offset = 2, last_offset = 12,
       offsets = {2, 10, 12}
     At the output of the last hidden layer (after the {-5, 0, 5} splice):
      feat_dim = 1024, num_chunks = 512, first_offset = 7, last_offset = 7
   (the decoding setup would still look pretty normal, so we don't give an example).
    
*/
class ChunkInfo {
 public:
  ChunkInfo()  // default constructor we assume this object will not be used
      : feat_dim_(0), num_chunks_(0),
        first_offset_(0), last_offset_(0), 
77
        offsets_() { }
78 79 80 81 82
 
  ChunkInfo(int32 feat_dim, int32 num_chunks,
            int32 first_offset, int32 last_offset ) 
      : feat_dim_(feat_dim), num_chunks_(num_chunks),
        first_offset_(first_offset), last_offset_(last_offset),
83
        offsets_() { Check(); }
84 85 86 87 88 89 90
  
  ChunkInfo(int32 feat_dim, int32 num_chunks,
            const std::vector<int32> offsets)
      : feat_dim_(feat_dim), num_chunks_(num_chunks),
        first_offset_(offsets.front()), last_offset_(offsets.back()),
        offsets_(offsets) { if (last_offset_ - first_offset_ + 1 == offsets_.size())
                              offsets_.clear();
91
          Check(); }
92 93 94 95 96 97 98 99 100 101 102 103

  // index : actual row index in the current chunk
  // offset : the time offset of feature frame at current row in the chunk
  // As described above offsets can take a variety of values, we see the indices
  // corresponding to the offsets in each case
  // 1) if first_offset = 0 & last_offset = 691, then chunk has data
  // corresponding to time offsets 0:691, so index = offset 
  // 2) if first_offset = 7 & last_offset = 684, 
  //      then index = offset - first offset
  // 3) if offsets = {2, 10, 12} then indices for these offsets are 0, 1 and 2
 
  // Returns the chunk row index corresponding to given time offset
104
  int32 GetIndex (int32 offset) const;
105 106
  
  // Returns time offset at the current row index in the chunk
107
  int32 GetOffset (int32 index) const;
108 109 110 111 112

  // Makes the offsets vector empty, to ensure that the chunk is processed as a
  // contiguous chunk with the given first_offset and last_offset
  void MakeOffsetsContiguous () { offsets_.clear(); Check(); }

113 114
  // Returns chunk size, meaning the number of distinct frame-offsets we
  // have for each chunk (they don't have to be contiguous).
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
  inline int32 ChunkSize() const { return NumRows() / num_chunks_; }

  // Returns number of chunks we expect the feature matrix to have
  inline int32 NumChunks() const { return num_chunks_; }

  /// Returns the number of rows that we expect the feature matrix to have.
  int32 NumRows() const { 
    return num_chunks_ * (!offsets_.empty() ? offsets_.size() :
                                         last_offset_ - first_offset_ + 1); }

  /// Returns the number of columns that we expect the feature matrix to have.
  int32 NumCols() const { return feat_dim_; }
    
  /// Checks that the matrix has the size we expect, and die if not.
  void CheckSize(const CuMatrixBase<BaseFloat> &mat) const;

  /// Checks that the data in the ChunkInfo is valid, and die if not.
  void Check() const;  

 private:
  int32 feat_dim_;  // Feature dimension.
  int32 num_chunks_;  // Number of separate equal-sized chunks of features
  int32 first_offset_;  // Start time offset within each chunk, numbered so that at
                      // the input to the network, the first_offset of the first
                      // feature would always be zero.
  int32 last_offset_;  // End time offset within each chunk.
  std::vector<int32> offsets_; // offsets is only nonempty if the chunk contains
                             // a non-contiguous sequence.  If nonempty, it must
                             // be sorted, and offsets.front() == first_offset,
                             // offsets.back() == last_offset.
  
};

148 149 150 151 152 153 154 155 156 157 158
/**
 * Abstract class, basic element of the network,
 * it is a box with defined inputs, outputs,
 * and tranformation functions interface.
 *
 * It is able to propagate and backpropagate
 * exact implementation is to be implemented in descendants.
 *
 */ 
class Component {
 public:
159
  Component(): index_(-1) { }
160 161 162 163
  
  virtual std::string Type() const = 0; // each type should return a string such as
  // "SigmoidComponent".

164 165 166 167 168 169
  /// Returns the index in the sequence of layers in the neural net; intended only
  /// to be used in debugging information.
  virtual int32 Index() const { return index_; }

  virtual void SetIndex(int32 index) { index_ = index; }

170 171 172 173 174 175 176 177 178 179 180
  /// Initialize, typically from a line of a config file.  The "args" will
  /// contain any parameters that need to be passed to the Component, e.g.
  /// dimensions.
  virtual void InitFromString(std::string args) = 0; 
  
  /// Get size of input vectors
  virtual int32 InputDim() const = 0;
  
  /// Get size of output vectors 
  virtual int32 OutputDim() const = 0;

181
  /// Return a vector describing the temporal context this component requires
182 183 184 185 186 187
  /// for each frame of output, as a sorted list.  The default implementation
  /// returns a vector ( 0 ), but a splicing layer might return e.g. (-2, -1, 0,
  /// 1, 2), but it doesn't have to be contiguous.  Note : The context needed by
  /// the entire network is a function of the contexts needed by all the
  /// components.  It is required that Context().front() <= 0 and
  /// Context().back() >= 0.
188
  virtual std::vector<int32> Context() const { return std::vector<int32>(1, 0); }
189 190 191 192 193 194 195 196

  /// Perform forward pass propagation Input->Output.  Each row is
  /// one frame or training example.  Interpreted as "num_chunks"
  /// equally sized chunks of frames; this only matters for layers
  /// that do things like context splicing.  Typically this variable
  /// will either be 1 (when we're processing a single contiguous
  /// chunk of data) or will be the same as in.NumFrames(), but
  /// other values are possible if some layers do splicing.
197 198 199
  virtual void Propagate(const ChunkInfo &in_info,
                         const ChunkInfo &out_info,
                         const CuMatrixBase<BaseFloat> &in,
200 201 202 203 204 205
                         CuMatrixBase<BaseFloat> *out) const = 0;

  /// A non-virtual propagate function that first resizes output if necessary.
  void Propagate(const ChunkInfo &in_info,
                 const ChunkInfo &out_info,
                 const CuMatrixBase<BaseFloat> &in,
206
                 CuMatrix<BaseFloat> *out) const {
207 208 209 210 211 212 213 214 215
    if (out->NumRows() != out_info.NumRows() ||
        out->NumCols() != out_info.NumCols()) {
      out->Resize(out_info.NumRows(), out_info.NumCols());
    }

    // Cast to CuMatrixBase to use the virtual version of propagate function.
    Propagate(in_info, out_info, in,
              static_cast<CuMatrixBase<BaseFloat>*>(out));
  } 
216 217 218 219 220 221 222 223 224
  
  /// Perform backward pass propagation of the derivative, and
  /// also either update the model (if to_update == this) or
  /// update another model or compute the model derivative (otherwise).
  /// Note: in_value and out_value are the values of the input and output
  /// of the component, and these may be dummy variables if respectively
  /// BackpropNeedsInput() or BackpropNeedsOutput() return false for
  /// that component (not all components need these).
  ///
225
  /// num_chunks lets us treat the input matrix as contiguous-in-time
226
  /// chunks of equal size; it only matters if splicing is involved.
227 228 229
  virtual void Backprop(const ChunkInfo &in_info,
                        const ChunkInfo &out_info,
                        const CuMatrixBase<BaseFloat> &in_value,
230 231
                        const CuMatrixBase<BaseFloat> &out_value,                        
                        const CuMatrixBase<BaseFloat> &out_deriv,
232
                        Component *to_update, // may be identical to "this".
233
                        CuMatrix<BaseFloat> *in_deriv) const = 0;
234

235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
  virtual bool BackpropNeedsInput() const { return true; } // if this returns false,
  // the "in_value" to Backprop may be a dummy variable.
  virtual bool BackpropNeedsOutput() const { return true; } // if this returns false,
  // the "out_value" to Backprop may be a dummy variable.
  
  /// Read component from stream
  static Component* ReadNew(std::istream &is, bool binary);

  /// Copy component (deep copy).
  virtual Component* Copy() const = 0;

  /// Initialize the Component from one line that will contain
  /// first the type, e.g. SigmoidComponent, and then
  /// a number of tokens (typically integers or floats) that will
  /// be used to initialize the component.
  static Component *NewFromString(const std::string &initializer_line);

  /// Return a new Component of the given type e.g. "SoftmaxComponent",
253
  /// or NULL if no such type exists. 
254 255 256 257 258 259 260 261 262 263 264 265 266
  static Component *NewComponentOfType(const std::string &type);
  
  virtual void Read(std::istream &is, bool binary) = 0; // This Read function
  // requires that the Component has the correct type.
  
  /// Write component to stream
  virtual void Write(std::ostream &os, bool binary) const = 0;

  virtual std::string Info() const;

  virtual ~Component() { }

 private:
267
  int32 index_;
268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295
  KALDI_DISALLOW_COPY_AND_ASSIGN(Component);
};


/**
 * Class UpdatableComponent is a Component which has
 * trainable parameters and contains some global 
 * parameters for stochastic gradient descent
 * (learning rate, L2 regularization constant).
 * This is a base-class for Components with parameters.
 */
class UpdatableComponent: public Component {
 public:
  UpdatableComponent(const UpdatableComponent &other):
      learning_rate_(other.learning_rate_){ }
  
  void Init(BaseFloat learning_rate) {
    learning_rate_ = learning_rate;
  }
  UpdatableComponent(BaseFloat learning_rate) {
    Init(learning_rate);
  }

  /// Set parameters to zero, and if treat_as_gradient is true, we'll be
  /// treating this as a gradient so set the learning rate to 1 and make any
  /// other changes necessary (there's a variable we have to set for the
  /// MixtureProbComponent).
  virtual void SetZero(bool treat_as_gradient) = 0;
296
  
297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324
  UpdatableComponent(): learning_rate_(0.001) { }
  
  virtual ~UpdatableComponent() { }

  /// Here, "other" is a component of the same specific type.  This
  /// function computes the dot product in parameters, and is computed while
  /// automatically adjusting learning rates; typically, one of the two will
  /// actually contain the gradient.
  virtual BaseFloat DotProduct(const UpdatableComponent &other) const = 0;
  
  /// We introduce a new virtual function that only applies to
  /// class UpdatableComponent.  This is used in testing.
  virtual void PerturbParams(BaseFloat stddev) = 0;
  
  /// This new virtual function scales the parameters
  /// by this amount.  
  virtual void Scale(BaseFloat scale) = 0;

  /// This new virtual function adds the parameters of another
  /// updatable component, times some constant, to the current
  /// parameters.
  virtual void Add(BaseFloat alpha, const UpdatableComponent &other) = 0;
  
  /// Sets the learning rate of gradient descent
  void SetLearningRate(BaseFloat lrate) {  learning_rate_ = lrate; }
  /// Gets the learning rate of gradient descent
  BaseFloat LearningRate() const { return learning_rate_; }

325 326
  virtual std::string Info() const;
  
327 328 329 330 331 332 333
  // The next few functions are not implemented everywhere; they are
  // intended for use by L-BFGS code, and we won't implement them
  // for all child classes.
  
  /// The following new virtual function returns the total dimension of
  /// the parameters in this class.  E.g. used for L-BFGS update
  virtual int32 GetParameterDim() const { KALDI_ASSERT(0); return 0; }
334 335 336 337

  /// Turns the parameters into vector form.  We put the vector form on the CPU,
  /// because in the kinds of situations where we do this, we'll tend to use
  /// too much memory for the GPU.
338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369
  virtual void Vectorize(VectorBase<BaseFloat> *params) const { KALDI_ASSERT(0); }
  /// Converts the parameters from vector form.
  virtual void UnVectorize(const VectorBase<BaseFloat> &params) {
    KALDI_ASSERT(0);
  }
  
 protected: 
  BaseFloat learning_rate_; ///< learning rate (0.0..0.01)
 private:
  const UpdatableComponent &operator = (const UpdatableComponent &other); // Disallow.
};

/// This kind of Component is a base-class for things like
/// sigmoid and softmax.
class NonlinearComponent: public Component {
 public:
  void Init(int32 dim) { dim_ = dim; count_ = 0.0; }
  explicit NonlinearComponent(int32 dim) { Init(dim); }
  NonlinearComponent(): dim_(0) { } // e.g. prior to Read().
  explicit NonlinearComponent(const NonlinearComponent &other);
  
  virtual int32 InputDim() const { return dim_; }
  virtual int32 OutputDim() const { return dim_; }
  
  /// We implement InitFromString at this level.
  virtual void InitFromString(std::string args);
  
  /// We implement Read at this level as it just needs the Type().
  virtual void Read(std::istream &is, bool binary);
  
  /// Write component to stream.
  virtual void Write(std::ostream &os, bool binary) const;
370 371 372 373
  
  void Scale(BaseFloat scale); // relates to scaling stats, not parameters.
  void Add(BaseFloat alpha, const NonlinearComponent &other); // relates to
                                                              // adding stats
374

375 376
  // The following functions are unique to NonlinearComponent.
  // They mostly relate to diagnostics.
377 378
  const CuVector<double> &ValueSum() const { return value_sum_; }
  const CuVector<double> &DerivSum() const { return deriv_sum_; }
379
  double Count() const { return count_; }
380 381 382 383

  // The following function is used when "widening" neural networks.
  void SetDim(int32 dim);
  
384
 protected:
385
  friend class NormalizationComponent;
386 387 388
  friend class SigmoidComponent;
  friend class TanhComponent;
  friend class SoftmaxComponent;
389
  friend class LogSoftmaxComponent;
390 391
  friend class RectifiedLinearComponent;
  friend class SoftHingeComponent;
392
  
393

394 395 396
  // This function updates the stats "value_sum_", "deriv_sum_", and
  // count_. (If deriv == NULL, it won't update "deriv_sum_").
  // It will be called from the Backprop function of child classes.
397 398
  void UpdateStats(const CuMatrixBase<BaseFloat> &out_value,
                   const CuMatrixBase<BaseFloat> *deriv = NULL);
399

400 401 402
  
  const NonlinearComponent &operator = (const NonlinearComponent &other); // Disallow.
  int32 dim_;
403 404
  CuVector<double> value_sum_; // stats at the output.
  CuVector<double> deriv_sum_; // stats of the derivative of the nonlinearity (only
405 406
  // applicable to element-by-element nonlinearities, not Softmax.
  double count_;
407 408
  // The mutex is used in UpdateStats, only for resizing vectors.
  Mutex mutex_;
409 410
};

411 412 413 414 415 416 417 418 419 420 421
class MaxoutComponent: public Component {
 public:
  void Init(int32 input_dim, int32 output_dim);
  explicit MaxoutComponent(int32 input_dim, int32 output_dim) {
    Init(input_dim, output_dim);
  }
  MaxoutComponent(): input_dim_(0), output_dim_(0) { }
  virtual std::string Type() const { return "MaxoutComponent"; }
  virtual void InitFromString(std::string args); 
  virtual int32 InputDim() const { return input_dim_; }
  virtual int32 OutputDim() const { return output_dim_; }
422
  using Component::Propagate; // to avoid name hiding
423 424 425
  virtual void Propagate(const ChunkInfo &in_info,
                         const ChunkInfo &out_info,
                         const CuMatrixBase<BaseFloat> &in,
426
                         CuMatrixBase<BaseFloat> *out) const; 
427 428 429 430
  virtual void Backprop(const ChunkInfo &in_info,
                        const ChunkInfo &out_info,
                        const CuMatrixBase<BaseFloat> &in_value,
                        const CuMatrixBase<BaseFloat> &,  //out_value,                        
431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450
                        const CuMatrixBase<BaseFloat> &out_deriv,
                        Component *to_update, // may be identical to "this".
                        CuMatrix<BaseFloat> *in_deriv) const;
  virtual bool BackpropNeedsInput() const { return true; }
  virtual bool BackpropNeedsOutput() const { return true; }
  virtual Component* Copy() const { return new MaxoutComponent(input_dim_,
                                                              output_dim_); }
  
  virtual void Read(std::istream &is, bool binary); // This Read function
  // requires that the Component has the correct type.
  
  /// Write component to stream
  virtual void Write(std::ostream &os, bool binary) const;

  virtual std::string Info() const;
 protected:
  int32 input_dim_;
  int32 output_dim_;
};

451 452 453 454 455 456 457 458 459 460 461
class PnormComponent: public Component {
 public:
  void Init(int32 input_dim, int32 output_dim, BaseFloat p);
  explicit PnormComponent(int32 input_dim, int32 output_dim, BaseFloat p) {
    Init(input_dim, output_dim, p);
  }
  PnormComponent(): input_dim_(0), output_dim_(0), p_(0) { }
  virtual std::string Type() const { return "PnormComponent"; }
  virtual void InitFromString(std::string args); 
  virtual int32 InputDim() const { return input_dim_; }
  virtual int32 OutputDim() const { return output_dim_; }
462
  using Component::Propagate; // to avoid name hiding
463 464 465
  virtual void Propagate(const ChunkInfo &in_info,
                         const ChunkInfo &out_info,
                         const CuMatrixBase<BaseFloat> &in,
466
                         CuMatrixBase<BaseFloat> *out) const; 
467 468 469 470
  virtual void Backprop(const ChunkInfo &in_info,
                        const ChunkInfo &out_info,
                        const CuMatrixBase<BaseFloat> &in_value,
                        const CuMatrixBase<BaseFloat> &,  //out_value,                        
471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498
                        const CuMatrixBase<BaseFloat> &out_deriv,
                        Component *to_update, // may be identical to "this".
                        CuMatrix<BaseFloat> *in_deriv) const;
  virtual bool BackpropNeedsInput() const { return true; }
  virtual bool BackpropNeedsOutput() const { return true; }
  virtual Component* Copy() const { return new PnormComponent(input_dim_,
                                                              output_dim_, p_); }
  
  virtual void Read(std::istream &is, bool binary); // This Read function
  // requires that the Component has the correct type.
  
  /// Write component to stream
  virtual void Write(std::ostream &os, bool binary) const;

  virtual std::string Info() const;
 protected:
  int32 input_dim_;
  int32 output_dim_;
  BaseFloat p_;
};

class NormalizeComponent: public NonlinearComponent {
 public:
  explicit NormalizeComponent(int32 dim): NonlinearComponent(dim) { }
  explicit NormalizeComponent(const NormalizeComponent &other): NonlinearComponent(other) { }
  NormalizeComponent() { }
  virtual std::string Type() const { return "NormalizeComponent"; }
  virtual Component* Copy() const { return new NormalizeComponent(*this); }
499 500
  virtual bool BackpropNeedsInput() const { return true; }
  virtual bool BackpropNeedsOutput() const { return true; }
501
  using Component::Propagate; // to avoid name hiding
502 503 504
  virtual void Propagate(const ChunkInfo &in_info,
                         const ChunkInfo &out_info,
                         const CuMatrixBase<BaseFloat> &in,
505
                         CuMatrixBase<BaseFloat> *out) const; 
506 507 508 509
  virtual void Backprop(const ChunkInfo &in_info,
                        const ChunkInfo &out_info,
                        const CuMatrixBase<BaseFloat> &in_value,
                        const CuMatrixBase<BaseFloat> &out_value,                        
510 511 512 513 514
                        const CuMatrixBase<BaseFloat> &out_deriv,
                        Component *to_update, // may be identical to "this".
                        CuMatrix<BaseFloat> *in_deriv) const;
 private:
  NormalizeComponent &operator = (const NormalizeComponent &other); // Disallow.
515 516 517 518
  static const BaseFloat kNormFloor;
  // about 0.7e-20.  We need a value that's exactly representable in
  // float and whose inverse square root is also exactly representable
  // in float (hence, an even power of two).
519 520
};

521

522 523 524 525 526 527
class SigmoidComponent: public NonlinearComponent {
 public:
  explicit SigmoidComponent(int32 dim): NonlinearComponent(dim) { }
  explicit SigmoidComponent(const SigmoidComponent &other): NonlinearComponent(other) { }    
  SigmoidComponent() { }
  virtual std::string Type() const { return "SigmoidComponent"; }
528 529
  virtual bool BackpropNeedsInput() const { return false; }
  virtual bool BackpropNeedsOutput() const { return true; }
530
  virtual Component* Copy() const { return new SigmoidComponent(*this); }
531
  using Component::Propagate; // to avoid name hiding
532 533 534
  virtual void Propagate(const ChunkInfo &in_info,
                         const ChunkInfo &out_info,
                         const CuMatrixBase<BaseFloat> &in,
535
                         CuMatrixBase<BaseFloat> *out) const; 
536 537 538 539
  virtual void Backprop(const ChunkInfo &in_info,
                        const ChunkInfo &out_info,
                        const CuMatrixBase<BaseFloat> &in_value,
                        const CuMatrixBase<BaseFloat> &out_value,                        
540
                        const CuMatrixBase<BaseFloat> &out_deriv,
541
                        Component *to_update, // may be identical to "this".
542
                        CuMatrix<BaseFloat> *in_deriv) const;
543 544 545 546 547 548 549 550 551 552 553
 private:
  SigmoidComponent &operator = (const SigmoidComponent &other); // Disallow.
};

class TanhComponent: public NonlinearComponent {
 public:
  explicit TanhComponent(int32 dim): NonlinearComponent(dim) { }
  explicit TanhComponent(const TanhComponent &other): NonlinearComponent(other) { }
  TanhComponent() { }
  virtual std::string Type() const { return "TanhComponent"; }
  virtual Component* Copy() const { return new TanhComponent(*this); }
554 555
  virtual bool BackpropNeedsInput() const { return false; }
  virtual bool BackpropNeedsOutput() const { return true; }
556
  using Component::Propagate; // to avoid name hiding
557 558 559
  virtual void Propagate(const ChunkInfo &in_info,
                         const ChunkInfo &out_info,
                         const CuMatrixBase<BaseFloat> &in,
560
                         CuMatrixBase<BaseFloat> *out) const; 
561 562 563 564
  virtual void Backprop(const ChunkInfo &in_info,
                        const ChunkInfo &out_info,
                        const CuMatrixBase<BaseFloat> &in_value,
                        const CuMatrixBase<BaseFloat> &out_value,                        
565
                        const CuMatrixBase<BaseFloat> &out_deriv,
566
                        Component *to_update, // may be identical to "this".
567
                        CuMatrix<BaseFloat> *in_deriv) const;
568 569 570 571
 private:
  TanhComponent &operator = (const TanhComponent &other); // Disallow.
};

572 573 574 575 576 577 578 579 580 581 582 583 584
/// Take the absoute values of an input vector to a power.
/// The derivative for zero input will be treated as zero.
class PowerComponent: public NonlinearComponent {
 public:
  void Init(int32 dim, BaseFloat power = 2);
  explicit PowerComponent(int32 dim, BaseFloat power = 2) {
    Init(dim, power);
  }
  PowerComponent(): dim_(0), power_(2) { }
  virtual std::string Type() const { return "PowerComponent"; }
  virtual void InitFromString(std::string args); 
  virtual int32 InputDim() const { return dim_; }
  virtual int32 OutputDim() const { return dim_; }
585
  using Component::Propagate; // to avoid name hiding
586 587 588
  virtual void Propagate(const ChunkInfo &in_info,
                         const ChunkInfo &out_info,
                         const CuMatrixBase<BaseFloat> &in,
589
                         CuMatrixBase<BaseFloat> *out) const; 
590 591 592 593
  virtual void Backprop(const ChunkInfo &in_info,
                        const ChunkInfo &out_info,
                        const CuMatrixBase<BaseFloat> &in_value,
                        const CuMatrixBase<BaseFloat> &out_value,                        
594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612
                        const CuMatrixBase<BaseFloat> &out_deriv,
                        Component *to_update, // may be identical to "this".
                        CuMatrix<BaseFloat> *in_deriv) const;
  virtual bool BackpropNeedsInput() const { return true; }
  virtual bool BackpropNeedsOutput() const { return true; }
  virtual Component* Copy() const { return new PowerComponent(dim_, power_); }
  virtual void Read(std::istream &is, bool binary); // This Read function
  // requires that the Component has the correct type.
  
  /// Write component to stream
  virtual void Write(std::ostream &os, bool binary) const;

  virtual std::string Info() const;

 private:
  int32 dim_;
  BaseFloat power_;
};

613 614 615 616 617 618 619
class RectifiedLinearComponent: public NonlinearComponent {
 public:
  explicit RectifiedLinearComponent(int32 dim): NonlinearComponent(dim) { }
  explicit RectifiedLinearComponent(const RectifiedLinearComponent &other): NonlinearComponent(other) { }
  RectifiedLinearComponent() { }
  virtual std::string Type() const { return "RectifiedLinearComponent"; }
  virtual Component* Copy() const { return new RectifiedLinearComponent(*this); }
620 621
  virtual bool BackpropNeedsInput() const { return false; }
  virtual bool BackpropNeedsOutput() const { return true; }
622
  using Component::Propagate; // to avoid name hiding
623 624 625
  virtual void Propagate(const ChunkInfo &in_info,
                         const ChunkInfo &out_info,
                         const CuMatrixBase<BaseFloat> &in,
626
                         CuMatrixBase<BaseFloat> *out) const; 
627 628 629 630
  virtual void Backprop(const ChunkInfo &in_info,
                        const ChunkInfo &out_info,
                        const CuMatrixBase<BaseFloat> &in_value,
                        const CuMatrixBase<BaseFloat> &out_value,                        
631
                        const CuMatrixBase<BaseFloat> &out_deriv,
632
                        Component *to_update, // may be identical to "this".
633
                        CuMatrix<BaseFloat> *in_deriv) const;
634 635 636 637 638 639 640 641 642 643 644
 private:
  RectifiedLinearComponent &operator = (const RectifiedLinearComponent &other); // Disallow.
};

class SoftHingeComponent: public NonlinearComponent {
 public:
  explicit SoftHingeComponent(int32 dim): NonlinearComponent(dim) { }
  explicit SoftHingeComponent(const SoftHingeComponent &other): NonlinearComponent(other) { }
  SoftHingeComponent() { }
  virtual std::string Type() const { return "SoftHingeComponent"; }
  virtual Component* Copy() const { return new SoftHingeComponent(*this); }
645 646
  virtual bool BackpropNeedsInput() const { return true; }
  virtual bool BackpropNeedsOutput() const { return true; }
647
  using Component::Propagate; // to avoid name hiding
648 649 650
  virtual void Propagate(const ChunkInfo &in_info,
                         const ChunkInfo &out_info,
                         const CuMatrixBase<BaseFloat> &in,
651
                         CuMatrixBase<BaseFloat> *out) const; 
652 653 654 655
  virtual void Backprop(const ChunkInfo &in_info,
                        const ChunkInfo &out_info,
                        const CuMatrixBase<BaseFloat> &in_value,
                        const CuMatrixBase<BaseFloat> &out_value,                        
656
                        const CuMatrixBase<BaseFloat> &out_deriv,
657
                        Component *to_update, // may be identical to "this".
658
                        CuMatrix<BaseFloat> *in_deriv) const;
659 660 661 662
 private:
  SoftHingeComponent &operator = (const SoftHingeComponent &other); // Disallow.
};

663 664 665 666 667 668 669 670 671 672 673 674

// This class scales the input by a specified constant.  This is, of course,
// useless, but we use it when we want to change how fast the next layer learns.
// (e.g. a smaller scale will make the next layer learn slower.)
class ScaleComponent: public Component {
 public:
  explicit ScaleComponent(int32 dim, BaseFloat scale): dim_(dim), scale_(scale) { }
  explicit ScaleComponent(const ScaleComponent &other):
      dim_(other.dim_), scale_(other.scale_) { }
  ScaleComponent(): dim_(0), scale_(0.0) { }
  virtual std::string Type() const { return "ScaleComponent"; }
  virtual Component* Copy() const { return new ScaleComponent(*this); }
675 676
  virtual bool BackpropNeedsInput() const { return false; }
  virtual bool BackpropNeedsOutput() const { return false; }
677
  using Component::Propagate; // to avoid name hiding
678 679 680
  virtual void Propagate(const ChunkInfo &in_info,
                         const ChunkInfo &out_info,
                         const CuMatrixBase<BaseFloat> &in,
681
                         CuMatrixBase<BaseFloat> *out) const; 
682 683 684 685
  virtual void Backprop(const ChunkInfo &in_info,
                        const ChunkInfo &out_info,
                        const CuMatrixBase<BaseFloat> &in_value,
                        const CuMatrixBase<BaseFloat> &out_value,                        
686
                        const CuMatrixBase<BaseFloat> &out_deriv,
687
                        Component *to_update, // may be identical to "this".
688
                        CuMatrix<BaseFloat> *in_deriv) const;
689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708

  virtual int32 InputDim() const { return dim_; }
  virtual int32 OutputDim() const { return dim_; }
  virtual void Read(std::istream &is, bool binary);
  
  virtual void Write(std::ostream &os, bool binary) const;

  void Init(int32 dim, BaseFloat scale);
  
  virtual void InitFromString(std::string args); 

  virtual std::string Info() const;
  
 private:
  int32 dim_;
  BaseFloat scale_;
  ScaleComponent &operator = (const ScaleComponent &other); // Disallow.
};


709

710
class SumGroupComponent; // Forward declaration.
711
class AffineComponent; // Forward declaration.
712
class FixedScaleComponent; // Forward declaration.
713 714 715 716 717 718

class SoftmaxComponent: public NonlinearComponent {
 public:
  explicit SoftmaxComponent(int32 dim): NonlinearComponent(dim) { }
  explicit SoftmaxComponent(const SoftmaxComponent &other): NonlinearComponent(other) { }  
  SoftmaxComponent() { }
719
  virtual std::string Type() const { return "SoftmaxComponent"; }
720 721
  virtual bool BackpropNeedsInput() const { return false; }
  virtual bool BackpropNeedsOutput() const { return true; }
722
  using Component::Propagate; // to avoid name hiding
723 724 725
  virtual void Propagate(const ChunkInfo &in_info,
                         const ChunkInfo &out_info,
                         const CuMatrixBase<BaseFloat> &in,
726
                         CuMatrixBase<BaseFloat> *out) const; 
727 728 729 730
  virtual void Backprop(const ChunkInfo &in_info,
                        const ChunkInfo &out_info,
                        const CuMatrixBase<BaseFloat> &in_value,
                        const CuMatrixBase<BaseFloat> &out_value,                        
731
                        const CuMatrixBase<BaseFloat> &out_deriv,
732
                        Component *to_update, // may be identical to "this".
733
                        CuMatrix<BaseFloat> *in_deriv) const;
734
  
735
  void MixUp(int32 num_mixtures,
736 737 738 739
             BaseFloat power,
             BaseFloat min_count,
             BaseFloat perturb_stddev,
             AffineComponent *ac,
740 741
             SumGroupComponent *sc);
  
742 743 744 745 746
  virtual Component* Copy() const { return new SoftmaxComponent(*this); }
 private:
  SoftmaxComponent &operator = (const SoftmaxComponent &other); // Disallow.
};

747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772
class LogSoftmaxComponent: public NonlinearComponent {
 public:
  explicit LogSoftmaxComponent(int32 dim): NonlinearComponent(dim) { }
  explicit LogSoftmaxComponent(const LogSoftmaxComponent &other): NonlinearComponent(other) { }  
  LogSoftmaxComponent() { }
  virtual std::string Type() const { return "LogSoftmaxComponent"; }
  virtual bool BackpropNeedsInput() const { return false; }
  virtual bool BackpropNeedsOutput() const { return true; }
  using Component::Propagate; // to avoid name hiding
  virtual void Propagate(const ChunkInfo &in_info,
                         const ChunkInfo &out_info,
                         const CuMatrixBase<BaseFloat> &in,
                         CuMatrixBase<BaseFloat> *out) const; 
  virtual void Backprop(const ChunkInfo &in_info,
                        const ChunkInfo &out_info,
                        const CuMatrixBase<BaseFloat> &in_value,
                        const CuMatrixBase<BaseFloat> &out_value,                        
                        const CuMatrixBase<BaseFloat> &out_deriv,
                        Component *to_update, // may be identical to "this".
                        CuMatrix<BaseFloat> *in_deriv) const;
 
  virtual Component* Copy() const { return new LogSoftmaxComponent(*this); }
 private:
  LogSoftmaxComponent &operator = (const LogSoftmaxComponent &other); // Disallow.
};

773

774
class FixedAffineComponent;
775 776 777 778 779 780 781 782 783

// Affine means a linear function plus an offset.
// Note: although this class can be instantiated, it also
// function as a base-class for more specialized versions of
// AffineComponent.
class AffineComponent: public UpdatableComponent {
  friend class SoftmaxComponent; // Friend declaration relates to mixing up.
 public:
  explicit AffineComponent(const AffineComponent &other);
784
  // The next constructor is used in converting from nnet1.
785 786
  AffineComponent(const CuMatrixBase<BaseFloat> &linear_params,
                  const CuVectorBase<BaseFloat> &bias_params,
787 788
                  BaseFloat learning_rate);
  
789 790 791 792
  virtual int32 InputDim() const { return linear_params_.NumCols(); }
  virtual int32 OutputDim() const { return linear_params_.NumRows(); }
  void Init(BaseFloat learning_rate,
            int32 input_dim, int32 output_dim,
793 794 795 796
            BaseFloat param_stddev, BaseFloat bias_stddev);
  void Init(BaseFloat learning_rate,
            std::string matrix_filename);

Dan Povey's avatar
Dan Povey committed
797 798 799 800
  // This function resizes the dimensions of the component, setting the
  // parameters to zero, while leaving any other configuration values the same.
  virtual void Resize(int32 input_dim, int32 output_dim);

801 802 803 804 805 806
  // The following functions are used for collapsing multiple layers
  // together.  They return a pointer to a new Component equivalent to
  // the sequence of two components.  We haven't implemented this for
  // FixedLinearComponent yet.
  Component *CollapseWithNext(const AffineComponent &next) const ;
  Component *CollapseWithNext(const FixedAffineComponent &next) const;
807
  Component *CollapseWithNext(const FixedScaleComponent &next) const;
808 809
  Component *CollapseWithPrevious(const FixedAffineComponent &prev) const;

810 811 812 813 814 815 816
  virtual std::string Info() const;
  virtual void InitFromString(std::string args);
  
  AffineComponent(): is_gradient_(false) { } // use Init to really initialize.
  virtual std::string Type() const { return "AffineComponent"; }
  virtual bool BackpropNeedsInput() const { return true; }
  virtual bool BackpropNeedsOutput() const { return false; }
817
  using Component::Propagate; // to avoid name hiding
818 819 820
  virtual void Propagate(const ChunkInfo &in_info,
                         const ChunkInfo &out_info,
                         const CuMatrixBase<BaseFloat> &in,
821
                         CuMatrixBase<BaseFloat> *out) const; 
822 823
  virtual void Scale(BaseFloat scale);
  virtual void Add(BaseFloat alpha, const UpdatableComponent &other);
824 825 826 827
  virtual void Backprop(const ChunkInfo &in_info,
                        const ChunkInfo &out_info,
                        const CuMatrixBase<BaseFloat> &in_value,
                        const CuMatrixBase<BaseFloat> &out_value,                        
828
                        const CuMatrixBase<BaseFloat> &out_deriv,
829
                        Component *to_update, // may be identical to "this".
830
                        CuMatrix<BaseFloat> *in_deriv) const;
831 832 833 834 835 836 837 838 839
  virtual void SetZero(bool treat_as_gradient);
  virtual void Read(std::istream &is, bool binary);
  virtual void Write(std::ostream &os, bool binary) const;
  virtual BaseFloat DotProduct(const UpdatableComponent &other) const;
  virtual Component* Copy() const;
  virtual void PerturbParams(BaseFloat stddev);
  // This new function is used when mixing up:
  virtual void SetParams(const VectorBase<BaseFloat> &bias,
                         const MatrixBase<BaseFloat> &linear);
840 841
  const CuVector<BaseFloat> &BiasParams() { return bias_params_; }
  const CuMatrix<BaseFloat> &LinearParams() { return linear_params_; }
842 843 844 845

  virtual int32 GetParameterDim() const;
  virtual void Vectorize(VectorBase<BaseFloat> *params) const;
  virtual void UnVectorize(const VectorBase<BaseFloat> &params);
846 847 848 849 850 851 852 853 854 855 856 857 858

  /// This function is for getting a low-rank approximations of this
  /// AffineComponent by two AffineComponents.
  virtual void LimitRank(int32 dimension,
                         AffineComponent **a, AffineComponent **b) const;

  /// This function is implemented in widen-nnet.cc
  void Widen(int32 new_dimension,
             BaseFloat param_stddev,
             BaseFloat bias_stddev,
             std::vector<NonlinearComponent*> c2, // will usually have just one
                                                  // element.
             AffineComponent *c3);
859
 protected:
860
  friend class AffineComponentPreconditionedOnline;
861 862
  // This function Update() is for extensibility; child classes may override this.
  virtual void Update(
863 864
      const CuMatrixBase<BaseFloat> &in_value,
      const CuMatrixBase<BaseFloat> &out_deriv) {
865 866 867 868 869
    UpdateSimple(in_value, out_deriv);
  }
  // UpdateSimple is used when *this is a gradient.  Child classes may
  // or may not override this.
  virtual void UpdateSimple(
870 871
      const CuMatrixBase<BaseFloat> &in_value,
      const CuMatrixBase<BaseFloat> &out_deriv);  
872 873

  const AffineComponent &operator = (const AffineComponent &other); // Disallow.
874 875
  CuMatrix<BaseFloat> linear_params_;
  CuVector<BaseFloat> bias_params_;
876 877 878 879

  bool is_gradient_; // If true, treat this as just a gradient.
};

880

881 882 883
// This is an idea Dan is trying out, a little bit like
// preconditioning the update with the Fisher matrix, but the
// Fisher matrix has a special structure.
884
// [note: it is currently used in the standard recipe].
885 886 887 888 889 890 891 892 893
class AffineComponentPreconditioned: public AffineComponent {
 public:
  virtual std::string Type() const { return "AffineComponentPreconditioned"; }

  virtual void Read(std::istream &is, bool binary);
  virtual void Write(std::ostream &os, bool binary) const;
  void Init(BaseFloat learning_rate,
            int32 input_dim, int32 output_dim,
            BaseFloat param_stddev, BaseFloat bias_stddev,
894 895 896 897
            BaseFloat alpha, BaseFloat max_change);
  void Init(BaseFloat learning_rate, BaseFloat alpha,
            BaseFloat max_change, std::string matrix_filename);
  
898 899 900
  virtual void InitFromString(std::string args);
  virtual std::string Info() const;
  virtual Component* Copy() const;
901
  AffineComponentPreconditioned(): alpha_(1.0), max_change_(0.0) { }
902
  void SetMaxChange(BaseFloat max_change) { max_change_ = max_change; }
903
 protected:
904 905
  KALDI_DISALLOW_COPY_AND_ASSIGN(AffineComponentPreconditioned);
  BaseFloat alpha_;
906 907 908 909 910
  BaseFloat max_change_; // If > 0, this is the maximum amount of parameter change (in L2 norm)
                         // that we allow per minibatch.  This was introduced in order to
                         // control instability.  Instead of the exact L2 parameter change,
                         // for efficiency purposes we limit a bound on the exact change.
                         // The limit is applied via a constant <= 1.0 for each minibatch,
911
                         // A suitable value might be, for example, 10 or so; larger if there are
912 913 914 915
                         // more parameters.

  /// The following function is only called if max_change_ > 0.  It returns the
  /// greatest value alpha <= 1.0 such that (alpha times the sum over the
916
  /// row-index of the two matrices of the product the l2 norms of the two rows
917 918
  /// times learning_rate_)
  /// is <= max_change.
919 920
  BaseFloat GetScalingFactor(const CuMatrix<BaseFloat> &in_value_precon,
                             const CuMatrix<BaseFloat> &out_deriv_precon);
921

922
  virtual void Update(
923 924
      const CuMatrixBase<BaseFloat> &in_value,
      const CuMatrixBase<BaseFloat> &out_deriv);
925 926 927
};


928 929 930
/// Keywords: natural gradient descent, NG-SGD, naturalgradient.  For
/// the top-level of the natural gradient code look here, and also in
/// nnet-precondition-online.h.
931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951
/// AffineComponentPreconditionedOnline is, like AffineComponentPreconditioned,
/// a version of AffineComponent that has a non-(multiple of unit) learning-rate
/// matrix.  See nnet-precondition-online.h for a description of the technique.
/// This method maintains an orthogonal matrix N with a small number of rows,
/// actually two (for input and output dims) which gets modified each time;
/// we maintain a mutex for access to this (we just use it to copy it when
/// we need it and write to it when we change it).  For multi-threaded use,
/// the parallelization method is to lock a mutex whenever we want to
/// read N or change it, but just quickly make a copy and release the mutex;
/// this is to ensure operations on N are atomic.
class AffineComponentPreconditionedOnline: public AffineComponent {
 public:
  virtual std::string Type() const {
    return "AffineComponentPreconditionedOnline";
  }

  virtual void Read(std::istream &is, bool binary);
  virtual void Write(std::ostream &os, bool binary) const;
  void Init(BaseFloat learning_rate,
            int32 input_dim, int32 output_dim,
            BaseFloat param_stddev, BaseFloat bias_stddev,
952
            int32 rank_in, int32 rank_out, int32 update_period,
953
            BaseFloat num_samples_history, BaseFloat alpha,
954
            BaseFloat max_change_per_sample);
955
  void Init(BaseFloat learning_rate, int32 rank_in,
956 957
            int32 rank_out, int32 update_period,
            BaseFloat num_samples_history,
958 959 960
            BaseFloat alpha, BaseFloat max_change_per_sample,
            std::string matrix_filename);

Dan Povey's avatar
Dan Povey committed
961 962
  virtual void Resize(int32 input_dim, int32 output_dim);
  
963 964
  // This constructor is used when converting neural networks partway through
  // training, from AffineComponent or AffineComponentPreconditioned to
965
  // AffineComponentPreconditionedOnline.
966
  AffineComponentPreconditionedOnline(const AffineComponent &orig,
967
                                      int32 rank_in, int32 rank_out,
968
                                      int32 update_period,
969
                                      BaseFloat eta, BaseFloat alpha);
970 971 972 973
  
  virtual void InitFromString(std::string args);
  virtual std::string Info() const;
  virtual Component* Copy() const;
974
  AffineComponentPreconditionedOnline(): max_change_per_sample_(0.0) { }
975 976 977 978

 private:
  KALDI_DISALLOW_COPY_AND_ASSIGN(AffineComponentPreconditionedOnline);

979

980 981 982 983
  // Configs for preconditioner.  The input side tends to be better conditioned ->
  // smaller rank needed, so make them separately configurable.
  int32 rank_in_;
  int32 rank_out_;
984
  int32 update_period_;
985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017
  BaseFloat num_samples_history_;
  BaseFloat alpha_;
  
  OnlinePreconditioner preconditioner_in_;

  OnlinePreconditioner preconditioner_out_;

  BaseFloat max_change_per_sample_;
  // If > 0, max_change_per_sample_ this is the maximum amount of parameter
  // change (in L2 norm) that we allow per sample, averaged over the minibatch.
  // This was introduced in order to control instability.
  // Instead of the exact L2 parameter change, for
  // efficiency purposes we limit a bound on the exact
  // change.  The limit is applied via a constant <= 1.0
  // for each minibatch, A suitable value might be, for
  // example, 10 or so; larger if there are more
  // parameters.

  /// The following function is only called if max_change_per_sample_ > 0, it returns a
  /// scaling factor alpha <= 1.0 (1.0 in the normal case) that enforces the
  /// "max-change" constraint.  "in_products" is the inner product with itself
  /// of each row of the matrix of preconditioned input features; "out_products"
  /// is the same for the output derivatives.  gamma_prod is a product of two
  /// scalars that are output by the preconditioning code (for the input and
  /// output), which we will need to multiply into the learning rate.
  /// out_products is a pointer because we modify it in-place.
  BaseFloat GetScalingFactor(const CuVectorBase<BaseFloat> &in_products,
                             BaseFloat gamma_prod,
                             CuVectorBase<BaseFloat> *out_products);

  // Sets the configs rank, alpha and eta in the preconditioner objects,
  // from the class variables.
  void SetPreconditionerConfigs();
1018 1019 1020 1021 1022 1023

  virtual void Update(
      const CuMatrixBase<BaseFloat> &in_value,
      const CuMatrixBase<BaseFloat> &out_deriv);
};

1024 1025 1026 1027
class RandomComponent: public Component {
 public:
  // This function is required in testing code and in other places we need
  // consistency in the random number generation (e.g. when optimizing
1028
  // validation-set performance), but check where else we call sRand().  You'll
1029 1030 1031 1032 1033 1034
  // need to call srand as well as making this call.  
  void ResetGenerator() { random_generator_.SeedGpu(0); }
 protected:
  CuRand<BaseFloat> random_generator_;
};

1035
/// Splices a context window of frames together [over time]
1036 1037 1038
class SpliceComponent: public Component {
 public:
  SpliceComponent() { }  // called only prior to Read() or Init().
1039 1040 1041
  // Note: it is required that the elements of "context" be in
  // strictly increasing order, that the lowest element of component
  // be nonpositive, and the highest element be nonnegative.
1042
  void Init(int32 input_dim,
1043
            std::vector<int32> context,
1044 1045 1046 1047 1048 1049
            int32 const_component_dim=0);
  virtual std::string Type() const { return "SpliceComponent"; }
  virtual std::string Info() const;
  virtual void InitFromString(std::string args);
  virtual int32 InputDim() const { return input_dim_; }
  virtual int32 OutputDim() const;
1050
  virtual std::vector<int32> Context() const { return context_; }
1051
  using Component::Propagate; // to avoid name hiding
1052 1053 1054
  virtual void Propagate(const ChunkInfo &in_info,
                         const ChunkInfo &out_info,
                         const CuMatrixBase<BaseFloat> &in,
1055
                         CuMatrixBase<BaseFloat> *out) const; 
1056 1057 1058 1059
  virtual void Backprop(const ChunkInfo &in_info,
                        const ChunkInfo &out_info,
                        const CuMatrixBase<BaseFloat> &in_value,
                        const CuMatrixBase<BaseFloat> &out_value,                        
1060
                        const CuMatrixBase<BaseFloat> &out_deriv,
1061
                        Component *to_update, // may be identical to "this".
1062
                        CuMatrix<BaseFloat> *in_deriv) const;
1063 1064 1065 1066 1067 1068 1069 1070
  virtual bool BackpropNeedsInput() const { return false; }
  virtual bool BackpropNeedsOutput() const { return false; }
  virtual Component* Copy() const;
  virtual void Read(std::istream &is, bool binary);
  virtual void Write(std::ostream &os, bool binary) const;
 private:
  KALDI_DISALLOW_COPY_AND_ASSIGN(SpliceComponent);
  int32 input_dim_;
1071
  std::vector<int32> context_;
1072 1073 1074
  int32 const_component_dim_;
};

1075 1076 1077 1078 1079 1080
/// This is as SpliceComponent but outputs the max of
/// any of the inputs (taking the max across time).
class SpliceMaxComponent: public Component {
 public:
  SpliceMaxComponent() { }  // called only prior to Read() or Init().
  void Init(int32 dim,
1081
            std::vector<int32> context);
1082 1083 1084 1085 1086
  virtual std::string Type() const { return "SpliceMaxComponent"; }
  virtual std::string Info() const;
  virtual void InitFromString(std::string args);
  virtual int32 InputDim() const { return dim_; }
  virtual int32 OutputDim() const { return dim_; }
1087
  virtual std::vector<int32> Context() const  { return context_; }
1088
  using Component::Propagate; // to avoid name hiding
1089 1090 1091
  virtual void Propagate(const ChunkInfo &in_info,
                         const ChunkInfo &out_info,
                         const CuMatrixBase<BaseFloat> &in,
1092
                         CuMatrixBase<BaseFloat> *out) const; 
1093 1094 1095 1096
  virtual void Backprop(const ChunkInfo &in_info,