Parallelizing Multi-Head Attention

Table of Contents
In the multi-head attention mechanism, why after reshaping the projection matrices for Q/K/V from 3 dimensions to 4, we need to transpose the tokens
dimension with the heads
dimension?
Using the canonical example code for Attention Heads below as an example, why do we need Q = self.W_q(x).reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2)
?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import torch.nn as nn
class MultiHeadSelfAttention(nn.Module):
def __init__(self, embed_size, num_heads):
super().__init__()
self.num_heads = num_heads
self.head_dim = embed_size // num_heads # Split embeddings across heads
# Linear layers for Q, K, V
self.W_q = nn.Linear(embed_size, embed_size)
self.W_k = nn.Linear(embed_size, embed_size)
self.W_v = nn.Linear(embed_size, embed_size)
# Fully connected output layer, i.e. $W^O$
self.fc_out = nn.Linear(embed_size, embed_size)
def forward(self, x):
B, T, D = x.shape # Batch, Seq_len, Embedding_dim
Q = self.W_q(x).reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2)
K = self.W_k(x).reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2)
V = self.W_v(x).reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2)
# Compute attention scores (scaled dot-product attention)
scores = torch.matmul(Q, K.transpose(-1, -2)) / (self.head_dim ** 0.5)
attn = torch.softmax(scores, dim=-1) # Normalize scores
output = torch.matmul(attn, V) # Apply attention to values
# Concatenate and project back to embedding size
output = output.transpose(1, 2).contiguous().reshape(B, T, D)
return self.fc_out(output)
# Example usage
x = torch.randn(1, 10, 512) # Batch size 1, Sequence length 10, Embedding size 512
attention_layer = MultiHeadSelfAttention(embed_size=512, num_heads=8)
output = attention_layer(x)
print(output.shape) # Should be [1, 10, 512]
Review Multi-head Attention
Instead of using a single attention head, transformers use multiple heads. Each attention head has its own set of projection matrices for Q, K, and V. Each head learns to focus on different types of relationships. For example:
- One head might focus on long-range dependencies, such as linking a subject to its verb.
- Another might focus on local context, such as detecting adjective-noun pairs.
- Another might specialize in syntax or semantics.
Each head has independent learnable weight/projection matrices $W_i^Q, W_i^K, W_i^V$, where $i$ is the head index. Each weight matrix has shape $(D, D/h)$ where $D$ is the embedding dimension and $h$ is the number of heads.
For input embeddings $X$ of shape $(B, T, D)$, the head $i$ computes:
\[Q_i = X W_i^Q, \quad K_i = X W_i^K, \quad V_i = X W_i^V\]The multi-head attention then concatenates the per-head attention outputs and linearly mixes them:
\[\text{MultiHeadAttention}(Q, K, V) = \text{Concat}\left(\text{head}_1, \text{head}_2, \ldots, \text{head}_h\right)W^O\]Each head has output shape $(B, T, D/h)$. The concatenated output has shape $(B, T, D)$. $W^O$ is trainable and has shape $(D, D)$. $W^O$ mixes information across heads and refines the final representation before passing it to the next layer.
Reshaping for Parallelization
Let’s break down Q = self.W_q(x).reshape(B, T, self.num_heads, self.head_dim).transpose(1, 2)
:
self.W_q
is a linear layer (nn.Linear(embed_size, embed_size)
) that projects x
into a new representation, specifically the query (Q
) in multi-head attention. This linear transformation does not change the shape of x
, which remains (B, T, D)
.
In .reshape(B, T, self.num_heads, self.head_dim)
, we split the embedding dimension D
into self.num_heads
and self.head_dim
(where self.head_dim = D / self.num_heads
). For example, if D = 512
and self.num_heads = 8
, then self.head_dim = 512 / 8 = 64
. The shape becomes: (B, T, 8, 64)
.
In .transpose(1, 2)
, we swap the sequence length dimension (T
) and the number of heads dimension (num_heads
). Note that index 1 is the second column. This changes the shape:
Before: (B, T, num_heads, head_dim)
After: (B, num_heads, T, head_dim)
This rearrangement makes attention more efficient, because matrix multiplications can now parallelize across heads (each head operates independently on different parts of the embedding). The reason is the following.
In the code scores = torch.matmul(Q, K.transpose(-1, -2)) / (self.head_dim ** 0.5)
, K.transpose(-1, -2)
swaps the last two dimensions so that K
has the shape K^T: (B, num_heads, head_dim, T)
.
Then, the matrix multiplication Q @ K^T
is (B, num_heads, T, head_dim) @ (B, num_heads, head_dim, T)
, which results in (B, num_heads, T, T)
.
Thus, Q @ K^T
becomes more efficient after transpose(1, 2)
because:
- Parallel Computation for Each Head: By keeping
num_heads
as the second dimension, each head’s computation happens independently but in parallel across the batch, using optimized GPU kernels. - Better Memory Access Patterns: GPUs are highly optimized for contiguous memory access. The transpose(1, 2) operation ensures that each head’s data is grouped together, improving cache efficiency during matrix multiplication.