CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In
dlsyscourse

CoCalc provides the best real-time collaborative environment for Jupyter Notebooks, LaTeX documents, and SageMath, scalable from individual users to large groups and classes!

GitHub Repository: dlsyscourse/public_notebooks
Path: blob/main/transformer_implementation.ipynb
Views: 35
Kernel: Python 3 (ipykernel)

Implementing Transformers

This notebook will walk you through the internals of implementing self attention and transformer networks. As with recurrent networks (and unlike convolutions), there is actually relatively little that is fundamentally new in their implementation, as it all largely involves an application of existing primitives you will have already implemented in your autodiff framework. However, there is indeed one aspect of an efficient implementation that requires a slight generalization of an item we have discussed already: a batch version of matrix multiplication. This is required for both the minibatch version of attention as well as the common "multihead" version. We will also briefly discuss some approaches to making Transformers more efficient.

Implementing self-attention

Let's begin with a simple implementation of self-attention. This essentially just implements the basic equation

Y=softmax(KQTd)V\begin{equation} Y = \mathrm{softmax}\left(\frac{KQ^T}{\sqrt{d}}\right)V \end{equation}

By convention, however, it's typical to implement self attention in terms of the actual inputs XX rather than the KK, QQ, and VV values themselves (i.e., instead of having the linear layer separately). It's also common to have an output weight as well (even though this could in theory be folded into the WKQVW_{KQV} terms), which applies an additional linear layer to the output of the the entire operation. I.e., the full operation is given by Y=(softmax(XWKWQTXTd)XWV)Wo.\begin{equation} Y = \left(\mathrm{softmax}\left(\frac{X W_K W_Q^T X^T}{\sqrt{d}}\right)X W_V \right) W_o. \end{equation} It's possible to also incorporate bias terms into each of these projections, though we won't bother with this, as it is less common for everything but the output weight, and then just largely adds complexity.

Let's see what this implementation looks like.

import numpy as np import torch import torch.nn as nn
def softmax(Z): Z = np.exp(Z - Z.max(axis=-1, keepdims=True)) return Z / Z.sum(axis=-1, keepdims=True) def self_attention(X, mask, W_KQV, W_out): K,Q,V = np.split(X@W_KQV, 3, axis=-1) attn = softmax(K@Q.swapaxes(-1,-2) / np.sqrt(X.shape[-1]) + mask) return attn@V@W_out, attn

