1 // nnet3/attention.h
3 // Copyright 2017 Johns Hopkins University (author: Daniel Povey)
4 // Hossein Hadian
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 // http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
21 #ifndef KALDI_NNET3_ATTENTION_H_
22 #define KALDI_NNET3_ATTENTION_H_
24 #include "base/kaldi-common.h"
25 #include "util/common-utils.h"
26 #include "itf/options-itf.h"
27 #include "matrix/matrix-lib.h"
28 #include "cudamatrix/cu-matrix-lib.h"
29 #include "nnet3/nnet-common.h"
30 #include "nnet3/convolution.h"
32 #include <iostream>
34 namespace kaldi {
35 namespace nnet3 {
36 namespace attention {
38 /// @file attention.h
39 ///
40 /// This file contains the lower-level interface for self-attention.
41 /// This is a form of self-attention, inspired by Google's paper
42 /// "Attention is all you need", but implemented in a way that's more
43 /// obviously suitable for speech tasks. The main difference is that
44 /// instead of taking as input *all frames* from the previous layer,
45 /// we accept a limited grid of frames (so the left-context and
46 /// right-context are finite). Also time-encoding is handled in a different
47 /// way-- we encode the time as a relative offset.
51 // Our attention is "multi-head", like in Google's paper. Note: we're basically
52 // implementing multi-head attention as a fixed nonlinearity, with the actual
53 // parameters relegated to the previous layer. That is, the attention layer
54 // won't have any parameters of its own, but the parameters of the preceding
55 // layer will be interpretable as the parameters. It doesn't change what's
56 // computed, it just affects how the neural net is divided into components.
57 //
58 // * Basic restricted self-attention (without positional encoding).
59 //
60 // To explain what's going on, we start with the simplest form of attention:
61 // single-head, and no positional encoding, but with restricted context. For purposes
62 // of exposition we assume that the time offsets we need form a contigous
63 // range, i.e. with time-stride == 1; the code does have the notion of a stride (you'll
64 // see later).
65 //
66 // Using notation similar to the Google paper, suppose we have a time-sequence
67 // of inputs, and the inputs are (keys, values and queries):
68 //
69 // k_t, v_t, q_t
70 //
71 // where k_t and q_t are vectors of dimension 'key_dim' and v_t is a vector
72 // of dimension 'value_dim' (you may choose to make this the same as key_dim, but
73 // that isn't a constraint).
75 // Let's make num_left_inputs and num_right_inputs be the number of
76 // left-context and right-context frames required, and for some t,
77 // let input_indexes(t) be the set
78 // [ t - num_left_inputs, t - num_left_inputs + 1, ... t + num_right_inputs].
79 // To evaluate the output (which we'll write u_t), we need the query
80 // value q_t, plus the keys and values k_s and v_s for all s in input_indexes(t).
81 // If the inputs are not available for some subset of input_indexes(t),
82 // we just let them be zeros; the network can learn to ignore them if it wants,
83 // but making them zeros is simpler to implement.
84 //
85 //
86 // Anyway, the output u_t (without positional encoding yet) is:
87 //
88 // u_t := \sum_{s in input_indexes(t)} Z_t exp(q_t . k_s) v_s
89 //
90 // where Z_t is 1/(\sum_s exp(q_t . k_s)). We'll handle scaling
91 // issues (the 1/sqrt(dim) factor in the Google paper) later on,
92 // by scaling the keys.
93 //
94 //
95 // * Positional encoding
96 // We now explain how we include positional encoding in the model.
97 //
98 //
99 // Let context_dim = 1 + num_left_inputs + num_right_inputs.
100 // Let v be a vector, and let the function Extend(v, o) (where
101 // 0 <= o < context_dim) extend v with extra dimensions
102 // that encode the time-offset. To be precise, we have
103 //
104 // Extend(v, o) = Append(v, u_o)
105 //
106 // where u_o is a unit vector of dimension context_dim that is nonzero in the
107 // o'th dimension (assuming zero-based indexing).
108 //
109 // So when we add the positional encoding (and the scale on the keys), we replace
110 // the equation:
111 // u_t := \sum_{s in input_indexes(t)} Z_t exp(q_t . k_s) v_s
112 // with:
113 // u_t := \sum_{s in input_indexes(t)} Z_t exp(alpha q_t . Extend(key-scale * k_s, s - t + num_left_inputs)) Extend(v_s, s - t + num_left_inputs)
114 //
115 // (we won't actually physically extend the vectors as we compute this,
116 // we'll do it a different way, but it's equivalent to what we wrote
117 // above. The dimension of q_t is key_dim + context_dim, and the dimension
118 // of the output u_t is value_dim + context_dim.
119 //
120 //
121 // * Multi-head attention
122 //
123 // The attention component if we had a single head, would have an input dimension
124 // of (2*key_dim + context_dim + value_dim), which would be interpreted
125 // as Append(k_t, q_t, v_t), of dimensions respectively
126 // (key_dim, key_dim + context_dim, value_dim). It would have an output
127 // dimension of value_dim + context_dim.
128 //
129 // In any case, the multi-head version has input and output dimension that
130 // are larger by a factor of 'num_heads', and which is equivalent to
131 // several single-head components appended together.
132 //
133 //
134 //
135 // * The actual calculation
136 //
137 // Let's assume that we might have multiple independent sequences; we'll
138 // call this 'num_images' because we're borrowing certain structures from
139 // the convolution code.
141 // The input is formatted as a matrix, whose NumRows() could be written as
142 // num_images * num_t_in, where num_t_in is the number of distinct input 't'
143 // values, and whose output is num_images * num_t_out. To keep it simple we'll
144 // explain this under the assumption that we don't have any 't' stride in the
145 // attention (t_stride == 1 in the code), and that num_heads == 1; both of
146 // those are fairly simple modifications to the basic scheme.
147 // The image (normally 'n') index has a higher stride than the 't' index in
148 // both the input and the output. We assume that there is 'enough'
149 // context of the input to compute all required offsets of the output.
150 //
151 // Define the intermediate quantity b_{t,o}, which you can think of
152 // as the input to softmax; the index 't' is the output time-index
153 // index at the output, and o ranges from 0 to context_dim - 1.
154 //
155 // b_{t,o} = q_t . Extend(key-scale * k_{t + o - num_left_inputs}, o)
156 //
157 // To get rid of the Extend() expressions, define sub-ranges of q_t by
158 // writing q_t = Append(r_t, s_t) where r_t is of dimension 'key_dim'
159 // and s_t is of dimension context_dim.
160 //
161 // b_{t,o} = s_{t,o} + key-scale (r_t . k_{t+o-num_left_inputs}) [eqn:b]
162 //
163 // The 'b' quantity is the input to the softmax. Define
164 // c_t = Softmax(b_t)
165 // so \sum_o c_{t,o} = 1.0. These are the weights on the values.
166 //
167 //
168 // The output can be written as:
169 // u_t := \sum_o c_{t,o} Extend(v_{t+o-num_left_inputs}, o)
170 // but we can write this in a form more suitable for computation as:
171 // u_t := Append(\sum_o c_{t,o} v_{t+o-num_left_inputs}, c_t) [eqn:u]
172 //
173 //
174 // * Implementation
175 //
176 // The most time-consuming parts of this computation, we imagine, would be the
177 // dot-products in [eqn:b] and the weighted sum (\sum_o) in [eqn:u]. They have
178 // an awkward band-diagonal structure that would not be particularly convenient
179 // to implement using CUBLAS and the like; I don't believe the relevant operations
180 // exist in the BLAS interface, at least for [eqn:u].
181 //
182 // In the future I hope to implement this with block-wise matrix
183 // multiplies-- imagine covering the band-diagonal part of a matrix with
184 // rectangular blocks in such a way that all the nonzero elements are covered,
185 // but the blocks might go over the zero parts a bit. This could be done with
186 // Or perhaps we can figure out how to implement the block-diagonal matrix
187 // multiplies in CUDA.
191 /**
192 This function is a utility function that is at the core of how we implement
193 attention. It may in future need to be renamed and possibly moved into the
194 cudamatrix directory and implemented in CUDA. The current implementation is
195 quite inefficient. We can also consider doing a complete redesign of how the
196 implementation works, such that this function doesn't exist at all; or we
197 could have a batched version of this function that would operate on a batch
198 of A, B and C at once (or a "strided, batched" version where the difference
199 between the members of the batch is expressed as a stride).
201 This function implements a special operation that you could view as some kind
202 of matrix multiplication where only a band of the product is retained.
204 The inputs A and B must have the same number of columns
205 (A.NumCols() == B.NumCols()), and A and C must have the same
206 number of rows (A.NumRows() == C->NumRows()). The number of
207 rows of B must exceed the number of rows of A. Define
208 num_extra_rows = B.NumRows() - A.NumRows().
209 Then C.NumCols() - 1 must divide num_extra_rows.
210 Define
211 row_shift = num_extra_rows / (C.NumCols() - 1).
213 This function implements:
214 (*C)(i, j) = alpha * VecVec(A.Row(i), B.Row(i + j * row_shift))
215 */
216 void GetAttentionDotProducts(BaseFloat alpha,
217 const CuMatrixBase<BaseFloat> &A,
218 const CuMatrixBase<BaseFloat> &B,
219 CuMatrixBase<BaseFloat> *C);
222 /**
223 This function is related to GetAttentionDotProducts(); it
224 is used in scaling the values by the softmax scales, and
225 in backprop.
227 We have put the A, B and C in an unusual order here in order
228 to make clearer the relationship with GetAttentionDotProducts().
229 The matrices have the same relationship in terms of their
230 dimensions, as A, B and C do in GetAttentionDotProducts().
232 This function implements:
234 A->Row(i) += \sum_j alpha * C(i, j) * B.Row(i + j * row_shift).
235 */
236 void ApplyScalesToOutput(BaseFloat alpha,
237 const CuMatrixBase<BaseFloat> &B,
238 const CuMatrixBase<BaseFloat> &C,
239 CuMatrixBase<BaseFloat> *A);
242 /**
243 This function is related to GetAttentionDotProducts(); it
244 is used in backprop.
246 We have put the A, B and C in an unusual order here in order
247 to make clearer the relationship with GetAttentionDotProducts().
248 The matrices have the same relationship in terms of their
249 dimensions, as A, B and C do in GetAttentionDotProducts().
251 This function implements:
253 B->Row(i + j * row_shift) += alpha * C(i, j) * A.Row(i).
254 */
255 void ApplyScalesToInput(BaseFloat alpha,
256 const CuMatrixBase<BaseFloat> &A,
257 const CuMatrixBase<BaseFloat> &C,
258 CuMatrixBase<BaseFloat> *B);
262 /**
263 This is a higher-level interface to the attention code.
264 Read the extended comment in the file nnet3/attention.h for context.
266 @param [in] key_scale Scale on the non-context part of the keys.
267 @param [in] keys Matrix whose rows contains the keys, dimension is
268 num-input-rows by key-dim.
269 @param [in] queries Matrix whose rows contains the queries, dimension
270 is num-output-rows by query-dim, where query-dim
271 == key-dim + context-dim.
272 num-output-rows - num-input-rows must be a multiple
273 of context-dim - 1 (we'll 'shift' the keys by multiples
274 of 0, n, 2n, ... (context-dim - 1) * n.
275 @param [in] values Values to average at the output, of dimension
276 num-input-rows by value-dim. [we may add context
277 information to these averages if required, see comment
278 for 'output'].
279 @param [out] c Expected to be finite at entry (no infs or nan's);
280 at exit this will contain the output of the softmax.
281 Must be of dimension num-output-rows by context-dim.
282 @param [out] output The output of the attention mechanism will be *added*
283 to this location. Dimension must be num-output-rows
284 by either value-dim, or value-dim + context-dim. To
285 the first 'value-dim' columns of this will be added
286 the weighted combination of 'values', weighted by
287 the corresponding weights of 'c' (e.g. the first
288 column of 'c' scaling the first 'output-dim' rows of
289 'values', then the next column of 'c' scaling the
290 submatrix of 'values' shifted by 'n', and so on.
291 If the output->NumCols() is value-dim + context-dim,
292 'c' will be added to the remaining columns of
293 'output'.
294 */
295 void AttentionForward(BaseFloat key_scale,
296 const CuMatrixBase<BaseFloat> &keys,
297 const CuMatrixBase<BaseFloat> &queries,
298 const CuMatrixBase<BaseFloat> &values,
299 CuMatrixBase<BaseFloat> *c,
300 CuMatrixBase<BaseFloat> *output);
302 /** Performs the backward pass corresponding to 'AttentionForward',
303 propagating the derivative back to the keys, queries and values.
305 The interface should be easy to understand with reference
306 to AttentionForward(), so we won't document it, except to note
307 that 'keys_deriv', 'queries_deriv' and 'values_deriv' are
308 *added to*, not set, by this function.
309 */
310 void AttentionBackward(BaseFloat key_scale,
311 const CuMatrixBase<BaseFloat> &keys,
312 const CuMatrixBase<BaseFloat> &queries,
313 const CuMatrixBase<BaseFloat> &values,
314 const CuMatrixBase<BaseFloat> &c,
315 const CuMatrixBase<BaseFloat> &output_deriv,
316 CuMatrixBase<BaseFloat> *keys_deriv,
317 CuMatrixBase<BaseFloat> *queries_deriv,
318 CuMatrixBase<BaseFloat> *values_deriv);
325 } // namespace attention
326 } // namespace nnet3
327 } // namespace kaldi
330 #endif