Attention from first Principles - 2
Multi-Head Attention and Causal Self-Attention
In the first post in this series we built the intuition for how to construct self attention, now we will continue our exploration and cover the below aspects in the self attention mechanism. You can refer to first part of this post here:
I will be covering the following topics:
Multi-Head Attention: The Power of Multiple Perspectives
The Multi-Head Illusion: Where Does Separation Actually Happen?
Causal Self-Attention: Looking Only Backward
Multi-Head Attention: The Power of Multiple Perspectives
The limitation of single-head attention:
Remember our attention weights? They form a single probability distribution — each token must decide how to allocate 100% of its attention across all positions.
But what if a token needs to attend to multiple things for different reasons?
Example: “The bank by the river approved my loan”
The word “bank” might need to:
Attend to “river” (for word sense: geographical vs. financial)
Attend to “approved” and “loan” (for syntactic role: subject of the sentence)
Attend to “my” (for semantic relation: whose loan?)
With a single attention head, these needs compete. The softmax forces a choice — maybe “loan” gets 60% weight, “river” gets 30%, “my” gets 10%. But we lose information.
The insight: Let different heads specialize
What if we could have:
Head 1: Focuses on word sense disambiguation (attends to “river”)
Head 2: Focuses on syntactic structure (attends to “approved”, “loan”)
Head 3: Focuses on semantic roles (attends to “my”)
Each head gets its own attention pattern, its own way of mixing context. They run in parallel, and we combine their outputs.
Why split dimensions instead of stacking?
We could make each head full d_model dimension, but that would multiply parameters by h. Instead:
Split d_model into h pieces of size d_k = d_model / h
Each head works in a smaller subspace
Total dimension after concatenation: back to d_model
Parameters stay manageable, but we gain representational diversity
we gain representational diversity
Mental Model:
Single head = one lens on the data. Multi-head = multiple specialized lenses simultaneously. Like having multiple experts analyze the same sentence, each noticing different patterns.
Multi-Head Attention: Mathematical Formulation
For each head i (i = 1 to h):
Qⁱ = X W_Q^i where W_Q^i: (d_model × d_k)
Kⁱ = X W_K^i where W_K^i: (d_model × d_k)
Vⁱ = X W_V^i where W_V^i: (d_model × d_k)
headⁱ = Attention(Qⁱ, Kⁱ, Vⁱ)
= softmax(Qⁱ Kⁱᵀ / √d_k) VⁱEach head produces output of shape (seq_len × d_k).
Combine all heads:
MultiHead(Q, K, V) = Concat(head¹, head², ..., headʰ) W_OWhere:
Concat: concatenates along the feature dimension
Combined shape: (seq_len × d_model) since h × d_k = d_model
W_O: (d_model × d_model) output projection
Why the output projection W_O?
After concatenation, we have h independent representations side-by-side. W_O lets the model learn how to mix information across heads — allowing heads to interact and combine their insights.
W_O lets the model learn how to mix information across heads
Let’s look at pytorch implementation of multi-head attention mechanism
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0, “d_model must be divisible by num_heads”
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# Linear projections for Q, K, V (all heads at once)
self.W_Q = nn.Linear(d_model, d_model)
self.W_K = nn.Linear(d_model, d_model)
self.W_V = nn.Linear(d_model, d_model)
# Output projection
self.W_O = nn.Linear(d_model, d_model)
def split_heads(self, x):
“”“Split the last dimension into (num_heads, d_k)”“”
batch_size, seq_len, d_model = x.size()
return x.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
# Output: (batch, num_heads, seq_len, d_k)
def forward(self, X):
batch_size = X.size(0)
# 1. Linear projections
Q = self.W_Q(X) # (batch, seq_len, d_model)
K = self.W_K(X)
V = self.W_V(X)
# 2. Split into multiple heads
Q = self.split_heads(Q) # (batch, num_heads, seq_len, d_k)
K = self.split_heads(K)
V = self.split_heads(V)
# 3. Scaled dot-product attention (for all heads in parallel)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
attention_weights = F.softmax(scores, dim=-1)
attention_output = torch.matmul(attention_weights, V)
# Output: (batch, num_heads, seq_len, d_k)
# 4. Concatenate heads
attention_output = attention_output.transpose(1, 2).contiguous()
# Shape: (batch, seq_len, num_heads, d_k)
attention_output = attention_output.view(batch_size, -1, self.d_model)
# Shape: (batch, seq_len, d_model)
# 5. Final linear projection
output = self.W_O(attention_output)
return output, attention_weightsThe Multi-Head Illusion: Where Does Separation Actually Happen?
Looking at the implementation, you might notice something puzzling:
Q = self.W_Q(X) # Single projection matrix
Q = self.split_heads(Q) # Just a reshape!We have one large W_Q matrix, and we’re just reshaping its output. There’s no separate matrix per head. So where does the “multi” in multi-head come from?
What the reshape actually does:
# Before: (batch, seq_len, d_model)
# After: (batch, num_heads, seq_len, d_k)We’re not creating new data — we’re reorganizing the same numbers. The weights in columns 0–63 of W_Q become “head 1”, columns 64–127 become “head 2”, and so on. It’s the same matrix, sliced differently.
So how do heads learn different patterns?
They don’t start different — gradient descent makes them different.
During training, different parts of the output affect the loss in different ways. Gradients flow back differently to different columns of W_Q. Over thousands of updates, columns 0–63 might learn to extract syntactic features, while columns 64–127 learn semantic features. The specialization emerges — it’s not designed in.
But wait — if it’s just emergent learning, why bother reshaping at all?
This is the key insight. The reshape enables something critical: separate attention patterns.
Without reshaping, we’d compute one attention distribution:
attention = softmax(Q @ K^T) # Shape: (seq_len × seq_len)One softmax. One probability distribution over tokens. Each token allocates its 100% attention once.
With reshaping, we compute h independent attention distributions:
attention = softmax(Q @ K^T) # Shape: (num_heads, seq_len, seq_len)Each head gets its own softmax, its own probability distribution. Head 1 can put 80% attention on “river” while head 2 puts 80% attention on “loan” — simultaneously.
The real separation: Operations on the last two dimensions
The crucial operations that create independent heads happen on the (seq_len × seq_len) matrices:
# For each head independently:
scores = Q @ K^T # (seq_len × d_k) @ (d_k × seq_len) → (seq_len × seq_len)
attention = softmax(scores, dim=-1) # Applied to (seq_len × seq_len)
output = attention @ V # (seq_len × seq_len) @ (seq_len × d_k)At the GPU level:
When we have shape (batch, num_heads, seq_len, seq_len), modern frameworks parallelize across the num_heads dimension:
8 heads = 8 parallel attention matrices computed simultaneously
Each GPU thread block can handle a different head
The softmax operation (the most expensive part) runs independently per head
Memory access patterns are optimized for this structure
Why this matters for performance:
# Single head: One large attention computation
(seq_len × d_model) @ (d_model × seq_len) → (seq_len × seq_len)
Cost: seq_len² × d_model operations
# Multi-head: h smaller computations in parallel
h × [(seq_len × d_k) @ (d_k × seq_len) → (seq_len × seq_len)]
Cost per head: seq_len² × d_k operations
Total: seq_len² × (h × d_k) = seq_len² × d_model (same!)
But: Each head’s (seq_len × seq_len) matrix is independent
→ Can be computed in parallel on different GPU cores
→ Better hardware utilizationThe d_k dimension split isn’t just conceptual — it creates h independent (seq_len × seq_len) attention matrices that GPUs can process simultaneously. Each head’s attention pattern is computed, normalized, and applied completely independently.
The weight matrix is one large pool of learnable parameters. The reshape splits the attention computation into parallel tracks at the (seq_len × seq_len) level — where the actual attention patterns form. Each head computes its own independent attention matrix, and GPUs process these matrices in parallel. Specialization emerges because (1) different gradient signals reach different parts of the weights, and (2) each head computes its own independent attention pattern over the last two dimensions.
We can update our multi-head attention code with PyTorch optimized scaled_dot_product_attention — PyTorch automatically handles scaling, softmax, and matrix multiplication. Uses optimized kernels (FlashAttention when available) and is much faster, especially for long sequences. Look at step number 3 in code below:
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0, “d_model must be divisible by num_heads”
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
# Linear projections for Q, K, V
self.W_Q = nn.Linear(d_model, d_model)
self.W_K = nn.Linear(d_model, d_model)
self.W_V = nn.Linear(d_model, d_model)
# Output projection
self.W_O = nn.Linear(d_model, d_model)
def forward(self, X):
batch_size, seq_len, _ = X.size()
# 1. Linear projections
Q = self.W_Q(X) # (batch, seq_len, d_model)
K = self.W_K(X)
V = self.W_V(X)
# 2. Reshape and transpose for multi-head attention
Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
# Shape: (batch, num_heads, seq_len, d_k)
# 3. Apply scaled dot-product attention (optimized by PyTorch)
attention_output = F.scaled_dot_product_attention(Q, K, V)
# Shape: (batch, num_heads, seq_len, d_k)
# 4. Concatenate heads
attention_output = attention_output.transpose(1, 2).contiguous()
attention_output = attention_output.view(batch_size, seq_len, self.d_model)
# 5. Final projection
output = self.W_O(attention_output)
return outputCausal Self-Attention: Looking Only Backward
The problem with standard self-attention:
In our implementation so far, every token can attend to all other tokens, including future ones. This is fine for tasks like:
Sentence classification (BERT-style)
Translation (encoder side)
But it breaks autoregressive generation (like GPT), where we predict one token at a time. During training, we can’t let token i “peek” at tokens i+1, i+2, etc. — that would be cheating!
Example:
Sentence: “The cat sat on the”
Predicting: “mat”
Token “sat” shouldn’t see “on” or “the” when we’re training
Otherwise the model learns to cheat, not to predictThe solution: Masking
We prevent tokens from attending to future positions by setting their attention scores to -∞ before softmax:
scores = Q @ K^T / √d_k
# Apply causal mask
scores = scores.masked_fill(mask == 0, float(’-inf’))
attention_weights = softmax(scores)When a score is -∞, softmax makes its weight ~0.
What does the mask look like?
For a sequence of length 4:
[[1, 0, 0, 0], # Token 0 can only see itself
[1, 1, 0, 0], # Token 1 can see tokens 0,1
[1, 1, 1, 0], # Token 2 can see tokens 0,1,2
[1, 1, 1, 1]] # Token 3 can see all tokensLower triangular matrix — each token sees itself and everything before it.
Lower triangular matrix — helps us implement causal mask
Causal Attention Implementation
Creating the causal mask:
def create_causal_mask(seq_len):
“”“Creates a lower triangular mask for causal attention”“”
mask = torch.tril(torch.ones(seq_len, seq_len))
return mask # Shape: (seq_len, seq_len)
# Example:
mask = create_causal_mask(4)
print(mask)
# tensor([[1., 0., 0., 0.],
# [1., 1., 0., 0.],
# [1., 1., 1., 0.],
# [1., 1., 1., 1.]])Manual implementation with masking:
def causal_self_attention(Q, K, V):
“”“
Q, K, V: (batch, num_heads, seq_len, d_k)
“”“
d_k = Q.size(-1)
seq_len = Q.size(-2)
# Compute attention scores
scores = (Q @ K.transpose(-2, -1)) / (d_k ** 0.5)
# Shape: (batch, num_heads, seq_len, seq_len)
# Create and apply causal mask
mask = torch.tril(torch.ones(seq_len, seq_len, device=Q.device))
scores = scores.masked_fill(mask == 0, float(’-inf’))
# Softmax and apply to values
attention_weights = F.softmax(scores, dim=-1)
output = attention_weights @ V
return output, attention_weightsUsing PyTorch’s built-in (simpler):
# Just add is_causal=True!
attention_output = F.scaled_dot_product_attention(
Q, K, V,
is_causal=True # That’s it!
)PyTorch handles the masking automatically and uses optimized implementations.
What the attention weights look like:
# Without causal mask:
[[0.25, 0.25, 0.25, 0.25],
[0.25, 0.25, 0.25, 0.25],
[0.25, 0.25, 0.25, 0.25],
[0.25, 0.25, 0.25, 0.25]]
# With causal mask:
[[1.00, 0.00, 0.00, 0.00], # Token 0: 100% on itself
[0.50, 0.50, 0.00, 0.00], # Token 1: split between 0,1
[0.33, 0.33, 0.33, 0.00], # Token 2: split among 0,1,2
[0.25, 0.25, 0.25, 0.25]] # Token 3: all previous tokensWith that, we’ve completed our exploration of the causal scaled self-attention mechanism.
In our next post of this series, we’ll move on to the innovations that build on this foundation — such as Grouped Query Attention (GQA) and Multi-Head Latent Attention (MHLA).
Stay Tuned!! and click like/subscribe if you liked the post.



This piece truely made me think about how elegantly Multi-Head Attention addresses the single-head limitations you so well explained in your first post, bringing more clarity to the overall mechanism.