We can compare this to PyTorch's self-attention implementation, the nn.MultiheadAttention layer (we'll cover what we mean by "multi-head" shortly). Note that by default (mainly just to be similar to the RNN implementation and other sequence models, the nn.MultiheadAttention layer also by default takes inputs in (T,N,d)(T,N,d) form (i.e, the batch dimension second. But unlike for RNNs, this ordering doesn't make much sense for self-attention and Transformers: we will be computing the operation "in parallel" over all times points, instead of as a sequential model like for RNNs. So we'll use the batch_first=True flag to make this a more natural dimension ordering for the inputs.

T = 5 M = torch.triu(-float("inf")*torch.ones(T,T),1)
tensor([[0., -inf, -inf, -inf, -inf], [0., 0., -inf, -inf, -inf], [0., 0., 0., -inf, -inf], [0., 0., 0., 0., -inf], [0., 0., 0., 0., 0.]])
T, d = 100, 64 attn = nn.MultiheadAttention(d, 1, bias=False, batch_first=True) M = torch.triu(-float("inf")*torch.ones(T,T),1) X = torch.randn(1,T,d) Y_, A_ = attn(X,X,X, attn_mask=M)
Y, A = self_attention(X[0].numpy(), M.numpy(), attn.in_proj_weight.detach().numpy().T, attn.out_proj.weight.detach().numpy().T)
print(np.linalg.norm(A - A_[0].detach().numpy())) print(np.linalg.norm(Y - Y_[0].detach().numpy()))
1.8741974e-07 1.3277154e-06

Minibatching with batch matrix multiply

Once we move from single example to minibatches, there is one additional subtlety that comes into play for self-attenion. Recall that for each sample in the minibatch, we will have to compute a matrix product, e.g., the KQTKQ^T term. If we need to process examples in a minibatch, we will need to perform this matrix multiplication correspondingly for each sample. This is an operation known as a batch matrix multiply.

It may seem as though nothing is new here. True, for an MLP it was possible to perform the entire batch equation as a single matrix multiplication, but didn't we similarly need to batch matrix multiplications for convolutional networks (after the im2col function)? Or for RNNs?

The answer is actually that no, previous to this we haven't needed the true batch matrix multiplication fuctionality. The situations we had before involved the multiplication of a "batched" tensor by a single weight matrix. I.e., in a ConvNet, we had something like y=im2col(x)W\begin{equation} y = \mathrm{im2col}(x) W \end{equation} or in the batched setting y(i)=im2col(x(i))W.\begin{equation} y^{(i)} = \mathrm{im2col}\left(x^{(i)}\right) W. \end{equation}

But this operation can be accomplished with "normal" matrix multiplication by just stacking the multiple samples into the matrix on the left [y(1)y(2)y(N)]=[im2col(x(1))im2col(x(2))im2col(x(N))]W.\begin{equation} \begin{bmatrix} y^{(1)} \\ y^{(2)} \\ \vdots \\ y^{(N)} \end{bmatrix} = \begin{bmatrix} \mathrm{im2col}\left(x^{(1)}\right) \\ \mathrm{im2col}\left(x^{(2)}\right) \\ \vdots \\ \mathrm{im2col}\left(x^{(N)}\right) \\ \end{bmatrix} W. \end{equation} This operation is just a normal matrix multiplication, so can be implemented e.g., using your framework so far, where matrix multiplication always operates on 2 dimensional NDArrays.

Fortunately, numpy's @ operator already performs batch matrix multiplication for the case of multiple arrays of (the same) dimension more than 2.

# illustration of batch matmul B = np.random.randn(10,3,5,4) C = np.random.randn(10,3,4,3) (B@C).shape
(10, 3, 5, 3)

Let's see how this works with our self attention layer. In fact, because of the judicious usage of axis=-1 and similar terms, our layer works exactly the same as it did before.

N = 10 M = torch.triu(-float("inf")*torch.ones(T,T),1) X = torch.randn(N,T,d) Y_, A_ = attn(X,X,X, attn_mask=M)
Y, A = self_attention(X.numpy(), M.numpy(), attn.in_proj_weight.detach().numpy().T, attn.out_proj.weight.detach().numpy().T)
print(np.linalg.norm(A - A_.detach().numpy())) print(np.linalg.norm(Y - Y_.detach().numpy()))
5.5253105e-07 3.97839e-06

Multihead attention

Practical implementations of attention use what is called multihead attention, which simply means that we run the self-attention mechansism of different subsets of the KK, QQ, VV terms, then concatenate them together. Formally, we'll partition these terms as K=[K1K2Kheads]\begin{equation} K = \begin{bmatrix} K_1 & K_2 & \cdots & K_{\mathrm{heads}} \end{bmatrix} \end{equation} (and similarly for QQ and VV.

Then will form the self attention outputs Yi=softmax(KiQiTd/heads)Vi\begin{equation} Y_i = \mathrm{softmax}\left(\frac{K_iQ_i^T}{\sqrt{d/\mathrm{heads}}}\right)V_i \end{equation} and then form the final ouput Y=[Y1Y2Yheads]Wo.\begin{equation} Y = \begin{bmatrix} Y_1 & Y_2 & \cdots & Y_{\mathrm{heads}} \end{bmatrix} W_o. \end{equation}

The advantage of multi-head attention is that applying a single self-attention layer to a "high dimensional" hidden state (i.e., where dd is large) seems to waste a lot of the information contained in the hidden layers. Recall, for intance, that the terms in the self attention matrix would be proportation to ktTqsk_t^T q_s. If ktk_t and qsq_s are high dimensional, then a lot of "internal structure" could be lost to result in ultimately just one weighting term. By breaking this up and computing multiple differen attention matrices, each of which weights different dimensions of the VV term, we avoid this problem, and practically lead to better performance. Note however that the "right" tradeoff between the number of heads and dd is still rather heuristic in nature.

def multihead_attention(X, mask, heads, W_KQV, W_out): N,T,d = X.shape K,Q,V = np.split(X@W_KQV, 3, axis=-1) K,Q,V = [a.reshape(N,T,heads,d//heads).swapaxes(1,2) for a in (K,Q,V)] attn = softmax(K@Q.swapaxes(-1,-2) / np.sqrt(d//heads) + mask) return (attn@V).swapaxes(1,2).reshape(N,T,d) @ W_out, attn
heads = 4 attn = nn.MultiheadAttention(d, heads, bias=False, batch_first=True) Y_, A_ = attn(X,X,X, attn_mask=M)
Y, A = multihead_attention(X.numpy(), M.numpy(), 4, attn.in_proj_weight.detach().numpy().T, attn.out_proj.weight.detach().numpy().T)
A_.shape
torch.Size([10, 100, 100])
A.shape
(10, 4, 100, 100)
print(np.linalg.norm(Y - Y_.detach().numpy())) print(np.linalg.norm(A.mean(1) - A_.detach().numpy()))
4.0823516e-06 4.2045417e-07

Transformer Block

Let's finally put all this together into a full transformer block. Transformers simply amount to a self-attention block, with a residual layers and layer norm operation, followed by a two-layer feedforward network, with another residual layer and layer norm. We can implement this in a few lines of code. Note that in "real" implementations, the layer norm terms, etc, would actually have trainable scale/bias terms that add a bit more expressivity to the model. This version we show will only be the same, for instance, at initialization.

def layer_norm(Z, eps): return (Z - Z.mean(axis=-1, keepdims=True)) / np.sqrt(Z.var(axis=-1, keepdims=True) + eps) def relu(Z): return np.maximum(Z,0) def transformer(X, mask, heads, W_KQV, W_out, W_ff1, W_ff2, eps): Z = layer_norm(multihead_attention(X, mask, heads, W_KQV, W_out)[0] + X, eps) return layer_norm(Z + relu(Z@W_ff1)@W_ff2, eps)
trans = nn.TransformerEncoderLayer(d, heads, dim_feedforward=128, dropout=0.0, batch_first=True) trans.linear1.bias.data.zero_() trans.linear2.bias.data.zero_(); Y_ = trans(X, M)
Y = transformer(X.numpy(), M.numpy(), heads, trans.self_attn.in_proj_weight.detach().numpy().T, trans.self_attn.out_proj.weight.detach().numpy().T, trans.linear1.weight.detach().numpy().T, trans.linear2.weight.detach().numpy().T, trans.norm1.eps)
print(np.linalg.norm(Y - Y_.detach().numpy()))
2.7750326e-05

The question for "efficient Transformers"

Since the Transformer was first proposed, there have been endless attempts made to make different "efficient" versions of the operation. The key drawback of transformers, we have seen, is that they require forming a the T×TT \times T attention matrix and multiplying by VV (an O(T2d)O(T^2d) operation) softmax(KQTd)V\begin{equation} \mathrm{softmax}\left(\frac{KQ^T}{\sqrt{d}}\right)V \end{equation} If TT is much larger than dd (e.g., the sequence is very long, then this operation is quite costly).

There are essentially two approaches to making the approach more efficient: by attempting the represent the attention matrix A=softmax(KQTd)\begin{equation} A = \mathrm{softmax}\left(\frac{KQ^T}{\sqrt{d}}\right) \end{equation} either using sparsity or using low rank structure. In general, of course, this matrix neither sparse nor low rank. But we could simply dicate, for example, that we will only compute some subset of the attention weights, thereby decreasing the number of inner products we need to perform (this is the basis of the so-called "Sparse Attention" layer: similar approaches have been proposed a number of times, but this is one such example). Alternatively, one could try to infer some kind of hard sparsity by e.g., triangle inequalities or other similar instances (because, remember, we are computing what amounts to a similarly metric between the xx terms at different times).

Alternatively, we could try to represent AA in low rank form instead. To see why this could be appealing, consider the case where we don't have a softmax operation at all, but instead used the "attention" layer (KQTd)V\begin{equation} \left(\frac{KQ^T}{\sqrt{d}}\right)V \end{equation} In this case, if TdT \gg d, we could instead perform our multiplication in the order K(QTV)K(Q^T V), which would only have complexity O(Td2)O(Td^2), potentially much smaller. And some papers infact advocate for this very thing, or alternatively try to find a low-rank representation of the actual attention weights, to similar effects.

The thing to keep in mind with all these "efficient" alternatives (and if you have been reading the literation surrounding Transformers, you have likely seen a ton of these), is whether they are actually more efficient, for an equivalent level of performance, once real execution speed in taken into account. My best understanding of the current situation is that 1) explicit sparse self attention is indeed sometimes useful for models that want very long history, but that 2) most of the "efficient" transformer mechanisms that use low rank structure or inferred sparsity structure don't improve much in practice over traditional attention.