CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!
CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!
Path: blob/main/transformer_implementation.ipynb
Views: 35
Implementing Transformers
This notebook will walk you through the internals of implementing self attention and transformer networks. As with recurrent networks (and unlike convolutions), there is actually relatively little that is fundamentally new in their implementation, as it all largely involves an application of existing primitives you will have already implemented in your autodiff framework. However, there is indeed one aspect of an efficient implementation that requires a slight generalization of an item we have discussed already: a batch version of matrix multiplication. This is required for both the minibatch version of attention as well as the common "multihead" version. We will also briefly discuss some approaches to making Transformers more efficient.
Implementing self-attention
Let's begin with a simple implementation of self-attention. This essentially just implements the basic equation
By convention, however, it's typical to implement self attention in terms of the actual inputs rather than the , , and values themselves (i.e., instead of having the linear layer separately). It's also common to have an output weight as well (even though this could in theory be folded into the terms), which applies an additional linear layer to the output of the the entire operation. I.e., the full operation is given by It's possible to also incorporate bias terms into each of these projections, though we won't bother with this, as it is less common for everything but the output weight, and then just largely adds complexity.
Let's see what this implementation looks like.
We can compare this to PyTorch's self-attention implementation, the nn.MultiheadAttention
layer (we'll cover what we mean by "multi-head" shortly). Note that by default (mainly just to be similar to the RNN implementation and other sequence models, the nn.MultiheadAttention
layer also by default takes inputs in form (i.e, the batch dimension second. But unlike for RNNs, this ordering doesn't make much sense for self-attention and Transformers: we will be computing the operation "in parallel" over all times points, instead of as a sequential model like for RNNs. So we'll use the batch_first=True
flag to make this a more natural dimension ordering for the inputs.
Minibatching with batch matrix multiply
Once we move from single example to minibatches, there is one additional subtlety that comes into play for self-attenion. Recall that for each sample in the minibatch, we will have to compute a matrix product, e.g., the term. If we need to process examples in a minibatch, we will need to perform this matrix multiplication correspondingly for each sample. This is an operation known as a batch matrix multiply.
It may seem as though nothing is new here. True, for an MLP it was possible to perform the entire batch equation as a single matrix multiplication, but didn't we similarly need to batch matrix multiplications for convolutional networks (after the im2col function)? Or for RNNs?
The answer is actually that no, previous to this we haven't needed the true batch matrix multiplication fuctionality. The situations we had before involved the multiplication of a "batched" tensor by a single weight matrix. I.e., in a ConvNet, we had something like or in the batched setting
But this operation can be accomplished with "normal" matrix multiplication by just stacking the multiple samples into the matrix on the left This operation is just a normal matrix multiplication, so can be implemented e.g., using your framework so far, where matrix multiplication always operates on 2 dimensional NDArrays.
Fortunately, numpy's @
operator already performs batch matrix multiplication for the case of multiple arrays of (the same) dimension more than 2.
Let's see how this works with our self attention layer. In fact, because of the judicious usage of axis=-1
and similar terms, our layer works exactly the same as it did before.
Multihead attention
Practical implementations of attention use what is called multihead attention, which simply means that we run the self-attention mechansism of different subsets of the , , terms, then concatenate them together. Formally, we'll partition these terms as (and similarly for and .
Then will form the self attention outputs and then form the final ouput
The advantage of multi-head attention is that applying a single self-attention layer to a "high dimensional" hidden state (i.e., where is large) seems to waste a lot of the information contained in the hidden layers. Recall, for intance, that the terms in the self attention matrix would be proportation to . If and are high dimensional, then a lot of "internal structure" could be lost to result in ultimately just one weighting term. By breaking this up and computing multiple differen attention matrices, each of which weights different dimensions of the term, we avoid this problem, and practically lead to better performance. Note however that the "right" tradeoff between the number of heads and is still rather heuristic in nature.
Transformer Block
Let's finally put all this together into a full transformer block. Transformers simply amount to a self-attention block, with a residual layers and layer norm operation, followed by a two-layer feedforward network, with another residual layer and layer norm. We can implement this in a few lines of code. Note that in "real" implementations, the layer norm terms, etc, would actually have trainable scale/bias terms that add a bit more expressivity to the model. This version we show will only be the same, for instance, at initialization.
The question for "efficient Transformers"
Since the Transformer was first proposed, there have been endless attempts made to make different "efficient" versions of the operation. The key drawback of transformers, we have seen, is that they require forming a the attention matrix and multiplying by (an operation) If is much larger than (e.g., the sequence is very long, then this operation is quite costly).
There are essentially two approaches to making the approach more efficient: by attempting the represent the attention matrix either using sparsity or using low rank structure. In general, of course, this matrix neither sparse nor low rank. But we could simply dicate, for example, that we will only compute some subset of the attention weights, thereby decreasing the number of inner products we need to perform (this is the basis of the so-called "Sparse Attention" layer: similar approaches have been proposed a number of times, but this is one such example). Alternatively, one could try to infer some kind of hard sparsity by e.g., triangle inequalities or other similar instances (because, remember, we are computing what amounts to a similarly metric between the terms at different times).
Alternatively, we could try to represent in low rank form instead. To see why this could be appealing, consider the case where we don't have a softmax operation at all, but instead used the "attention" layer In this case, if , we could instead perform our multiplication in the order , which would only have complexity , potentially much smaller. And some papers infact advocate for this very thing, or alternatively try to find a low-rank representation of the actual attention weights, to similar effects.
The thing to keep in mind with all these "efficient" alternatives (and if you have been reading the literation surrounding Transformers, you have likely seen a ton of these), is whether they are actually more efficient, for an equivalent level of performance, once real execution speed in taken into account. My best understanding of the current situation is that 1) explicit sparse self attention is indeed sometimes useful for models that want very long history, but that 2) most of the "efficient" transformer mechanisms that use low rank structure or inferred sparsity structure don't improve much in practice over traditional attention.