This page documents the key-value (KV) cache subsystem used during autoregressive text generation. It covers the two-level class hierarchy in src/transformers/cache_utils.py, how concrete implementations are selected via GenerationConfig, and how the cache integrates with the generation loop.
For details on how generation uses the cache at a higher level, see the Generation System page (4), and for speculative decoding's specific use of cache rollback, see page 4.4 For continuous batching and PagedAttentionCache, see page 4.5
The cache system is organized in two levels:
CacheLayerMixin and its subclasses represent the cache for a single attention layer. Each instance stores key and value tensors for that layer.Cache and its subclasses hold a list of layer-level objects (one per transformer layer) and dispatch calls by layer_idx.Diagram: Two-Level Cache Architecture
Sources: src/transformers/cache_utils.py1-650
CacheLayerMixinCacheLayerMixin is the abstract base class for a single layer's KV cache, defined in src/transformers/cache_utils.py.
Key attributes and abstract interface:
| Member | Type | Description |
|---|---|---|
is_compileable | class var (bool) | True enables torch.compile compatibility |
is_initialized | instance var (bool) | Set to True after first update call |
keys, values | torch.Tensor | None | Stored KV tensors |
lazy_initialization(key_states, value_states) | abstract | Allocates backing tensors on first call |
update(key_states, value_states, cache_kwargs) | abstract | Appends/inserts new KV states; returns full tensors |
get_mask_sizes(cache_position) | abstract | Returns (kv_length, kv_offset) for mask construction |
get_seq_length() | abstract | Returns current number of cached tokens |
get_max_cache_shape() | abstract | Returns maximum capacity (-1 if unbounded) |
Concrete methods on CacheLayerMixin:
| Method | Description |
|---|---|
reset() | Zeros out stored tensors, resets cumulative_length if present |
reorder_cache(beam_idx) | Selects rows by beam index using index_select |
offload() | Moves keys/values to CPU (non-blocking) |
prefetch() | Moves keys/values back to original device (non-blocking) |
Sources: src/transformers/cache_utils.py27-83
DynamicLayerThe default layer implementation. Grows unboundedly by concatenating new tokens.
is_compileable = False, is_sliding = Falselazy_initialization: creates empty tensors with matching dtype and deviceupdate: torch.cat([self.keys, key_states], dim=-2); returns the full accumulated tensorget_max_cache_shape(): returns -1 (no upper bound)crop(max_length): truncates self.keys[..., :max_length, :] — used by speculative decoding rollbackbatch_repeat_interleave(repeats): expands batch dim for beam search setupbatch_select_indices(indices): contracts batch dim for beam search pruningSources: src/transformers/cache_utils.py85-165
DynamicSlidingWindowLayerInherits DynamicLayer. Keeps only the most recent sliding_window tokens in storage but returns full (concatenated) states from update.
is_sliding = Truecumulative_length — the total number of tokens seen, even after truncationupdate: appends new tokens, stores only keys[:, :, -sliding_window+1:, :], but returns the full concatenated tensor for attention computationget_seq_length(): returns cumulative_length, not the stored lengthget_mask_sizes: computes kv_offset = max(cumulative_length - sliding_window + 1, 0) to correctly window the maskSources: src/transformers/cache_utils.py168-251
StaticLayerPre-allocates fixed-size tensors of shape (batch_size, num_heads, max_cache_len, head_dim). The primary torch.compile-compatible layer.
is_compileable = True, is_sliding = Falselazy_initialization: allocates zero-filled tensors; calls torch._dynamo.mark_static_address on them to prevent graph breaks during compilation. Skipped when is_torchdynamo_compiling() is true (i.e., never called during the compile trace itself)update: writes new states at positions given by cache_kwargs["cache_position"] via index_copy_(2, cache_position, key_states), with a fallback to self.keys[:, :, cache_position] = key_states for devices that don't support index_copy_get_mask_sizes: always returns (max_cache_len, 0) — the mask always covers the full allocated windowget_seq_length: counts non-zero rows in self.keys[0, 0]Sources: src/transformers/cache_utils.py254-363
StaticSlidingWindowLayerInherits StaticLayer. Pre-allocates tensors of size min(sliding_window, max_cache_len) and uses a rolling update.
is_sliding = Trueeffective_max_cache_len = min(sliding_window, max_cache_len)roll(-1, dims=-2) to shift, then overwrites the last positiontorch.cat then copies the last max_cache_len tokensindex_copy_ like StaticLayercopy_ into the pre-allocated tensors (not assignment) to preserve the static dynamo addressSources: src/transformers/cache_utils.py365-487
QuantizedLayer, QuantoQuantizedLayer, HQQQuantizedLayerQuantizedLayer extends DynamicLayer with a two-tier storage model.
Storage structure:
self.keys, self.values — unquantized residual buffer (up to residual_length tokens)self._quantized_keys, self._quantized_values — quantized storage for older tokensupdate logic:
residual_length, quantizes the entire accumulated buffer and clears the residualConstructor parameters:
| Parameter | Description |
|---|---|
nbits | Quantization bit-width (2 or 4 for quanto; 2, 4, or 8 for HQQ) |
axis_key / axis_value | Quantization axis (0 or -1) |
q_group_size | Group size for per-channel quantization |
residual_length | Max tokens in unquantized buffer before flush |
QuantoQuantizedLayer uses optimum-quanto (qint2/qint4) and HQQQuantizedLayer uses the hqq library.
Sources: src/transformers/cache_utils.py490-650
CacheCache is the abstract top-level class. It owns a list of CacheLayerMixin instances and routes all operations by layer_idx.
Primary interface:
| Method | Signature | Description |
|---|---|---|
update | (key_states, value_states, layer_idx, cache_kwargs) | Delegates to self.layers[layer_idx].update(...) |
early_initialization | (batch_size, config, device, dtype, max_cache_len) | Forces lazy_initialization ahead of generation (required for torch.export) |
get_seq_length | (layer_idx=0) | Returns cached token count for a layer |
get_mask_sizes | (layer_idx, cache_position) | Returns (kv_length, kv_offset) for mask generation |
get_max_cache_shape | (layer_idx=None) | Returns maximum capacity |
reset | () | Resets all layers |
reorder_cache | (beam_idx) | Reorders all layers per beam index |
crop | (max_length) | Truncates all dynamic layers |
batch_repeat_interleave | (repeats) | Expands batch across all layers |
batch_select_indices | (indices) | Selects batch elements across all layers |
The is_compileable property on Cache aggregates is_compileable across all its layers.
Sources: src/transformers/cache_utils.py1-650 docs/source/en/internal/generation_utils.md241-259
DynamicCacheInstantiates one DynamicLayer per non-sliding-window layer and one DynamicSlidingWindowLayer per sliding-window layer, based on the model's config. This is the default cache used when cache_implementation is not set.
DynamicCache(config=model.config) (or no args for a blank cache)is_compileable = False (prevents torch.compile of decode step)crop, batch_repeat_interleave, batch_select_indicesdynamic_full string alias in cache_implementation selects this cache type explicitly, bypassing any model-default override — required by assisted/speculative decoding to guarantee rollback is supportedSources: src/transformers/cache_utils.py1-650 src/transformers/generation/candidate_generator.py190
StaticCacheInstantiates one StaticLayer per non-sliding-window layer and one StaticSlidingWindowLayer per sliding-window layer.
StaticCache(config=model.config, max_cache_len=N)is_compileable = True — enables 4D mask creation and torch.compile of the decode loopoffloaded=True mode, which uses offload()/prefetch() on each layerShapes returned by update:
# MHA: (1, 32, max_cache_len, 128)
# GQA: (1, num_key_value_heads, max_cache_len, head_dim)
Sources: src/transformers/cache_utils.py254-363 tests/utils/test_cache_utils.py75-115
QuantizedCacheManages a list of QuantoQuantizedLayer or HQQQuantizedLayer instances depending on the backend field in cache_config. Selected via cache_implementation="quantized".
torch.compilecache_config dict passed through GenerationConfigSources: src/transformers/cache_utils.py490-650 tests/utils/test_cache_utils.py208-248
EncoderDecoderCacheWraps two Cache instances for encoder-decoder models (BART, T5, etc.):
self_attention_cache — decoder self-attention KV (grows each step)cross_attention_cache — encoder cross-attention KV (fixed after encoder forward pass)Sources: src/transformers/generation/utils.py29-35 docs/source/en/internal/generation_utils.md259
GenerationConfigThe GenerationConfig fields that control cache behavior:
| Field | Type | Description |
|---|---|---|
use_cache | bool | Enables/disables caching entirely |
cache_implementation | str | None | Selects the cache class by name |
cache_config | dict | None | Additional kwargs forwarded to the cache constructor |
cache_implementation string-to-class mapping:
| String | Resolves to |
|---|---|
"dynamic" | DynamicCache |
"dynamic_full" | DynamicCache (explicit, for speculative decoding) |
"static" | StaticCache |
"offloaded" | DynamicCache(offloaded=True) |
"offloaded_static" | StaticCache(offloaded=True) |
"quantized" | QuantizedCache (backend from cache_config) |
"sliding_window", "hybrid", etc. | Deprecated aliases (now handled by StaticCache internally) |
Sources: src/transformers/generation/configuration_utils.py45-56 src/transformers/generation/configuration_utils.py146-156
Diagram: Cache Selection from GenerationConfig
Sources: src/transformers/generation/configuration_utils.py45-56 src/transformers/cache_utils.py27-32
Diagram: Cache Interactions During Generation
Sources: src/transformers/generation/utils.py493-591 src/transformers/generation/utils.py556-574
is_compileable and 4D Mask CreationWhen past_key_values.is_compileable is True (i.e., using StaticCache), prepare_inputs_for_generation converts a 2D attention mask into a 4D causal mask by calling create_masks_for_generate. This static mask shape is required for torch.compile to capture a fixed computation graph without retracing on each step.
src/transformers/generation/utils.py556-574
reorder_cacheAt each beam search step, GenerationMixin calls cache.reorder_cache(beam_idx). This iterates over all layers and calls CacheLayerMixin.reorder_cache(beam_idx), which performs:
self.keys = self.keys.index_select(0, beam_idx.to(self.keys.device))
self.values = self.values.index_select(0, beam_idx.to(self.values.device))
This reorders the batch dimension so that surviving beams retain the correct cached context.
Sources: src/transformers/cache_utils.py78-83 tests/utils/test_cache_utils.py171-206
cropAssisted and speculative decoding may need to roll back the cache when draft tokens are rejected. DynamicLayer.crop(max_length) truncates self.keys and self.values along the sequence dimension. StaticCache does not support crop (rejected tokens leave behind stale values that are masked out via cache_position).
This is why AssistedCandidateGenerator forces cache_implementation = "dynamic_full" — it must roll back accepted candidate tokens from the main model's cache.
Sources: src/transformers/cache_utils.py141-153 src/transformers/generation/candidate_generator.py189-191
| Method | Use Case |
|---|---|
batch_repeat_interleave(repeats) | Expands batch before beam search (each input repeated num_beams times) |
batch_select_indices(indices) | Contracts/selects batch elements (e.g., multi-sequence return) |
Both delegate to DynamicLayer.batch_repeat_interleave and DynamicLayer.batch_select_indices using repeat_interleave and direct indexing respectively.
Sources: src/transformers/cache_utils.py155-165
CacheLayerMixin provides offload() and prefetch() methods. When offloaded=True is passed to DynamicCache or StaticCache, the cache moves tensors to CPU between decode steps and prefetches them before attention computation. This trades bandwidth for reduced GPU memory pressure.
Sources: src/transformers/cache_utils.py57-67
QuantizedCache uses QuantizedLayer's two-tier approach:
Diagram: QuantizedLayer Storage Model
Sources: src/transformers/cache_utils.py490-568
| Cache Type | is_compileable | Notes |
|---|---|---|
DynamicCache | False | Default; no torch.compile of decode |
StaticCache | True | Pre-allocates tensors; supports torch.compile |
QuantizedCache | False | Dynamic sizing; no compile |
EncoderDecoderCache | Depends on inner caches | Typically False |
Sources: src/transformers/cache_utils.py27-32 src/transformers/cache_utils.py264-265 docs/source/en/kv_cache.md25-30
Refresh this wiki
This wiki was recently refreshed. Please wait 2 days to refresh again.