This page covers the mechanisms in transformers for exporting models to portable formats and compiling them for optimized inference. It includes the torch.export / ExecuTorch workflow, torch.compile compatibility requirements, and the hub-based kernel injection system.
For cache implementations (StaticCache, HybridCache, etc.) that underpin export compatibility, see page 4.3. For quantization-related compilation hooks, see 6.1.
The library provides three distinct but related mechanisms for accelerating inference beyond standard eager-mode execution:
| Mechanism | Entry Point | Use Case |
|---|---|---|
torch.export + ExecuTorch | TorchExportableModuleForDecoderOnlyLM | Ahead-of-time model portability (mobile, edge) |
torch.compile | torch.compile(model.forward, ...) | GPU kernel fusion with StaticCache |
| Hub kernels | use_kernel_forward_from_hub decorator | Drop-in production kernels from HF Hub |
All three paths depend on having predictable tensor shapes. Dynamic caches (variable-length KV) prevent graph capture; StaticCache and HybridCache provide fixed-shape KV tensors required for compilation and export.
The entry point for decoder-only LM export is TorchExportableModuleForDecoderOnlyLM. Its __init__ inspects the model config and dispatches to the appropriate cache-aware wrapper.
ExecuTorch Export Class Hierarchy
Sources: src/transformers/integrations/executorch.py184-222
Defined in src/transformers/integrations/executorch.py184-337
Constructor (__init__):
PreTrainedModel plus optional batch_size, max_cache_len, device.config.use_cache. Raises ValueError if disabled.TorchExportableModuleWithHybridCache when config.layer_types exists and config.sliding_window is non-null; otherwise dispatches to TorchExportableModuleWithStaticCache.forward(input_ids, inputs_embeds, cache_position):
Delegates to the inner model's forward. Signature is kept minimal (no attention_mask, position_ids) to maintain a clean export boundary compatible with ExecuTorch's LLM runner.
export(input_ids, inputs_embeds, cache_position, dynamic_shapes, strict):
Calls torch.export.export with the module and provided inputs. Returns a torch.export.ExportedProgram. The dynamic_shapes argument lets callers declare which dimensions are dynamic (e.g., variable prefill length).
exported_program = torch.export.export(
self.model,
args=(),
kwargs=input_kwargs,
dynamic_shapes=dynamic_shapes,
strict=strict if strict is not None else True,
)
generate(exported_program, tokenizer, prompt, max_new_tokens, ...) (static method):
Runs autoregressive generation using an ExportedProgram. Calls exported_program.module() to get the callable, then runs the prompt token-by-token before generating new tokens. Returns the decoded string.
Used for models without hybrid attention (most standard decoder-only LLMs: LLaMA, Gemma, Qwen2). Wraps the model in a StaticCache-backed forward pass with fixed batch_size and max_cache_len.
Used for models where some layers use sliding-window attention and others use full attention (e.g., Gemma2). Identified by presence of config.layer_types and a non-null config.sliding_window.
Config dispatch logic:
Sources: src/transformers/integrations/executorch.py207-222
Defined in src/transformers/integrations/executorch.py32-181
Handles Vision-Language Models (e.g. SmolVLM2) by exporting three separate components:
| Component | Export Method | Inputs |
|---|---|---|
vision_encoder | export_vision_encoder() | pixel_values tensor, dynamic H/W |
connector | export_connector() | image_hidden_states, dynamic num_patches |
text_decoder | export_text_decoder() | Uses TorchExportableModuleForDecoderOnlyLM |
Returns a dict {"vision_encoder": ..., "connector": ..., "text_decoder": ...} from export().
Export flow for a decoder-only LM:
Sources: src/transformers/integrations/executorch.py247-337 tests/models/llama/test_modeling_llama.py229-288 tests/models/gemma/test_modeling_gemma.py354-442
For torch.export to succeed, the model must be loaded with a compatible configuration:
Flash Attention 2 is not export-compatible because its CUDA kernel cannot be captured by torch.export. SDPA and eager attention are both supported.
Sources: tests/models/llama/test_modeling_llama.py249-269 tests/models/gemma/test_modeling_gemma.py382-403
torch.compile with fullgraph=True requires static tensor shapes throughout the computation graph. The primary constraint is the KV cache:
DynamicCache: variable sequence dimension → graph breaks → cannot use fullgraph=TrueStaticCache: fixed max_cache_len → static shapes → fullgraph=True compatibleSlidingWindowCache: fixed window size → static shapes → compatibleSources: tests/models/gemma/test_modeling_gemma.py346-352 tests/models/mistral/test_modeling_mistral.py248-264
When cache_implementation="static" is passed to generate(), the generation loop internally compiles the decoding step. This is documented in test comments as:
generate()internally compiles each decoding step when static cache is used
Sources: tests/models/llama/test_modeling_llama.py220-225
mode | Use Case |
|---|---|
"reduce-overhead" | Reduces Python dispatch overhead; most common for inference |
"default" | Basic compilation |
"max-autotune" | Maximizes throughput, longer compile time |
torch._dynamo.reset() is required between compile calls with different configurations to clear cached graphs.
Sources: tests/models/mistral/test_modeling_mistral.py248-264
Some common operations that prevent fullgraph=True:
| Operation | Workaround |
|---|---|
attention_mask with dynamic shapes | Use StaticCache (mask is static-sized) |
max().item() in flash attention | Set TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1 |
DynamicCache updates | Switch to StaticCache |
Sources: src/transformers/modeling_flash_attention_utils.py386-393
The hub kernels system allows layer-level forward passes to be replaced with optimized kernels loaded from Hugging Face Hub repositories. This is controlled by src/transformers/integrations/hub_kernels.py.
use_kernel_forward_from_hub(layer_name)
A class decorator that registers the decorated nn.Module class to have its forward method replaced by a hub kernel when the kernels package is available and USE_HUB_KERNELS is enabled.
When a kernel is available for "RMSNorm" on the current device, the forward method is swapped out at import time.
use_kernel_func_from_hub(func_name)
Function-level equivalent. Decorates standalone functions (e.g., apply_rotary_transformers) rather than class forward methods.
Sources: src/transformers/integrations/hub_kernels.py62-84 src/transformers/models/gpt_oss/modeling_gpt_oss.py47-48 src/transformers/activations.py30-31
The _KERNEL_MAPPING dict maps logical layer names to LayerRepository objects, keyed by device type and optionally Mode:
Mode.TORCH_COMPILE marks kernels that are compatible with torch.compile.
Registered layers and their hub repos:
| Layer Name | Devices | Repository |
|---|---|---|
RMSNorm | cuda, rocm, xpu, mps, npu | kernels-community/liger_kernels, kernels-community/rmsnorm, kernels-community/mlx_rmsnorm |
MLP | cuda | medmekk/triton-llama-mlp |
MegaBlocksMoeMLP | cuda, rocm, xpu, cpu | kernels-community/megablocks, ahadnagy/megablocks |
FastGELU, QuickGELU, NewGELU, SiLU, GeLU, GeluTanh | cuda | kernels-community/activation |
MultiScaleDeformableAttention | cuda | kernels-community/deformable-detr |
Llama4TextMoe | cuda | kernels-community/moe |
rotary_pos_emb | cuda, xpu | kernels-community/rotary |
Sources: src/transformers/integrations/hub_kernels.py86-231
The load_and_register_attn_kernel function loads custom flash attention implementations from hub repos (e.g., kernels-community/flash-attn2, kernels-community/vllm-flash-attn3) and registers them in ALL_ATTENTION_FUNCTIONS.
Attention kernel loading flow:
The is_kernel function uses a regex ^[^/:]+/[^/:]+(?:@[^/:]+)?(?::[^/:]+)?$ to distinguish hub repo strings from built-in implementation names like "flash_attention_2" or "sdpa".
Sources: src/transformers/integrations/hub_kernels.py291-354 src/transformers/modeling_flash_attention_utils.py150-168
Separate from attention, _HUB_KERNEL_MAPPING handles op-level kernels (SSM layers etc.):
lazy_load_kernel(kernel_name) fetches and caches these modules on first use. Falls back to a pip-installed package if the kernels package is unavailable (e.g., mamba_ssm pip package).
Sources: src/transformers/integrations/hub_kernels.py282-401
| Environment Variable | Default | Effect |
|---|---|---|
USE_HUB_KERNELS | "YES" | Master switch for all hub kernel replacement |
When USE_HUB_KERNELS is set to a falsy value (NO, 0, FALSE), use_kernel_forward_from_hub becomes a no-op decorator and no kernel replacements occur.
register_kernel_mapping_transformers() must be called before kernelize() (from the kernels package) to register the _KERNEL_MAPPING. Calling it with xpu entries requires kernels >= 0.10.2.
Sources: src/transformers/integrations/hub_kernels.py58-60 src/transformers/integrations/hub_kernels.py236-243
Sources: src/transformers/integrations/executorch.py src/transformers/integrations/hub_kernels.py src/transformers/modeling_utils.py
Export and compile tests are marked with pytest markers to allow selective execution:
| Marker | Test Category |
|---|---|
@pytest.mark.torch_export_test | torch.export / ExecuTorch tests |
@pytest.mark.torch_compile_test | torch.compile tests |
@pytest.mark.flash_attn_test | Flash attention kernel tests |
Export tests exist for LLaMA, Gemma, Gemma2, Qwen2, Qwen3, and Cohere2 under their respective tests/models/*/test_modeling_*.py files. Each follows the same pattern: load with StaticCache config, wrap in TorchExportableModuleForDecoderOnlyLM, call .export(), and verify generation output matches eager baseline.
Compile tests verify that torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) with cache_implementation="static" produces identical output to uncompiled eager generation.
Sources: tests/models/llama/test_modeling_llama.py192-288 tests/models/gemma/test_modeling_gemma.py322-442 tests/models/gemma2/test_modeling_gemma2.py201-314 tests/models/qwen2/test_modeling_qwen2.py198-262
Refresh this wiki
This wiki was recently refreshed. Please wait 2 days to refresh again.