Attention from First Principles - 3
Grouped Query Attention (GQA) and Multi-Head Latent Attention (MHLA)
In the part 1 and part 2 of the series Attention from First Principles we covered the complete attention mechanism, you can refer to the previous articles here (highly recommended if you have not already read):
Now we will explore further and cover the improvements that have occurred over last few years, we will cover the following:
Computational Complexity of Self-Attention
LLM Inference and the KV Cache Problem
Grouped Query Attention (GQA): Reducing Memory Overhead
Multi-Head Latent Attention (MHLA): Core Concept
Conclusion & References
To recap from part 1 of this series we had arrived at the basic self attention mechanism which had the below steps:
For token i:
1. Compute query: Qᵢ = WQ Xᵢ
2. Compute keys for all tokens: Kⱼ = WK Xⱼ (for all j)
3. Compute values for all tokens: Vⱼ = WV Xⱼ (for all j)
4. Calculate scaled attention scores: scoreᵢ,ⱼ = (Qᵢ · Kⱼ) / √dₖ
5. Normalize with softmax: wᵢ,ⱼ = softmax(scoreᵢ,ⱼ)
6. Mix values: X’ᵢ = Σⱼ wᵢ,ⱼ VⱼWe will take a quick detour, to look at the computational complexity of self attention mechanism in the next section
Computational Complexity of Self-Attention
Breaking down each operation:
For a sequence of length n with embedding dimension d:
1. Computing Q, K, V projections: O(n · d²)
Three matrix multiplications: (n × d) @ (d × d) for each
2. Computing attention scores Q @ K^T: O(n² · d)
Q: (n × d), K^T: (d × n)
Result: (n × n) attention score matrix
Each of the n² entries requires d dot product operations
3. Softmax normalization: O(n²)
Applied to the (n × n) score matrix
4. Applying attention weights to values: O(n² · d)
attention_weights: (n × n), V: (n × d)
Result: (n × d)
This is NOT just scalar multiplication — it’s matrix multiplication
Each output position (n total) is a weighted sum over n values, each of dimension d
Wait, why is step 4 O(n² · d)?
When we compute attention_weights @ V:
For each of n output tokens
We compute a weighted sum of n value vectors
Each value vector has dimension d
Total: n × n × d operations
Total complexity: O(n² · d + n · d²)
For typical transformers where n (sequence length) >> d (often d = 64 per head):
The n² · d term dominates
Practical bottleneck: O(n² · d)
Numerical Example: Operation Counts
Let’s trace through a small example: n = 4 tokens, d = 64 dimensions
Step 1: Q, K, V projections
X @ W_Q: (4 × 64) @ (64 × 64) = 4 × 64 × 64 = 16,384 operations
Same for K and V
Total: 3 × 16,384 = 49,152 operations
Complexity: O(n · d²)Step 2: Q @ K^T (attention scores)
Q @ K^T: (4 × 64) @ (64 × 4) = 4 × 4 × 64 = 1,024 operations
Result: (4 × 4) score matrix
Complexity: O(n² · d)Step 3: Softmax
Applied to (4 × 4) = 16 values
Complexity: O(n²)Step 4: attention_weights @ V
(4 × 4) @ (4 × 64) = 4 × 4 × 64 = 1,024 operations
This is NOT just 16 scalar multiplications!
Each output element requires summing 4 weighted values of dimension 64
Complexity: O(n² · d)Now scale to realistic sizes: n = 2048, d = 64
Step 1 (projections): 2048 × 64² × 3 ≈ 25M operations
Step 2 (Q @ K^T): 2048² × 64 ≈ 268M operations ← dominates!
Step 4 (weights @ V): 2048² × 64 ≈ 268M operations ← dominates!The n² terms explode as sequence length grows!
Mental Model:
The (n × n) attention matrix isn’t just computed — it’s also used to mix n value vectors. Both operations scale quadratically with sequence length.
Next we’ll move on to the innovations that build on this foundation — such as Grouped Query Attention (GQA) and Multi-Head Latent Attention (MHLA).
But before diving into those advancements, we first need to understand LLM inference and the role of the KV cache.
LLM Inference and the KV Cache Problem
How autoregressive generation works:
When generating text (like GPT), we produce one token at a time:
Input: “The cat sat”
Step 1: Generate “on” → “The cat sat on”
Step 2: Generate “the” → “The cat sat on the”
Step 3: Generate “mat” → “The cat sat on the mat”At each step, we run the full transformer forward pass for the entire sequence so far.
The naive approach (wasteful):
Step 1: Compute attention for [”The”, “cat”, “sat”, “on”]
Step 2: Compute attention for [”The”, “cat”, “sat”, “on”, “the”]
Step 3: Compute attention for [”The”, “cat”, “sat”, “on”, “the”, “mat”]Notice the problem? We’re recomputing K and V for “The”, “cat”, “sat” at every step, even though they never change!
Remember our earlier insight about Q, K, V asymmetry:
Q changes (new token asking questions)
K and V for past tokens stay the same (they already offered their context)
This redundancy becomes expensive for long sequences.
Let’s visualize this with concrete matrices.
Step 1: Generate token 4 (sequence length = 4)
Tokens: [1, 2, 3, 4]
Q₄ = [q₁, q₂, q₃, q₄] # Queries for all 4 tokens
K₄ = [k₁, k₂, k₃, k₄] # Keys for all 4 tokens
V₄ = [v₁, v₂, v₃, v₄] # Values for all 4 tokens
Attention scores (causal):
k₁ k₂ k₃ k₄
q₁ [ • -∞ -∞ -∞ ]
q₂ [ • • -∞ -∞ ]
q₃ [ • • • -∞ ]
q₄ [ • • • • ]Step 2: Generate token 5 (sequence length = 5)
Tokens: [1, 2, 3, 4, 5]
Q₅ = [q₁, q₂, q₃, q₄, q₅] # Recompute ALL queries
K₅ = [k₁, k₂, k₃, k₄, k₅] # Recompute k₁, k₂, k₃, k₄ AGAIN!
V₅ = [v₁, v₂, v₃, v₄, v₅] # Recompute v₁, v₂, v₃, v₄ AGAIN!
Attention scores (causal):
k₁ k₂ k₃ k₄ k₅
q₁ [ • -∞ -∞ -∞ -∞ ]
q₂ [ • • -∞ -∞ -∞ ]
q₃ [ • • • -∞ -∞ ]
q₄ [ • • • • -∞ ]
q₅ [ • • • • • ]The waste: k₁, k₂, k₃, k₄ and v₁, v₂, v₃, v₄ are identical to Step 1! We computed them twice.
Enter KV Cache , here’s how KV cache solves this:
With KV Cache:
Step 1: Generate token 4
Compute: Q₄, K₄, V₄
Cache: Store K₄ and V₄ in memory
Attention for token 4:
q₄ · [k₁, k₂, k₃, k₄] # Use all keysStep 2: Generate token 5
Compute: Only Q₅ and the NEW k₅, v₅
Retrieve from cache: [k₁, k₂, k₃, k₄] and [v₁, v₂, v₃, v₄]
Append: k₅ to cached keys, v₅ to cached values
Cache now: K₅ = [k₁, k₂, k₃, k₄, k₅]
V₅ = [v₁, v₂, v₃, v₄, v₅]
Attention for token 5:
q₅ · [k₁, k₂, k₃, k₄, k₅] # Reuse cached + newStep 3: Generate token 6
Compute: Only Q₆ and NEW k₆, v₆
Retrieve: [k₁, k₂, k₃, k₄, k₅] from cache
Append: k₆
Attention for token 6:
q₆ · [k₁, k₂, k₃, k₄, k₅, k₆] # All from cache except k₆The savings:
Without cache: Compute n keys and values at each step (quadratic growth)
With cache: Compute only 1 new key and value per step (linear growth)
Memory cost: For each layer, we store: (seq_len × d_k) for keys + (seq_len × d_k) for values
This grows with sequence length, which becomes the bottleneck for long contexts.
With KV Cache compute only 1 new key and value per step (linear growth)
Grouped Query Attention (GQA): Reducing Memory Overhead
The KV cache bottleneck:
In multi-head attention with h heads, we cache:
K: (batch, num_heads, seq_len, d_k)
V: (batch, num_heads, seq_len, d_k)
For long sequences (say 100K tokens), this becomes massive. Example:
32 heads × 100K tokens × 128 dimensions × 2 bytes (fp16) = ~800 MB per layer
With 80 layers = 64 GB just for KV cache!
The key observation:
Do we really need separate K and V for every head? What if multiple heads could share the same K and V?
Multi-Query Attention (MQA) — The extreme:
All heads share ONE set of K and V
Each head still has its own Q
Memory: h times smaller!
Problem with MQA: Quality degradation — too much sharing hurts model performance.
Grouped Query Attention (GQA) — The sweet spot:
Divide h heads into g groups
Heads within a group share K and V
Each head still has its own Q
Example: 8 heads, 2 groups
Group 1: Heads 1–4 share K₁, V₁
Group 2: Heads 5–8 share K₂, V₂
Memory: 4× reduction (2 KV pairs instead of 8)
Architecture Comparison
Multi-Head Attention (MHA):
Head 1: Q₁, K₁, V₁
Head 2: Q₂, K₂, V₂
Head 3: Q₃, K₃, V₃
Head 4: Q₄, K₄, V₄
Head 5: Q₅, K₅, V₅
Head 6: Q₆, K₆, V₆
Head 7: Q₇, K₇, V₇
Head 8: Q₈, K₈, V₈
KV cache: 8 × (K, V) pairsMulti-Query Attention (MQA):
Head 1: Q₁ ┐
Head 2: Q₂ │
Head 3: Q₃ ├─→ Shared: K, V
Head 4: Q₄ │
Head 5: Q₅ │
Head 6: Q₆ │
Head 7: Q₇ │
Head 8: Q₈ ┘
KV cache: 1 × (K, V) pair
Memory: 8× smallerGrouped Query Attention (GQA) — 2 groups:
Group 1:
Head 1: Q₁ ┐
Head 2: Q₂ ├─→ Shared: K₁, V₁
Head 3: Q₃ │
Head 4: Q₄ ┘
Group 2:
Head 5: Q₅ ┐
Head 6: Q₆ ├─→ Shared: K₂, V₂
Head 7: Q₇ │
Head 8: Q₈ ┘
KV cache: 2 × (K, V) pairs
Memory: 4× smaller
Quality: Better than MQAThe following is our implementation in PyTorch updated for Grouped Query Attention (GQA)
import torch
import torch.nn as nn
import torch.nn.functional as F
class GroupedQueryAttention(nn.Module):
def __init__(self, d_model, num_heads, num_kv_groups):
super().__init__()
assert num_heads % num_kv_groups == 0, “num_heads must be divisible by num_kv_groups”
self.d_model = d_model
self.num_heads = num_heads
self.num_kv_groups = num_kv_groups
self.d_k = d_model // num_heads
self.heads_per_group = num_heads // num_kv_groups
# Q projection: full num_heads
self.W_Q = nn.Linear(d_model, d_model)
# K, V projections: only num_kv_groups
self.W_K = nn.Linear(d_model, num_kv_groups * self.d_k)
self.W_V = nn.Linear(d_model, num_kv_groups * self.d_k)
# Output projection
self.W_O = nn.Linear(d_model, d_model)
def forward(self, X):
batch_size, seq_len, _ = X.size()
# Q: full num_heads
Q = self.W_Q(X) # (batch, seq_len, d_model)
Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
# Shape: (batch, num_heads, seq_len, d_k)
# K, V: only num_kv_groups
K = self.W_K(X) # (batch, seq_len, num_kv_groups * d_k)
V = self.W_V(X)
K = K.view(batch_size, seq_len, self.num_kv_groups, self.d_k).transpose(1, 2)
V = V.view(batch_size, seq_len, self.num_kv_groups, self.d_k).transpose(1, 2)
# Shape: (batch, num_kv_groups, seq_len, d_k)
# Repeat K, V for each head in the group
K = K.repeat_interleave(self.heads_per_group, dim=1)
V = V.repeat_interleave(self.heads_per_group, dim=1)
# Shape: (batch, num_heads, seq_len, d_k)
# Standard attention
attention_output = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
# Concatenate and project
attention_output = attention_output.transpose(1, 2).contiguous()
attention_output = attention_output.view(batch_size, seq_len, self.d_model)
output = self.W_O(attention_output)
return outputKey implementation details:
Fewer K,V parameters:
W_KandW_Vproduce onlynum_kv_groups * d_kdimensionsRepeat for heads:
repeat_interleaveexpands K, V so each group’s KV is shared across multiple query headsMemory savings: Cache stores only
num_kv_groupsKV pairs instead ofnum_heads
repeat_interleaveexpands K, V so each group’s KV is shared across multiple query heads
GQA’s primary benefits:
KV cache memory reduction (main win)
Inference speedup (consequence of #1)
Training computation (minimal impact)
During training, we recompute everything anyway (no cache)
Fewer W_K, W_V parameters means slightly less computation
But the Q @ K^T and attention @ V operations still happen for all query heads
So training speedup is modest
GQA is primarily an inference optimization, not a training optimization. It attacks the KV cache bottleneck that only exists during autoregressive generation.
Empirical findings (from papers like LLaMA-2, Mistral):
The performance degradation is surprisingly small:
GQA with 4–8 KV heads performs nearly as well as full multi-head attention with 32+ heads
But memory usage drops dramatically (4–8× reduction in KV cache)
GQA is primarily an inference optimization, not a training optimization. It attacks the KV cache bottleneck that only exists during autoregressive generation.
Multi-Head Latent Attention (MLA): Core Concept
The KV cache problem:
In standard multi-head attention, we cache full K and V for every token:
K: (num_heads × d_k) = 128 heads × 128 dim = 16,384 values
V: (num_heads × d_k) = 128 heads × 128 dim = 16,384 values
Total: 32,768 values per token per layer
For long contexts (100K tokens), this becomes prohibitive.
MLA’s solution: Low-rank compression through a shared latent bottleneck
Instead of caching full K and V, MLA:
Compresses the input to a small latent vector
Caches only the latent (much smaller)
Reconstructs K and V from the latent when needed
The Mathematical Framework
Standard attention (what we cache):
K = X @ W_K (d_model → d_model)
V = X @ W_V (d_model → d_model)
Cache: Full K and V matricesMLA with factorization (what we actually do):
Step 1: Compress to shared latent
c_KV = X @ W_KV_down (d_model → d_compressed)
Step 2: Expand to K and V separately
K = c_KV @ W_K_up (d_compressed → d_model)
V = c_KV @ W_V_up (d_compressed → d_model)
Cache: Only c_KV (the compressed latent)Why share W_KV_down for both K and V?
This is the key innovation. K and V both represent the same input token — they contain the same semantic information, just used differently:
K is used for matching (computing attention scores)
V is used for mixing (weighted combination)
By forcing K and V through a single shared compression (W_KV_down), we:
Maximize compression: Cache d_compressed instead of 2 × d_model
Example: 512 dims instead of 32,768 dims = 64× reduction
2. Exploit redundancy: K and V share the same semantic core, so we compress that shared information once
3. Maintain flexibility: Separate up-projections (W_K_up, W_V_up) let K and V diverge for their specialized roles
The factorization view:
Implicitly: W_K = W_KV_down @ W_K_up (low-rank factorization)
W_V = W_KV_down @ W_V_up (low-rank factorization)
But they share the first factor (W_KV_down)Mental Model:
Think of c_KV as a “semantic fingerprint” of the token. It’s a compressed representation that captures the essence. From this fingerprint, we can reconstruct both the key (for matching) and value (for content) by learning different expansion paths.
Dimension Flow: Concrete Example
Setup:
d_model = 4096
num_heads = 128
d_k = 4096 / 128 = 32 per head
d_compressed = 512 (the latent dimension)
seq_len = 1 token (for clarity)
Standard MHA (what we’d cache):
Input: X (1 × 4096)
K = X @ W_K → (1 × 4096)
V = X @ W_V → (1 × 4096)
Split into heads:
K: (1 × 128 × 32) = 4,096 values
V: (1 × 128 × 32) = 4,096 values
Cache per token: 8,192 values (K + V)MLA (what we actually cache):
Input: X (1 × 4096)
Step 1: Compress to latent
c_KV = X @ W_KV_down → (1 × 512)
↑
This is what we cache! Only 512 values
Step 2: Expand to K and V (during inference, recompute from cache)
K = c_KV @ W_K_up → (1 × 4096)
V = c_KV @ W_V_up → (1 × 4096)
Split into heads:
K: (1 × 128 × 32)
V: (1 × 128 × 32)
Cache per token: 512 values (just c_KV)Memory savings:
Standard: 8,192 values per token
MLA: 512 values per token
Reduction: 16× smaller cache!here’s the PyTorch implementation:
import torch
import torch.nn as nn
import torch.nn.functional as F
class MultiHeadLatentAttention(nn.Module):
def __init__(self, d_model, num_heads, d_compressed):
“”“
d_model: 4096 (model dimension)
num_heads: 128 (number of attention heads)
d_compressed: 512 (latent/compressed dimension)
“”“
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_compressed = d_compressed
self.d_k = d_model // num_heads
# Query: separate compression path
self.W_Q_down = nn.Linear(d_model, d_compressed, bias=False)
self.W_Q_up = nn.Linear(d_compressed, d_model, bias=False)
# KV: shared compression, separate expansion
self.W_KV_down = nn.Linear(d_model, d_compressed, bias=False)
self.W_K_up = nn.Linear(d_compressed, d_model, bias=False)
self.W_V_up = nn.Linear(d_compressed, d_model, bias=False)
# Output projection
self.W_O = nn.Linear(d_model, d_model, bias=False)
def forward(self, X, kv_cache=None):
“”“
X: (batch, seq_len, d_model)
kv_cache: Optional cached c_KV from previous tokens
“”“
batch_size, seq_len, _ = X.size()
# === Query path ===
c_Q = self.W_Q_down(X) # (batch, seq_len, d_compressed)
Q = self.W_Q_up(c_Q) # (batch, seq_len, d_model)
# === KV path (with caching support) ===
c_KV_new = self.W_KV_down(X) # (batch, seq_len, d_compressed)
# Concatenate with cache if it exists
if kv_cache is not None:
c_KV = torch.cat([kv_cache, c_KV_new], dim=1)
else:
c_KV = c_KV_new
# Expand from latent to full K, V
K = self.W_K_up(c_KV) # (batch, total_seq_len, d_model)
V = self.W_V_up(c_KV) # (batch, total_seq_len, d_model)
# Split into heads
Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# Attention
output = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
# Reshape and project
output = output.transpose(1, 2).contiguous()
output = output.view(batch_size, seq_len, self.d_model)
output = self.W_O(output)
return output, c_KV # Return updated cacheUsage example:
# Initialize
mla = MultiHeadLatentAttention(d_model=4096, num_heads=128, d_compressed=512)
# Generation loop
kv_cache = None
for token in tokens:
x = embed(token) # (1, 1, 4096)
output, kv_cache = mla(x, kv_cache)
# kv_cache grows: (1, t, 512) where t = number of tokens processedDid you notice something different in implementation than what we discussed ? take a look again if you did not notice!!
Ok, let me tell you there is also a down projection of query vectors Q
Why Query Compression in MLA?
According to the DeepSeek-V3 paper (Section 2.1.1):
“For the attention queries, we also perform a low-rank compression, which can reduce the activation memory during training”
The reason: Training memory efficiency, not inference cache!
Query compression reduces activation memory during training:
Activation memory: During forward pass, we need to store intermediate activations for backpropagation
Without compression: Store full Q of size (batch × seq_len × d_model)
With compression: Store smaller c_Q of size (batch × seq_len × d_compressed), then recompute Q during backward pass if needed
Key distinction:
KV compression → Reduces inference KV cache (the main win)
Q compression → Reduces training activation memory (secondary benefit)
This is a classic memory-computation tradeoff: compress during forward pass, potentially recompute during backward pass to save memory.
Conclusion: From First Principles to Production
We’ve built attention from the ground up, starting with a simple question: How do we make word representations context-aware?
The journey took us through:
The foundation:
Fixed combinations → Dynamic dot products → Learnable projections (Q, K, V)
Scaling for stability, softmax for normalization
The critical separation: different attention patterns in the (seq_len × seq_len) space
Multi-head attention:
Not separate weight matrices — emergent specialization through gradient descent
Parallel attention patterns, each learning different relationships
GPU-level parallelism across heads
From theory to practice:
Causal masking for autoregressive generation
KV caching: the inference bottleneck
O(n²·d) complexity drives the need for efficiency
Modern innovations:
GQA: Share K, V across query heads → 4–8× smaller KV cache
MLA: Low-rank compression through shared latent bottleneck → 16× smaller cache
Both target the same problem: making long-context inference feasible
Attention’s core mechanism remains unchanged since 2017. The innovations are about making it scale — compressing what we cache, sharing what we can, optimizing for the hardware we have.
If this deep dive helped you build intuition for attention mechanisms, give it a Like! 👏👏 And if you’re implementing these in your own projects, I’d love to hear about your experience in the comments.
In the next part of this series we will cover further advancements like Deep Sparse Attention (DSA) and others, stay tuned!!
References
Vaswani, A., et al. (2017). “Attention Is All You Need.” NeurIPS 2017: https://arxiv.org/abs/1706.03762
Shazeer, N. (2019). “Fast Transformer Decoding: One Write-Head is All You Need.”: Introduces Multi-Query Attention (MQA) — https://arxiv.org/abs/1911.02150
Ainslie, J., et al. (2023). “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints.”: Grouped Query Attention — https://arxiv.org/abs/2305.13245
DeepSeek-AI (2024). “DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model.”: Introduces Multi-Head Latent Attention (MLA) — https://arxiv.org/abs/2405.04434
DeepSeek-AI (2024). “DeepSeek-V3 Technical Report.”: Latest MLA implementation details — https://arxiv.org/abs/2412.19437
PyTorch Documentation:
torch.nn.functional.scaled_dot_product_attention— https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html


