]> Gitweb @ Texas Instruments - Open Source Git Repositories - git.TI.com/gitweb - processor-sdk/kaldi.git/blob - src/nnet3/attention.h
[src,scripts,egs] Attention modeling, with example scripts (#1731)
[processor-sdk/kaldi.git] / src / nnet3 / attention.h
1 // nnet3/attention.h
3 // Copyright      2017  Johns Hopkins University (author: Daniel Povey)
4 //                      Hossein Hadian
6 // See ../../COPYING for clarification regarding multiple authors
7 //
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 //
12 //  http://www.apache.org/licenses/LICENSE-2.0
13 //
14 // THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 // KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
16 // WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
17 // MERCHANTABLITY OR NON-INFRINGEMENT.
18 // See the Apache 2 License for the specific language governing permissions and
19 // limitations under the License.
21 #ifndef KALDI_NNET3_ATTENTION_H_
22 #define KALDI_NNET3_ATTENTION_H_
24 #include "base/kaldi-common.h"
25 #include "util/common-utils.h"
26 #include "itf/options-itf.h"
27 #include "matrix/matrix-lib.h"
28 #include "cudamatrix/cu-matrix-lib.h"
29 #include "nnet3/nnet-common.h"
30 #include "nnet3/convolution.h"
32 #include <iostream>
34 namespace kaldi {
35 namespace nnet3 {
36 namespace attention {
38 /// @file  attention.h
39 ///
40 /// This file contains the lower-level interface for self-attention.
41 /// This is a form of self-attention, inspired by Google's paper
42 /// "Attention is all you need", but implemented in a way that's more
43 /// obviously suitable for speech tasks.  The main difference is that
44 /// instead of taking as input *all frames* from the previous layer,
45 /// we accept a limited grid of frames (so the left-context and
46 /// right-context are finite).  Also time-encoding is handled in a different
47 /// way-- we encode the time as a relative offset.
51 // Our attention is "multi-head", like in Google's paper.  Note: we're basically
52 // implementing multi-head attention as a fixed nonlinearity, with the actual
53 // parameters relegated to the previous layer.  That is, the attention layer
54 // won't have any parameters of its own, but the parameters of the preceding
55 // layer will be interpretable as the parameters.  It doesn't change what's
56 // computed, it just affects how the neural net is divided into components.
57 //
58 //  * Basic restricted self-attention (without positional encoding).
59 //
60 // To explain what's going on, we start with the simplest form of attention:
61 // single-head, and no positional encoding, but with restricted context.  For purposes
62 // of exposition we assume that the time offsets we need form a contigous
63 // range, i.e. with time-stride == 1; the code does have the notion of a stride (you'll
64 // see later).
65 //
66 // Using notation similar to the Google paper, suppose we have a time-sequence
67 // of inputs, and the inputs are (keys, values and queries):
68 //
69 //   k_t, v_t, q_t
70 //
71 // where k_t and q_t are vectors of dimension 'key_dim' and v_t is a vector
72 // of dimension 'value_dim' (you may choose to make this the same as key_dim, but
73 // that isn't a constraint).
75 // Let's make num_left_inputs and num_right_inputs be the number of
76 // left-context and right-context frames required, and for some t,
77 // let input_indexes(t) be the set
78 //  [ t - num_left_inputs, t - num_left_inputs + 1, ... t + num_right_inputs].
79 // To evaluate the output (which we'll write u_t), we need the query
80 // value q_t, plus the keys and values k_s and v_s for all s in input_indexes(t).
81 // If the inputs are not available for some subset of input_indexes(t),
82 // we just let them be zeros; the network can learn to ignore them if it wants,
83 // but making them zeros is simpler to implement.
84 //
85 //
86 // Anyway, the output u_t (without positional encoding yet) is:
87 //
88 //  u_t := \sum_{s in input_indexes(t)}  Z_t exp(q_t . k_s) v_s
89 //
90 // where Z_t is 1/(\sum_s exp(q_t . k_s)).  We'll handle scaling
91 // issues (the 1/sqrt(dim) factor in the Google paper) later on,
92 // by scaling the keys.
93 //
94 //
95 // * Positional encoding
96 // We now explain how we include positional encoding in the model.
97 //
98 //
99 // Let context_dim = 1 + num_left_inputs + num_right_inputs.
100 // Let v be a vector, and let the function Extend(v, o) (where
101 // 0 <= o < context_dim) extend v with extra dimensions
102 // that encode the time-offset.  To be precise, we have
103 //
104 //  Extend(v, o) = Append(v, u_o)
105 //
106 // where u_o is a unit vector of dimension context_dim that is nonzero in the
107 // o'th dimension (assuming zero-based indexing).
108 //
109 // So when we add the positional encoding (and the scale on the keys), we replace
110 // the equation:
111 //  u_t := \sum_{s in input_indexes(t)}  Z_t exp(q_t . k_s) v_s
112 // with:
113 //  u_t := \sum_{s in input_indexes(t)}  Z_t exp(alpha q_t . Extend(key-scale * k_s, s - t + num_left_inputs)) Extend(v_s, s - t + num_left_inputs)
114 //
115 // (we won't actually physically extend the vectors as we compute this,
116 // we'll do it a different way, but it's equivalent to what we wrote
117 // above. The dimension of q_t is key_dim + context_dim, and the dimension
118 // of the output u_t is value_dim + context_dim.
119 //
120 //
121 // * Multi-head attention
122 //
123 // The attention component if we had a single head, would have an input dimension
124 // of (2*key_dim + context_dim + value_dim), which would be interpreted
125 // as Append(k_t, q_t, v_t), of dimensions respectively
126 // (key_dim, key_dim + context_dim, value_dim).  It would have an output
127 // dimension of value_dim + context_dim.
128 //
129 // In any case, the multi-head version has input and output dimension that
130 // are larger by a factor of 'num_heads', and which is equivalent to
131 // several single-head components appended together.
132 //
133 //
134 //
135 //  * The actual calculation
136 //
137 // Let's assume that we might have multiple independent sequences; we'll
138 // call this 'num_images' because we're borrowing certain structures from
139 // the convolution code.
141 // The input is formatted as a matrix, whose NumRows() could be written as
142 // num_images * num_t_in, where num_t_in is the number of distinct input 't'
143 // values, and whose output is num_images * num_t_out.  To keep it simple we'll
144 // explain this under the assumption that we don't have any 't' stride in the
145 // attention (t_stride == 1 in the code), and that num_heads == 1; both of
146 // those are fairly simple modifications to the basic scheme.
147 // The image (normally 'n') index has a higher stride than the 't' index in
148 // both the input and the output.  We assume that there is 'enough'
149 // context of the input to compute all required offsets of the output.
150 //
151 // Define the intermediate quantity b_{t,o}, which you can think of
152 // as the input to softmax; the index 't' is the output time-index
153 // index at the output, and o ranges from 0 to context_dim - 1.
154 //
155 //    b_{t,o} =  q_t . Extend(key-scale * k_{t + o - num_left_inputs}, o)
156 //
157 // To get rid of the Extend() expressions, define sub-ranges of q_t by
158 // writing q_t = Append(r_t, s_t) where r_t is of dimension 'key_dim'
159 // and s_t is of dimension context_dim.
160 //
161 //    b_{t,o} =   s_{t,o}  +  key-scale (r_t . k_{t+o-num_left_inputs})  [eqn:b]
162 //
163 // The 'b' quantity is the input to the softmax.  Define
164 //     c_t = Softmax(b_t)
165 // so \sum_o c_{t,o} = 1.0.  These are the weights on the values.
166 //
167 //
168 //  The output can be written as:
169 //        u_t :=  \sum_o c_{t,o} Extend(v_{t+o-num_left_inputs}, o)
170 //  but we can write this in a form more suitable for computation as:
171 //     u_t :=  Append(\sum_o c_{t,o} v_{t+o-num_left_inputs},  c_t)     [eqn:u]
172 //
173 //
174 //  * Implementation
175 //
176 // The most time-consuming parts of this computation, we imagine, would be the
177 // dot-products in [eqn:b] and the weighted sum (\sum_o) in [eqn:u].  They have
178 // an awkward band-diagonal structure that would not be particularly convenient
179 // to implement using CUBLAS and the like; I don't believe the relevant operations
180 // exist in the BLAS interface, at least for [eqn:u].
181 //
182 // In the future I hope to implement this with block-wise matrix
183 // multiplies-- imagine covering the band-diagonal part of a matrix with
184 // rectangular blocks in such a way that all the nonzero elements are covered,
185 // but the blocks might go over the zero parts a bit.   This could be done with
186 // Or perhaps we can figure out how to implement the block-diagonal matrix
187 // multiplies in CUDA.
191 /**
192    This function is a utility function that is at the core of how we implement
193    attention.  It may in future need to be renamed and possibly moved into the
194    cudamatrix directory and implemented in CUDA.  The current implementation is
195    quite inefficient.  We can also consider doing a complete redesign of how the
196    implementation works, such that this function doesn't exist at all; or we
197    could have a batched version of this function that would operate on a batch
198    of A, B and C at once (or a "strided, batched" version where the difference
199    between the members of the batch is expressed as a stride).
201    This function implements a special operation that you could view as some kind
202    of matrix multiplication where only a band of the product is retained.
204    The inputs A and B must have the same number of columns
205    (A.NumCols() == B.NumCols()), and A and C must have the same
206    number of rows (A.NumRows() == C->NumRows()).  The number of
207    rows of B must exceed the number of rows of A.  Define
208       num_extra_rows = B.NumRows() - A.NumRows().
209    Then C.NumCols() - 1 must divide num_extra_rows.
210    Define
211       row_shift = num_extra_rows / (C.NumCols() - 1).
213    This function implements:
214       (*C)(i, j) = alpha * VecVec(A.Row(i), B.Row(i + j * row_shift))
215  */
216 void GetAttentionDotProducts(BaseFloat alpha,
217                              const CuMatrixBase<BaseFloat> &A,
218                              const CuMatrixBase<BaseFloat> &B,
219                              CuMatrixBase<BaseFloat> *C);
222 /**
223    This function is related to GetAttentionDotProducts(); it
224    is used in scaling the values by the softmax scales, and
225    in backprop.
227    We have put the A, B and C in an unusual order here in order
228    to make clearer the relationship with GetAttentionDotProducts().
229    The matrices have the same relationship in terms of their
230    dimensions, as A, B and C do in GetAttentionDotProducts().
232    This function implements:
234      A->Row(i) += \sum_j alpha * C(i, j) * B.Row(i + j * row_shift).
235  */
236 void ApplyScalesToOutput(BaseFloat alpha,
237                          const CuMatrixBase<BaseFloat> &B,
238                          const CuMatrixBase<BaseFloat> &C,
239                          CuMatrixBase<BaseFloat> *A);
242 /**
243    This function is related to GetAttentionDotProducts(); it
244    is used in backprop.
246    We have put the A, B and C in an unusual order here in order
247    to make clearer the relationship with GetAttentionDotProducts().
248    The matrices have the same relationship in terms of their
249    dimensions, as A, B and C do in GetAttentionDotProducts().
251    This function implements:
253      B->Row(i + j * row_shift) += alpha * C(i, j) * A.Row(i).
254  */
255 void ApplyScalesToInput(BaseFloat alpha,
256                         const CuMatrixBase<BaseFloat> &A,
257                         const CuMatrixBase<BaseFloat> &C,
258                         CuMatrixBase<BaseFloat> *B);
262 /**
263    This is a higher-level interface to the attention code.
264    Read the extended comment in the file nnet3/attention.h for context.
266      @param [in] key_scale   Scale on the non-context part of the keys.
267      @param [in] keys       Matrix whose rows contains the keys, dimension is
268                             num-input-rows by key-dim.
269      @param [in] queries    Matrix whose rows contains the queries, dimension
270                             is num-output-rows by query-dim, where query-dim
271                             == key-dim + context-dim.
272                             num-output-rows - num-input-rows must be a multiple
273                             of context-dim - 1 (we'll 'shift' the keys by multiples
274                             of 0, n, 2n, ... (context-dim - 1) * n.
275      @param [in] values     Values to average at the output, of dimension
276                             num-input-rows by value-dim.  [we may add context
277                             information to these averages if required, see comment
278                             for 'output'].
279      @param [out] c         Expected to be finite at entry (no infs or nan's);
280                             at exit this will contain the output of the softmax.
281                             Must be of dimension num-output-rows by context-dim.
282      @param [out] output    The output of the attention mechanism will be *added*
283                             to this location.  Dimension must be num-output-rows
284                             by either value-dim, or value-dim + context-dim.  To
285                             the first 'value-dim' columns of this will be added
286                             the weighted combination of 'values', weighted by
287                             the corresponding weights of 'c' (e.g. the first
288                             column of 'c' scaling the first 'output-dim' rows of
289                             'values', then the next column of 'c' scaling the
290                             submatrix of 'values' shifted by 'n', and so on.
291                             If the output->NumCols() is value-dim + context-dim,
292                             'c' will be added to the remaining columns of
293                             'output'.
294  */
295 void AttentionForward(BaseFloat key_scale,
296                       const CuMatrixBase<BaseFloat> &keys,
297                       const CuMatrixBase<BaseFloat> &queries,
298                       const CuMatrixBase<BaseFloat> &values,
299                       CuMatrixBase<BaseFloat> *c,
300                       CuMatrixBase<BaseFloat> *output);
302 /** Performs the backward pass corresponding to 'AttentionForward',
303     propagating the derivative back to the keys, queries and values.
305     The interface should be easy to understand with reference
306     to AttentionForward(), so we won't document it, except to note
307     that 'keys_deriv', 'queries_deriv' and 'values_deriv' are
308     *added to*, not set, by this function.
309  */
310 void AttentionBackward(BaseFloat key_scale,
311                        const CuMatrixBase<BaseFloat> &keys,
312                        const CuMatrixBase<BaseFloat> &queries,
313                        const CuMatrixBase<BaseFloat> &values,
314                        const CuMatrixBase<BaseFloat> &c,
315                        const CuMatrixBase<BaseFloat> &output_deriv,
316                        CuMatrixBase<BaseFloat> *keys_deriv,
317                        CuMatrixBase<BaseFloat> *queries_deriv,
318                        CuMatrixBase<BaseFloat> *values_deriv);
325 } // namespace attention
326 } // namespace nnet3
327 } // namespace kaldi
330 #endif