summary | shortlog | log | commit | commitdiff | tree
raw | patch | inline | side by side (parent: e56552b)
raw | patch | inline | side by side (parent: e56552b)
author | Shiyin Kang <kangshiyin@gmail.com> | |
Wed, 23 Nov 2016 14:05:31 +0000 (22:05 +0800) | ||
committer | Shiyin Kang <kangshiyin@gmail.com> | |
Sat, 26 Nov 2016 05:28:05 +0000 (13:28 +0800) |
working on kernel code
compilable kernel code
fix bug
pass unit test and deriv test
make nnet3 compilable.
speed test for backprop lstm
compilable kernel code
fix bug
pass unit test and deriv test
make nnet3 compilable.
speed test for backprop lstm
index b7571383193c32468ad5793477ba92979d98bdf0..bf504347872ab37b545adff98467fda0e7932dea 100644 (file)
const int params_stride, const int out_stride,
const int cell_dim, const int num_rows,
float* out);
+void cudaD_diff_lstm_nonlinearity(dim3 Gr, dim3 Bl, const int cell_dim,
+ const int num_rows, const double* input,
+ const int in_stride, const double* params,
+ const int params_stride,
+ const double* output_deriv,
+ const int output_deriv_stride,
+ const double* deriv_sum_in,
+ const int deriv_sum_in_stride,
+ const double* self_repair_config,
+ double count, double* input_deriv,
+ const int input_deriv_stride,
+ double* params_deriv,
+ const int params_deriv_stride,
+ double* value_sum_out,
+ const int value_sum_out_stride,
+ double* deriv_sum_out,
+ const int deriv_sum_out_stride,
+ double* self_repair_sum_out,
+ const int self_repair_sum_out_stride);
+void cudaF_diff_lstm_nonlinearity(dim3 Gr, dim3 Bl, const int cell_dim,
+ const int num_rows, const float* input,
+ const int in_stride, const float* params,
+ const int params_stride,
+ const float* output_deriv,
+ const int output_deriv_stride,
+ const double* deriv_sum_in,
+ const int deriv_sum_in_stride,
+ const float* self_repair_config, double count,
+ float* input_deriv,
+ const int input_deriv_stride,
+ float* params_deriv,
+ const int params_deriv_stride,
+ double* value_sum_out,
+ const int value_sum_out_stride,
+ double* deriv_sum_out,
+ const int deriv_sum_out_stride,
+ float* self_repair_sum_out,
+ const int self_repair_sum_out_stride);
+
+
} // extern "C"
index 614f8ec4cc6b9f2215ffd4e6d8791dc043100039..3fb38b54c20b97dfcfd6bb6d29866f0893ef9016 100644 (file)
}
+/**
+ This function does the 'backward' pass corresponding to the function
+ ComputeLstmNonlinearity. It's a little more complicated than you might
+ expect because of the 'self-repair' mechanism that we use to prevent the
+ sigmoid and tanh nonlinearities oversaturating, and because of the
+ average-activation and average-derivative stats that we store for these
+ nonlinearites (these stats are used both to control the self-repair
+ mechanism, and for diagnostic purposes).
+
+ Because the forward pass computes various intermediate values that are not
+ output, this function actually has to do the same computations as the
+ forward pass before it actually does the backprop.
+
+ In the following description, `C` is for `cell_dim`, `N` is for `num_rows`.
+
+ @param [in] input The same as in ComputeLstmNonlinearity().
+ A matrix, of dimension N by 5C (i.e. its num-cols must be
+ a multiple of 5). The column-space is interpreted as 5
+ consecutive blocks, each of dimension C, which we name:
+ (i_part, f_part, c_part, o_part, c_{t-1}).
+ @param [in] params The same as in ComputeLstmNonlinearity().
+ A matrix, of dimension 3 by C, with rows containing the
+ three diagonal parameter matrices used in LSTMs, namely
+ w_{ic}, w_{fc} and w_{oc}.
+ @param [in] output_deriv
+ A matrix, of dimension N by 2C, containing the derivative
+ of the objective function we're backpropagating,
+ w.r.t. the quantities c_t and m_t (in two blocks of
+ column-dimension C).
+ @param [in] deriv_sum_in
+ This is used in the self-repair code to identify
+ oversaturated nonlinearities.
+ It is a matrix, of dimension 5 by C, corresponding to
+ the totals of the derivatives of the 5 sigmoid and tanh
+ nonlinearities, in they order they appear in the equations
+ in the documentation of ComputeLstmNonlinearity()
+ respectively,
+ they appear in the equations for (i_t, f_t, c_t, o_t, m_t).
+ This will be divided by 'count_in' to get the average
+ derivative value so far, for each of the nonlinearities.
+ @param [in] self_repair_config
+ A vector of dimension 10, containing the configuration of
+ the self-repair to be used for the 5 nonlinearities.
+ The first 5 elements are the self_repair_lower_threshold
+ values (typically 0.05 for sigmoid and 0.2 for tanh),
+ and the next 5 elements are the corresponding
+ self-repair-scales (typically 10^-5).
+ @param [in] count_in The data-count that corresponds to the stats in
+ 'deriv_sum_in' at entry to the function.
+ This function should tolerate the count being zero
+ (in that case, it is free to do the self-repair or not,
+ as this should only happen on the 1st minibatch of each
+ training job).
+ @param [out] input_deriv
+ May be NULL; if not, this function writes, to this
+ location, the backpropagated derivative of the objective
+ function w.r.t. the 'input' matrix. This matrix should
+ have the same dimension as 'input' i.e. N by 5C. In
+ addition to the regular backpropagated derivative, the
+ output will include small values relating to 'self-repair'.
+ @param [out] params_deriv
+ May be NULL; if not, this is where this function *writes*
+ [not adds] the backpropagated derivative of the objective
+ function w.r.t. 'params'; it should have the same dimension
+ as 'params' (3 by C). (This matrix will then be processed
+ by the natural gradient code and added to the appropriate
+ copy of the parameter matrix, outside this function).
+ @param [out] value_sum_out
+ Must be NULL if params_deriv is NULL; if not, a matrix of
+ dimension 5 by C. This function *adds* to this location
+ the total value of each of the sigmoid/tanh nonlinearities
+ that it computes (this is for diagnostic purposes).
+ @param [out] deriv_sum_out
+ Must be NULL if params_deriv is NULL; if not, a matrix of
+ dimension 5 by C; this function *adds* to this location the
+ total of the derivative of each of the sigmoid/tanh
+ nonlinearities that it computes (this is for diagnostic
+ purposes and to control the self-repair). This function
+ should tolerate the case when 'deriv_sum_out' points to the
+ same data as 'deriv_sum_in'.
+ @param [out] self_repair_sum_out
+ Must be NULL if params_deriv is NULL; if not, a matrix of
+ dimension 5 by C; this function *writes* to this location
+ the sum of the number of times the self-repair code was
+ activated (integer values 0 <= k <= N). This will be
+ processed outside this function into self-repair stats for
+ diagnostics.
+// Use 2D block (8x32 threads) as we need to compute column sum.
+// Use 1D grid to cover the data matrix `cell_dim`.
+*/
+template<typename Real>
+__global__
+static void _diff_lstm_nonlinearity(const int cell_dim, const int num_rows,
+ const Real* input, const int input_stride,
+ const Real* params, const int params_stride,
+ const Real* output_deriv,
+ const int output_deriv_stride,
+ const double* deriv_sum_in,
+ const int deriv_sum_in_stride,
+ const Real* self_repair_config,
+ double count, Real* input_deriv,
+ const int input_deriv_stride,
+ Real* params_deriv,
+ const int params_deriv_stride,
+ double* value_sum_out,
+ const int value_sum_out_stride,
+ double* deriv_sum_out,
+ const int deriv_sum_out_stride,
+ Real* self_repair_sum_out,
+ const int self_repair_sum_out_stride) {
+ __shared__ Real smem[CU1DBLOCK];
+
+ const int j = blockIdx.x * blockDim.x + threadIdx.x;
+ const int tid = threadIdx.y * blockDim.x + threadIdx.x;
+ const int grid_stride = gridDim.y * blockDim.y;
+ const int i0 = blockIdx.y * blockDim.y + threadIdx.y;
+
+ Real w_ic_deriv_sum = 0;
+ Real w_fc_deriv_sum = 0;
+ Real w_oc_deriv_sum = 0;
+
+ Real i_t_value_sum = 0, i_t_deriv_sum = 0;
+ Real f_t_value_sum = 0, f_t_deriv_sum = 0;
+ Real c_part_value_sum = 0, c_part_deriv_sum = 0;
+ Real o_t_value_sum = 0, o_t_deriv_sum = 0;
+ Real c_t_value_sum = 0, c_t_deriv_sum = 0;
+
+ bool update_sr[5];
+
+ if (j < cell_dim) {
+ const Real w_ic = params[j];
+ const Real w_fc = params[params_stride + j];
+ const Real w_oc = params[2 * params_stride + j];
+
+ const Real* sr_config = self_repair_config;
+# pragma unroll
+ for (int i = 0; i < 5; i++) {
+ update_sr[i] = deriv_sum_in[i * deriv_sum_in_stride + j] / count
+ < sr_config[i];
+ }
+ const Real i_t_self_repair = (update_sr[0] ? sr_config[5] : 0);
+ const Real f_t_self_repair = (update_sr[1] ? sr_config[6] : 0);
+ const Real c_part_self_repair = (update_sr[2] ? sr_config[7] : 0);
+ const Real o_t_self_repair = (update_sr[3] ? sr_config[8] : 0);
+ const Real c_t_self_repair = (update_sr[4] ? sr_config[9] : 0);
+
+ for (int i = i0; i < num_rows; i += grid_stride) {
+ const Real i_part = input[i * input_stride + j];
+ const Real f_part = input[i * input_stride + j + cell_dim];
+ const Real c_part = input[i * input_stride + j + 2 * cell_dim];
+ const Real o_part = input[i * input_stride + j + 3 * cell_dim];
+ const Real c_prev = input[i * input_stride + j + 4 * cell_dim];
+
+ const Real i_t = 1 / (1 + exp(-i_part - w_ic * c_prev));
+ const Real f_t = 1 / (1 + exp(-f_part - w_fc * c_prev));
+ const Real tanh_c_part = tanh(c_part);
+ const Real c_t = f_t * c_prev + i_t * tanh_c_part;
+ const Real o_t = 1 / (1 + exp(-o_part - w_oc * c_t));
+ const Real tanh_c_t = tanh(c_t);
+
+ const Real i_t_deriv = i_t * (1 - i_t);
+ const Real f_t_deriv = f_t * (1 - f_t);
+ const Real c_part_deriv = 1 - tanh_c_part * tanh_c_part;
+ const Real o_t_deriv = o_t * (1 - o_t);
+ const Real c_t_deriv = 1 - tanh_c_t * tanh_c_t;
+
+ if (params_deriv) {
+ i_t_value_sum += i_t;
+ f_t_value_sum += f_t;
+ c_part_value_sum += tanh_c_part;
+ o_t_value_sum += o_t;
+ c_t_value_sum += tanh_c_t;
+
+ i_t_deriv_sum += i_t_deriv;
+ f_t_deriv_sum += f_t_deriv;
+ c_part_deriv_sum += c_part_deriv;
+ o_t_deriv_sum += o_t_deriv;
+ c_t_deriv_sum += c_t_deriv;
+ }
+
+ const Real dc_t_out = output_deriv[i * output_deriv_stride + j];
+ const Real dm_t = output_deriv[i * output_deriv_stride + j + cell_dim];
+
+ const Real dtanh_c_t = o_t * dm_t;
+ const Real do_t = tanh_c_t * dm_t;
+ const Real do_t_input = (o_t_deriv * do_t
+ - (2 * o_t - 1) * o_t_self_repair);
+
+ const Real dc_t = (c_t_deriv * dtanh_c_t + dc_t_out + do_t_input * w_oc)
+ - tanh_c_t * c_t_self_repair;
+ const Real dtanh_c_part = i_t * dc_t;
+ const Real df_t = dc_t * c_prev;
+ const Real df_t_input = (df_t * f_t_deriv
+ - (2 * f_t - 1) * f_t_self_repair);
+ const Real di_t = dc_t * tanh_c_part;
+ const Real di_t_input = (di_t * i_t_deriv
+ - (2 * i_t - 1) * i_t_self_repair);
+
+ if (params_deriv) {
+ w_ic_deriv_sum += c_prev * di_t_input;
+ w_fc_deriv_sum += c_prev * df_t_input;
+ w_oc_deriv_sum += c_t * do_t_input;
+ }
+
+ const Real dc_prev = w_ic * di_t_input + w_fc * df_t_input + f_t * dc_t;
+ const Real do_part = do_t_input;
+ const Real dc_part = (c_part_deriv * dtanh_c_part
+ - tanh_c_part * c_part_self_repair);
+ const Real df_part = df_t_input;
+ const Real di_part = di_t_input;
+
+ if (input_deriv) {
+ input_deriv[i * input_deriv_stride + j] += di_part;
+ input_deriv[i * input_deriv_stride + j + cell_dim] += df_part;
+ input_deriv[i * input_deriv_stride + j + cell_dim * 2] += dc_part;
+ input_deriv[i * input_deriv_stride + j + cell_dim * 3] += do_part;
+ input_deriv[i * input_deriv_stride + j + cell_dim * 4] += dc_prev;
+ }
+ }
+ }
+
+ if (params_deriv) {
+ // compute params_deriv
+ smem[tid] = w_ic_deriv_sum;
+# pragma unroll
+ for (int shift = CU1DBLOCK / 2; shift >= warpSize; shift >>= 1) {
+ __syncthreads();
+ if (tid < shift) {
+ smem[tid] += smem[tid + shift];
+ }
+ }
+ if (tid < warpSize && j < cell_dim) {
+ params_deriv[j] = smem[tid];
+ }
+
+ __syncthreads();
+ smem[tid] = w_fc_deriv_sum;
+# pragma unroll
+ for (int shift = CU1DBLOCK / 2; shift >= warpSize; shift >>= 1) {
+ __syncthreads();
+ if (tid < shift) {
+ smem[tid] += smem[tid + shift];
+ }
+ }
+ if (tid < warpSize && j < cell_dim) {
+ params_deriv[params_deriv_stride + j] = smem[tid];
+ }
+
+ __syncthreads();
+ smem[tid] = w_oc_deriv_sum;
+# pragma unroll
+ for (int shift = CU1DBLOCK / 2; shift >= warpSize; shift >>= 1) {
+ __syncthreads();
+ if (tid < shift) {
+ smem[tid] += smem[tid + shift];
+ }
+ }
+ if (tid < warpSize && j < cell_dim) {
+ params_deriv[2 * params_deriv_stride + j] = smem[tid];
+ }
+
+ // compute value_sum_out
+ __syncthreads();
+ smem[tid] = i_t_value_sum;
+# pragma unroll
+ for (int shift = CU1DBLOCK / 2; shift >= warpSize; shift >>= 1) {
+ __syncthreads();
+ if (tid < shift) {
+ smem[tid] += smem[tid + shift];
+ }
+ }
+ if (tid < warpSize && j < cell_dim) {
+ value_sum_out[j] += smem[tid];
+ }
+
+ __syncthreads();
+ smem[tid] = f_t_value_sum;
+# pragma unroll
+ for (int shift = CU1DBLOCK / 2; shift >= warpSize; shift >>= 1) {
+ __syncthreads();
+ if (tid < shift) {
+ smem[tid] += smem[tid + shift];
+ }
+ }
+ if (tid < warpSize && j < cell_dim) {
+ value_sum_out[value_sum_out_stride + j] += smem[tid];
+ }
+
+ __syncthreads();
+ smem[tid] = c_part_value_sum;
+# pragma unroll
+ for (int shift = CU1DBLOCK / 2; shift >= warpSize; shift >>= 1) {
+ __syncthreads();
+ if (tid < shift) {
+ smem[tid] += smem[tid + shift];
+ }
+ }
+ if (tid < warpSize && j < cell_dim) {
+ value_sum_out[2 * value_sum_out_stride + j] += smem[tid];
+ }
+
+ __syncthreads();
+ smem[tid] = o_t_value_sum;
+# pragma unroll
+ for (int shift = CU1DBLOCK / 2; shift >= warpSize; shift >>= 1) {
+ __syncthreads();
+ if (tid < shift) {
+ smem[tid] += smem[tid + shift];
+ }
+ }
+ if (tid < warpSize && j < cell_dim) {
+ value_sum_out[3 * value_sum_out_stride + j] += smem[tid];
+ }
+
+ __syncthreads();
+ smem[tid] = c_t_value_sum;
+# pragma unroll
+ for (int shift = CU1DBLOCK / 2; shift >= warpSize; shift >>= 1) {
+ __syncthreads();
+ if (tid < shift) {
+ smem[tid] += smem[tid + shift];
+ }
+ }
+ if (tid < warpSize && j < cell_dim) {
+ value_sum_out[4 * value_sum_out_stride + j] += smem[tid];
+ }
+
+ // need to update self_repair_sum_out before deriv_sum_out, because
+ // deriv_sum_out and deriv_sum_in might point to the same memory.
+ if (i0 < 5 && j < cell_dim) {
+ self_repair_sum_out[i0 * self_repair_sum_out_stride + j] +=
+ update_sr[i0] ? num_rows : 0;
+ }
+
+ // compute derive_sum_out
+ __syncthreads();
+ smem[tid] = i_t_deriv_sum;
+# pragma unroll
+ for (int shift = CU1DBLOCK / 2; shift >= warpSize; shift >>= 1) {
+ __syncthreads();
+ if (tid < shift) {
+ smem[tid] += smem[tid + shift];
+ }
+ }
+ if (tid < warpSize && j < cell_dim) {
+ deriv_sum_out[j] += smem[tid];
+ }
+
+ __syncthreads();
+ smem[tid] = f_t_deriv_sum;
+# pragma unroll
+ for (int shift = CU1DBLOCK / 2; shift >= warpSize; shift >>= 1) {
+ __syncthreads();
+ if (tid < shift) {
+ smem[tid] += smem[tid + shift];
+ }
+ }
+ if (tid < warpSize && j < cell_dim) {
+ deriv_sum_out[deriv_sum_out_stride + j] += smem[tid];
+ }
+
+ __syncthreads();
+ smem[tid] = c_part_deriv_sum;
+# pragma unroll
+ for (int shift = CU1DBLOCK / 2; shift >= warpSize; shift >>= 1) {
+ __syncthreads();
+ if (tid < shift) {
+ smem[tid] += smem[tid + shift];
+ }
+ }
+ if (tid < warpSize && j < cell_dim) {
+ deriv_sum_out[2 * deriv_sum_out_stride + j] += smem[tid];
+ }
+
+ __syncthreads();
+ smem[tid] = o_t_deriv_sum;
+# pragma unroll
+ for (int shift = CU1DBLOCK / 2; shift >= warpSize; shift >>= 1) {
+ __syncthreads();
+ if (tid < shift) {
+ smem[tid] += smem[tid + shift];
+ }
+ }
+ if (tid < warpSize && j < cell_dim) {
+ deriv_sum_out[3 * deriv_sum_out_stride + j] += smem[tid];
+ }
+
+ __syncthreads();
+ smem[tid] = c_t_deriv_sum;
+ __syncthreads();
+# pragma unroll
+ for (int shift = CU1DBLOCK / 2; shift >= warpSize; shift >>= 1) {
+ __syncthreads();
+ if (tid < shift) {
+ smem[tid] += smem[tid + shift];
+ }
+ }
+ if (tid < warpSize && j < cell_dim) {
+ deriv_sum_out[4 * deriv_sum_out_stride + j] += smem[tid];
+ }
+ }
+}
+
/***********************************************************************
* ANSI-C wrappers of CUDA kernels
*/
_lstm_nonlinearity<<<Gr, Bl>>>(in, in_stride, params, params_stride,
out_stride, cell_dim, num_rows, out);
}
+void cudaD_diff_lstm_nonlinearity(dim3 Gr, dim3 Bl, const int cell_dim,
+ const int num_rows, const double* input,
+ const int input_stride, const double* params,
+ const int params_stride,
+ const double* output_deriv,
+ const int output_deriv_stride,
+ const double* deriv_sum_in,
+ const int deriv_sum_in_stride,
+ const double* self_repair_config,
+ double count, double* input_deriv,
+ const int input_deriv_stride,
+ double* params_deriv,
+ const int params_deriv_stride,
+ double* value_sum_out,
+ const int value_sum_out_stride,
+ double* deriv_sum_out,
+ const int deriv_sum_out_stride,
+ double* self_repair_sum_out,
+ const int self_repair_sum_out_stride) {
+ _diff_lstm_nonlinearity<<<Gr, Bl>>>(cell_dim, num_rows, input,
+ input_stride, params, params_stride, output_deriv, output_deriv_stride,
+ deriv_sum_in, deriv_sum_in_stride, self_repair_config, count, input_deriv,
+ input_deriv_stride, params_deriv, params_deriv_stride, value_sum_out,
+ value_sum_out_stride, deriv_sum_out, deriv_sum_out_stride,
+ self_repair_sum_out, self_repair_sum_out_stride);
+}
+void cudaF_diff_lstm_nonlinearity(dim3 Gr, dim3 Bl, const int cell_dim,
+ const int num_rows, const float* input,
+ const int input_stride, const float* params,
+ const int params_stride,
+ const float* output_deriv,
+ const int output_deriv_stride,
+ const double* deriv_sum_in,
+ const int deriv_sum_in_stride,
+ const float* self_repair_config, double count,
+ float* input_deriv,
+ const int input_deriv_stride,
+ float* params_deriv,
+ const int params_deriv_stride,
+ double* value_sum_out,
+ const int value_sum_out_stride,
+ double* deriv_sum_out,
+ const int deriv_sum_out_stride,
+ float* self_repair_sum_out,
+ const int self_repair_sum_out_stride) {
+ _diff_lstm_nonlinearity<<<Gr, Bl>>>(cell_dim, num_rows, input,
+ input_stride, params, params_stride, output_deriv, output_deriv_stride,
+ deriv_sum_in, deriv_sum_in_stride, self_repair_config, count, input_deriv,
+ input_deriv_stride, params_deriv, params_deriv_stride, value_sum_out,
+ value_sum_out_stride, deriv_sum_out, deriv_sum_out_stride,
+ self_repair_sum_out, self_repair_sum_out_stride);
+}
index c8912b4ebfc15c28ca2d11baef76286958644060..656b82326a06a7ded8954ec575fc3d3e764abcb5 100644 (file)
cudaF_lstm_nonlinearity(Gr, Bl, in, in_stride, params, params_stride,
out_stride, cell_dim, num_rows, out);
}
+inline void cuda_diff_lstm_nonlinearity(dim3 Gr, dim3 Bl, const int cell_dim,
+ const int num_rows, const double* input,
+ const int input_stride,
+ const double* params,
+ const int params_stride,
+ const double* output_deriv,
+ const int output_deriv_stride,
+ const double* deriv_sum_in,
+ const int deriv_sum_in_stride,
+ const double* self_repair_config,
+ double count, double* input_deriv,
+ const int input_deriv_stride,
+ double* params_deriv,
+ const int params_deriv_stride,
+ double* value_sum_out,
+ const int value_sum_out_stride,
+ double* deriv_sum_out,
+ const int deriv_sum_out_stride,
+ double* self_repair_sum_out,
+ const int self_repair_sum_out_stride) {
+ cudaD_diff_lstm_nonlinearity(Gr, Bl, cell_dim, num_rows, input, input_stride,
+ params, params_stride, output_deriv,
+ output_deriv_stride, deriv_sum_in,
+ deriv_sum_in_stride, self_repair_config, count,
+ input_deriv, input_deriv_stride, params_deriv,
+ params_deriv_stride, value_sum_out,
+ value_sum_out_stride, deriv_sum_out,
+ deriv_sum_out_stride, self_repair_sum_out,
+ self_repair_sum_out_stride);
+}
+inline void cuda_diff_lstm_nonlinearity(dim3 Gr, dim3 Bl, const int cell_dim,
+ const int num_rows, const float* input,
+ const int input_stride,
+ const float* params,
+ const int params_stride,
+ const float* output_deriv,
+ const int output_deriv_stride,
+ const double* deriv_sum_in,
+ const int deriv_sum_in_stride,
+ const float* self_repair_config,
+ double count, float* input_deriv,
+ const int input_deriv_stride,
+ float* params_deriv,
+ const int params_deriv_stride,
+ double* value_sum_out,
+ const int value_sum_out_stride,
+ double* deriv_sum_out,
+ const int deriv_sum_out_stride,
+ float* self_repair_sum_out,
+ const int self_repair_sum_out_stride) {
+ cudaF_diff_lstm_nonlinearity(Gr, Bl, cell_dim, num_rows, input, input_stride,
+ params, params_stride, output_deriv,
+ output_deriv_stride, deriv_sum_in,
+ deriv_sum_in_stride, self_repair_config, count,
+ input_deriv, input_deriv_stride, params_deriv,
+ params_deriv_stride, value_sum_out,
+ value_sum_out_stride, deriv_sum_out,
+ deriv_sum_out_stride, self_repair_sum_out,
+ self_repair_sum_out_stride);
+}
} // namespace kaldi
index 243e696187bc8794a73969d426aca68c12f8ab0e..027eb4fd2a75f3553894ad1fbfabed259441023e 100644 (file)
AssertEqual(Houtput, HDoutput);
}
- for (int i = 16; i <= 1024; i *= 2) {
+ for (int i = 16; i <= 2048; i *= 2) {
BaseFloat time_in_secs = 0.025;
int32 num_rows = i;
int32 cell_dim = i;
}
}
+template<typename Real>
+static void UnitTestBackpropLstmNonlinearity() {
+ for (int i = 0; i < 3; i++) {
+ int32 num_rows = 1 + Rand() % 200;
+ int32 cell_dim = 1 + Rand() % 2000;
+// KALDI_LOG << num_rows << ", " << cell_dim;
+
+ Matrix<Real> hinput(num_rows, 5 * cell_dim);
+ Matrix<Real> hparams(3, cell_dim);
+ Matrix<Real> houtput_deriv(num_rows, 2 * cell_dim);
+ Matrix<double> hderiv_sum_in(5, cell_dim);
+ Vector<Real> hself_repair_config(10);
+ double count_in;
+ Matrix<Real> hinput_deriv(num_rows, 5 * cell_dim);
+ Matrix<Real> hparams_deriv(3, cell_dim);
+ Matrix<double> hvalue_sum_out(5, cell_dim);
+ Matrix<double> hderiv_sum_out(5, cell_dim);
+ Matrix<Real> hself_repair_sum_out(5, cell_dim);
+
+ hinput.SetRandn();
+ hparams.SetRandn();
+ houtput_deriv.SetRandn();
+ hderiv_sum_in.SetRandn();
+ hself_repair_config.SetRandn();
+ count_in = Rand() % num_rows;
+
+ hinput_deriv.SetRandn();
+ hparams_deriv.SetRandn();
+ hvalue_sum_out.SetRandn();
+ hderiv_sum_out.SetRandn();
+ hself_repair_sum_out.SetRandn();
+
+ CuMatrix<Real> dinput(hinput);
+ CuMatrix<Real> dparams(hparams);
+ CuMatrix<Real> doutput_deriv(houtput_deriv);
+ CuMatrix<double> dderiv_sum_in(hderiv_sum_in);
+ CuVector<Real> dself_repair_config(hself_repair_config);
+
+ CuMatrix<Real> dinput_deriv(hinput_deriv);
+ CuMatrix<Real> dparams_deriv(hparams_deriv);
+ CuMatrix<double> dvalue_sum_out(hvalue_sum_out);
+ CuMatrix<double> dderiv_sum_out(hderiv_sum_out);
+ CuMatrix<Real> dself_repair_sum_out(hself_repair_sum_out);
+
+ cu::CpuBackpropLstmNonlinearity(hinput, hparams, houtput_deriv,
+ hderiv_sum_in, hself_repair_config,
+ count_in, (MatrixBase<Real>*) NULL,
+ (MatrixBase<Real>*) NULL,
+ (MatrixBase<double>*) NULL,
+ (MatrixBase<double>*) NULL,
+ (MatrixBase<Real>*) NULL);
+ cu::BackpropLstmNonlinearity(dinput, dparams, doutput_deriv, dderiv_sum_in,
+ dself_repair_config, count_in,
+ (CuMatrixBase<Real>*) NULL,
+ (CuMatrixBase<Real>*) NULL,
+ (CuMatrixBase<double>*) NULL,
+ (CuMatrixBase<double>*) NULL,
+ (CuMatrixBase<Real>*) NULL);
+
+ cu::CpuBackpropLstmNonlinearity(hinput, hparams, houtput_deriv,
+ hderiv_sum_in, hself_repair_config,
+ count_in, (MatrixBase<Real>*) NULL,
+ &hparams_deriv, &hvalue_sum_out,
+ &hderiv_sum_out, &hself_repair_sum_out);
+ cu::BackpropLstmNonlinearity(dinput, dparams, doutput_deriv, dderiv_sum_in,
+ dself_repair_config, count_in,
+ (CuMatrixBase<Real>*) NULL, &dparams_deriv,
+ &dvalue_sum_out, &dderiv_sum_out,
+ &dself_repair_sum_out);
+
+ cu::CpuBackpropLstmNonlinearity(hinput, hparams, houtput_deriv,
+ hderiv_sum_in, hself_repair_config,
+ count_in, &hinput_deriv,
+ (MatrixBase<Real>*) NULL,
+ (MatrixBase<double>*) NULL,
+ (MatrixBase<double>*) NULL,
+ (MatrixBase<Real>*) NULL);
+ cu::BackpropLstmNonlinearity(dinput, dparams, doutput_deriv, dderiv_sum_in,
+ dself_repair_config, count_in, &dinput_deriv,
+ (CuMatrixBase<Real>*) NULL,
+ (CuMatrixBase<double>*) NULL,
+ (CuMatrixBase<double>*) NULL,
+ (CuMatrixBase<Real>*) NULL);
+
+ cu::CpuBackpropLstmNonlinearity(hinput, hparams, houtput_deriv,
+ hderiv_sum_in, hself_repair_config,
+ count_in, &hinput_deriv, &hparams_deriv,
+ &hvalue_sum_out, &hderiv_sum_out,
+ &hself_repair_sum_out);
+ cu::BackpropLstmNonlinearity(dinput, dparams, doutput_deriv, dderiv_sum_in,
+ dself_repair_config, count_in, &dinput_deriv,
+ &dparams_deriv, &dvalue_sum_out,
+ &dderiv_sum_out, &dself_repair_sum_out);
+
+ Matrix<Real> hdinput_deriv(dinput_deriv);
+ Matrix<Real> hdparams_deriv(dparams_deriv);
+ Matrix<double> hdvalue_sum_out(dvalue_sum_out);
+ Matrix<double> hdderiv_sum_out(dderiv_sum_out);
+ Matrix<Real> hdself_repair_sum_out(dself_repair_sum_out);
+
+// KALDI_LOG<< "input_deriv" << hinput_deriv << "d" << hdinput_deriv;
+// KALDI_LOG<< "hparams_deriv" << hparams_deriv << "d" << hdparams_deriv;
+// KALDI_LOG<< "hvalue_sum_out" << hvalue_sum_out << "d" << hdvalue_sum_out;
+// KALDI_LOG<< "hderiv_sum_out" << hderiv_sum_out << "d" << hdderiv_sum_out;
+// KALDI_LOG<< "hself_repair_sum_out" << hself_repair_sum_out << "d" << hdself_repair_sum_out;
+
+ AssertEqual(hinput_deriv, hdinput_deriv);
+ AssertEqual(hparams_deriv, hdparams_deriv);
+ AssertEqual(hvalue_sum_out, hdvalue_sum_out);
+ AssertEqual(hderiv_sum_out, hdderiv_sum_out);
+ AssertEqual(hself_repair_sum_out, hdself_repair_sum_out);
+ }
+
+ for (int i = 16; i <= 2048; i *= 2) {
+ BaseFloat time_in_secs = 0.025;
+ int32 num_rows = i;
+ int32 cell_dim = i;
+
+ CuMatrix<Real> input(num_rows, 5 * cell_dim);
+ CuMatrix<Real> params(3, cell_dim);
+ CuMatrix<Real> output_deriv(num_rows, 2 * cell_dim);
+ CuMatrix<double> deriv_sum_in(5, cell_dim);
+ CuVector<Real> self_repair_config(10);
+ double count_in;
+
+ CuMatrix<Real> input_deriv(num_rows, 5 * cell_dim);
+ CuMatrix<Real> params_deriv(3, cell_dim);
+ CuMatrix<double> value_sum_out(5, cell_dim);
+ CuMatrix<double> deriv_sum_out(5, cell_dim);
+ CuMatrix<Real> self_repair_sum_out(5, cell_dim);
+
+ input.SetRandn();
+ params.SetRandn();
+ output_deriv.SetRandn();
+ deriv_sum_in.SetRandn();
+ self_repair_config.SetRandn();
+ count_in = Rand() % num_rows;
+
+ Timer tim;
+ int32 iter = 0;
+ for (; tim.Elapsed() < time_in_secs; iter++)
+ cu::BackpropLstmNonlinearity(input, params, output_deriv, deriv_sum_in,
+ self_repair_config, count_in, &input_deriv,
+ ¶ms_deriv, &value_sum_out,
+ &deriv_sum_out, &self_repair_sum_out);
+
+
+ BaseFloat gflops = ((BaseFloat) i * i * iter) / (tim.Elapsed() * 1.0e+09);
+ KALDI_LOG << "For BackpropLstmNonlinearity"
+ << (sizeof(Real) == 8 ? "<double>" : "<float>") << ", for dim = "
+ << i << ", speed was " << gflops << " gigaflops";
+ }
+}
+
+
template<typename Real> void CudaMathUnitTest() {
#if HAVE_CUDA == 1
if (CuDevice::Instantiate().DoublePrecisionSupported())
UnitTestCuMathSplice<Real>();
UnitTestCuMathCopy<Real>();
UnitTestLstmNonlinearity();
+ UnitTestBackpropLstmNonlinearity<Real>();
}
} // namespace kaldi
index 806f4e309ab16d527d6bacb264ca258d143d2b7d..9d66720791427209f71ff03032f39ea5ba86b713 100644 (file)
}
}
+
+
template<typename Real>
void BackpropLstmNonlinearity(const CuMatrixBase<Real> &input,
const CuMatrixBase<Real> ¶ms,
#if HAVE_CUDA == 1
if (CuDevice::Instantiate().Enabled()) {
- KALDI_ERR << "CUDA version not implemented";
- // notes for Shiyin:
- // You could do an 'easy' initial version where we have have one thread per dimension,
- // and you can try optimizing this later on.
- // Since the cell-dim is usually quite large, like 1024, this is fairly reasonable.
- // But up to you.
+ Timer tim;
+ // Each thread block is working on 1 row of the data.
+ // It's best that cell dim is a multiple fo CU1DBLOCK
+
+
+ // Use 2D block (8x32 threads) as we need to compute column sum.
+ // Use 1D grid to cover the data matrix width `cell_dim`.
+ const int kWarpSize = 32;
+ dim3 dimBlock(kWarpSize, CU1DBLOCK / kWarpSize);
+// dim3 dimGrid(n_blocks(cell_dim, dimBlock.x),
+// n_blocks(num_rows, dimBlock.y));
+// if (dimGrid.x * dimGrid.y > 1024) {
+// dimGrid.y = std::max(1024 / dimGrid.x, 1);
+// }
+ dim3 dimGrid(n_blocks(cell_dim, dimBlock.x));
+ if (input_deriv == NULL) {
+ if (params_deriv == NULL) {
+ cuda_diff_lstm_nonlinearity(dimGrid, dimBlock, cell_dim, num_rows,
+ input.Data(), input.Stride(), params.Data(),
+ params.Stride(), output_deriv.Data(),
+ output_deriv.Stride(), deriv_sum_in.Data(),
+ deriv_sum_in.Stride(),
+ self_repair_config.Data(), count_in + 1,
+ NULL,
+ 0,
+ NULL,
+ 0,
+ NULL,
+ 0,
+ NULL,
+ 0,
+ NULL,
+ 0);
+
+ } else {
+ cuda_diff_lstm_nonlinearity(dimGrid, dimBlock, cell_dim, num_rows,
+ input.Data(), input.Stride(), params.Data(),
+ params.Stride(), output_deriv.Data(),
+ output_deriv.Stride(), deriv_sum_in.Data(),
+ deriv_sum_in.Stride(),
+ self_repair_config.Data(), count_in + 1,
+ NULL,
+ 0, params_deriv->Data(),
+ params_deriv->Stride(),
+ value_sum_out->Data(),
+ value_sum_out->Stride(),
+ deriv_sum_out->Data(),
+ deriv_sum_out->Stride(),
+ self_repair_sum_out->Data(),
+ self_repair_sum_out->Stride());
+ }
+ } else {
+ if (params_deriv == NULL) {
+ cuda_diff_lstm_nonlinearity(dimGrid, dimBlock, cell_dim, num_rows,
+ input.Data(), input.Stride(), params.Data(),
+ params.Stride(), output_deriv.Data(),
+ output_deriv.Stride(), deriv_sum_in.Data(),
+ deriv_sum_in.Stride(),
+ self_repair_config.Data(), count_in + 1,
+ input_deriv->Data(), input_deriv->Stride(),
+ NULL,
+ 0, NULL, 0, NULL, 0, NULL, 0);
+ } else {
+ cuda_diff_lstm_nonlinearity(dimGrid, dimBlock, cell_dim, num_rows,
+ input.Data(), input.Stride(), params.Data(),
+ params.Stride(), output_deriv.Data(),
+ output_deriv.Stride(), deriv_sum_in.Data(),
+ deriv_sum_in.Stride(),
+ self_repair_config.Data(), count_in + 1,
+ input_deriv->Data(), input_deriv->Stride(),
+ params_deriv->Data(),
+ params_deriv->Stride(),
+ value_sum_out->Data(),
+ value_sum_out->Stride(),
+ deriv_sum_out->Data(),
+ deriv_sum_out->Stride(),
+ self_repair_sum_out->Data(),
+ self_repair_sum_out->Stride());
+ }
+ }
+
+ CU_SAFE_CALL(cudaGetLastError());
+
+ CuDevice::Instantiate().AccuProfile(__func__, tim.Elapsed());
} else
#endif
{
index 33b2c4e6473976f707100c989073ddaef04aae71..c4b442e12b1a6506e93fada36473e1483c7ebcf5 100644 (file)
--- a/src/cudamatrix/cu-math.h
+++ b/src/cudamatrix/cu-math.h
A matrix, of dimension 3 by C, with rows containing the three
diagonal parameter matrices used in LSTMs, namely
w_{ic}, w_{fc} and w_{oc}.
- @param [out] output_deriv
+ @param [in] output_deriv
A matrix, of dimension N by 2C, containing the derivative of the
objective function we're backpropagating, w.r.t. the quantities
c_t and m_t (in two blocks of column-dimension C).
index da45791a06566bd7791fca9161251689966afe98..85c91c5eb0eb17276c2e7c77c79eeda74b206a6b 100644 (file)
virtual void InitFromConfig(ConfigLine *cfl);
- NaturalGradientPerElementScaleComponent() { } // use Init to really initialize.
+ LstmNonlinearityComponent() { } // use Init to really initialize.
virtual std::string Type() const {
- return "NaturalGradientPerElementScaleComponent";
+ return "LstmNonlinearityComponent";
}
virtual void Read(std::istream &is, bool binary);
virtual Component* Copy() const;
// Some functions that are specific to this class:
- explicit NaturalGradientPerElementScaleComponent(
- const NaturalGradientPerElementScaleComponent &other);
+ explicit LstmNonlinearityComponent(
+ const LstmNonlinearityComponent &other);
void Init(int32 dim, BaseFloat param_mean,
BaseFloat param_stddev, int32 rank, int32 update_period,