This document describes the attention mechanism implementation in nanochat's GPT model, focusing on the Flash Attention 3 integration, sliding window patterns, and architectural optimizations. For the broader transformer architecture (MLP blocks, residual connections, model configuration), see GPT Transformer Architecture. For inference-specific KV cache details, see Inference Engine and KV Cache.
The nanochat model implements causal self-attention with several modern enhancements:
Sources: nanochat/gpt.py59-118
Flash Attention 3 provides ~9% throughput improvement over Flash Attention 2 by optimizing memory access patterns and supporting Hopper GPU tensor cores.
The nanochat codebase uses a custom abstraction layer (nanochat.flash_attention) that automatically routes to FA3 on Hopper GPUs or falls back to PyTorch's scaled_dot_product_attention (SDPA) on other hardware.
Sources: dev/LOG.md519-560 nanochat/gpt.py26
During training, the model uses flash_attn.flash_attn_func for causal attention with optional sliding window:
| Parameter | Value | Purpose |
|---|---|---|
q, k, v | (B, T, H, D) | Native FA3 layout, no transpose |
causal | True | Enforce autoregressive masking |
window_size | (left, 0) | Sliding window (if enabled) |
The function signature in nanochat/gpt.py100:
During inference, the model uses flash_attn.flash_attn_with_kvcache which manages KV cache in-place:
| Parameter | Value | Purpose |
|---|---|---|
q | (B, T, H, D) | Query for current tokens |
k_cache, v_cache | (num_layers, B, T_cache, H, D) | Persistent cache |
k, v | (B, T, H, D) | New keys/values to append |
cache_seqlens | int32 tensor | Per-batch position tracker |
causal | True | Autoregressive masking |
window_size | (left, 0) | Sliding window |
The function modifies k_cache and v_cache in-place, automatically appending new keys/values and returning the attention output. See nanochat/gpt.py104-110
FA3's native (B, T, H, D) layout eliminates the transpose operations required by FA2's (B, H, T, D) layout:
Sources: dev/LOG.md800-830 nanochat/gpt.py80-100
Sliding window attention allows each token to attend only to a limited context window, reducing computation for long sequences while preserving full context in selected layers.
The model supports configurable per-layer window patterns via the window_pattern string in GPTConfig:
| Character | Meaning | Window Size |
|---|---|---|
L | Long | Full context (sequence_len) |
S | Short | Half context (sequence_len // 2) |
The pattern is tiled across layers, with the final layer always forced to L (full context). For example, with window_pattern="SSSL" and 20 layers:
The _compute_window_sizes method nanochat/gpt.py260-287 converts the pattern string to per-layer (left, right) tuples:
| Window Type | Tuple Value | Meaning |
|---|---|---|
| Full context | (sequence_len, 0) | Attend to all previous tokens |
| Half window | (sequence_len // 2, 0) | Attend to last N/2 tokens |
These tuples are passed directly to Flash Attention 3's window_size parameter. The value (-1, 0) can also be used to indicate unlimited left context.
The attention FLOPs per layer depend on the effective sequence length (capped by window size). The estimate_flops() method nanochat/gpt.py292-317 accounts for this:
Sources: nanochat/gpt.py36-39 nanochat/gpt.py260-287 dev/LOG.md784-798
GQA reduces memory bandwidth during inference by sharing key/value heads across multiple query heads.
| Parameter | Description | Default | Effect |
|---|---|---|---|
n_head | Number of query heads | 6 | Full attention resolution |
n_kv_head | Number of key/value heads | 6 | Must divide n_head |
head_dim | Dimension per head | n_embd // n_head | Computed |
Standard attention uses n_kv_head = n_head (all heads independent). GQA uses n_kv_head < n_head, where multiple query heads share the same key/value pair.
The projection layers nanochat/gpt.py69-71 create different sizes:
n_head × head_dim parametersn_kv_head × head_dim parametersn_kv_head × head_dim parametersFlash Attention 3 handles GQA automatically when the K and V tensors have fewer heads than Q:
For a model with n_head=12, head_dim=128, sequence_len=2048:
| Configuration | KV Cache Size per Layer |
|---|---|
n_kv_head=12 (standard) | 2 × 12 × 2048 × 128 = 6.3 MB |
n_kv_head=4 (GQA 3:1) | 2 × 4 × 2048 × 128 = 2.1 MB |
n_kv_head=1 (MQA) | 2 × 1 × 2048 × 128 = 0.5 MB |
Sources: nanochat/gpt.py33-34 nanochat/gpt.py63-68
RoPE provides relative position information by applying rotation matrices to query and key vectors.
The model precomputes rotation frequencies during initialization nanochat/gpt.py243-258:
inv_freq = 1.0 / (base ** (channel_range / head_dim))freqs = outer(t, inv_freq)cos, sin = freqs.cos(), freqs.sin()(1, seq_len, 1, head_dim/2)The cache is allocated for sequence_len × 10 to allow dynamic sequence lengths without recomputation.
The apply_rotary_emb function nanochat/gpt.py51-57 rotates query and key vectors:
This applies a 2D rotation to consecutive pairs of dimensions, encoding relative positions through interference patterns.
During inference with KV cache, the rotary embeddings must be offset to the current cache position nanochat/gpt.py396-397:
This ensures new tokens receive embeddings corresponding to their absolute position in the full sequence.
Sources: nanochat/gpt.py51-57 nanochat/gpt.py182-186 nanochat/gpt.py243-258
Both queries and keys are normalized using RMSNorm after applying rotary embeddings nanochat/gpt.py94:
The norm function nanochat/gpt.py42-44 is a purely functional RMSNorm with no learnable parameters:
| Issue Without QK Norm | How QK Norm Helps |
|---|---|
Attention logit magnitude grows with √d | Normalized vectors have unit norm |
Unstable training with large head_dim | Stable dot products regardless of dimension |
| Need to carefully tune attention scale | Automatic scaling via normalization |
The normalization is applied per-head (last dimension of shape (B, T, H, D)), ensuring each head's queries and keys are independently normalized.
Sources: nanochat/gpt.py42-44 nanochat/gpt.py94
Value embeddings add extra capacity by mixing learned token-specific embeddings into the attention values.
The model uses value embeddings at alternating layers, with the last layer always included nanochat/gpt.py47-49:
For a 12-layer model, value embeddings are present at layers: 1, 3, 5, 7, 9, 11.
Each value embedding is gated by an input-dependent weight nanochat/gpt.py86-89:
The gate network:
x[..., :32](B, T, n_kv_head)(0, 2) via 2 * sigmoid(...)For vocab_size=32768, n_kv_head=6, head_dim=128, each value embedding table contains:
32768 × (6 × 128) = 25.2M6 × 25.2M = 151MThis is comparable to the token embedding table itself (25.2M). The large parameter count is justified because value embeddings add capacity at near-zero FLOP cost (just embedding lookup + gated addition).
Sources: nanochat/gpt.py47-49 nanochat/gpt.py73-74 nanochat/gpt.py86-89 dev/LOG.md487-495
For non-Hopper GPUs, the system automatically falls back to PyTorch's scaled_dot_product_attention:
| Aspect | Flash Attention 3 | SDPA Fallback |
|---|---|---|
| Layout | (B, T, H, D) native | Transpose to (B, H, T, D) |
| KV cache | In-place update | Manual concatenation |
| Sliding window | Native window_size | Explicit mask tensor |
| Performance | Optimized for H100 | CPU/older GPU compatible |
| Memory | Recompute-optimized | Standard memory usage |
The fallback ensures nanochat can run on any device (CPU, MPS, older CUDA GPUs) but with reduced performance, especially for sliding window attention.
Sources: dev/LOG.md519-560
The attention mechanism integrates into the transformer block's forward pass nanochat/gpt.py140-143:
Arguments passed from the main GPT forward nanochat/gpt.py403-406:
ve: Value embedding for this token (if layer has it)cos_sin: Rotary embedding tables sliced to sequence lengthwindow_size: Per-layer sliding window configurationkv_cache: Inference cache (None during training)This design allows the attention layer to remain stateless, with all positional and caching state managed by the calling context.
Refresh this wiki