This page documents the CUDA platform implementation in vLLM: the class hierarchy, device capability detection, attention backend selection logic, memory management, and configuration validation specific to NVIDIA GPUs.
For background on the abstract Platform interface that CudaPlatform implements, see Platform Abstraction Layer. For the ROCm implementation that follows a similar but distinct pattern, see ROCm Platform.
The CUDA platform is implemented across three concrete classes and one module-level alias. At import time, vLLM automatically selects between two implementations based on whether NVML (NVIDIA Management Library) is available.
Class hierarchy diagram
The module-level alias CudaPlatform is assigned at the bottom of vllm/platforms/cuda.py:
CudaPlatform = NvmlCudaPlatform if nvml_available else NonNvmlCudaPlatform
Sources: vllm/platforms/cuda.py112-718
On import, vllm/platforms/cuda.py attempts pynvml.nvmlInit(). If this succeeds (and the shutdown completes cleanly), nvml_available is set to True and CudaPlatform resolves to NvmlCudaPlatform. Otherwise it falls back to NonNvmlCudaPlatform, which uses torch.cuda calls directly.
| Feature | NvmlCudaPlatform | NonNvmlCudaPlatform |
|---|---|---|
| Compute capability | Via pynvml.nvmlDeviceGetCudaComputeCapability | Via torch.cuda.get_device_capability |
| Device name | Via pynvml.nvmlDeviceGetName | Via torch.cuda.get_device_name |
| Total memory | Via pynvml.nvmlDeviceGetMemoryInfo | Via torch.cuda.get_device_properties().total_memory |
| NVLink check | Via pynvml.nvmlDeviceGetP2PStatus | Returns False, logs a warning |
| Device UUID | Via pynvml.nvmlDeviceGetUUID | Not implemented |
| CUDA context init | Avoids it (NVML is independent) | Initializes CUDA context on first call |
| Heterogeneous GPU warning | Yes (log_warnings) | No |
The key advantage of NVML is that it queries device properties without initializing the CUDA context, which matters for Ray-based deployments where workers set CUDA_VISIBLE_DEVICES after the module is imported.
Sources: vllm/platforms/cuda.py581-720
DeviceCapability is a NamedTuple with major and minor fields, defined in vllm/platforms/interface.py. Comparison operators are implemented so that capability checks like has_device_capability(80) work cleanly.
NvmlCudaPlatform.get_device_capability uses the with_nvml_context decorator to wrap pynvml init/shutdown around the call:
@cache
@with_nvml_context
def get_device_capability(cls, device_id=0) -> DeviceCapability | None:
physical_device_id = cls.device_id_to_physical_device_id(device_id)
handle = pynvml.nvmlDeviceGetHandleByIndex(physical_device_id)
major, minor = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
return DeviceCapability(major=major, minor=minor)
The device_id_to_physical_device_id method (inherited from Platform) reads CUDA_VISIBLE_DEVICES to map logical to physical device IDs, so NVML (which ignores CUDA_VISIBLE_DEVICES) gets the correct physical index.
The result is @cache-d per logical device ID to avoid redundant NVML initializations.
Sources: vllm/platforms/cuda.py585-596 vllm/platforms/interface.py58-98 vllm/platforms/interface.py205-218
| Minimum Capability | Integer | GPU Families | Feature Unlocked |
|---|---|---|---|
| 6.0 | 60 | Pascal, Volta, Turing | FP16 support |
| 8.0 | 80 | Ampere, Hopper | BF16 support |
| 8.9 | 89 | Ada Lovelace, H100 | FP8 support |
| 10.x | 100 | Blackwell | FlashInfer MLA / CUTLASS MLA preferred |
Sources: vllm/platforms/cuda.py124-135 vllm/platforms/cuda.py471-474
CudaPlatformBase.get_attn_backend_cls implements a two-stage selection process:
selected_backend is not None): Validate the requested backend against the current device configuration. If invalid, raise a ValueError.selected_backend is None): Build a prioritized list of candidates via _get_backend_priorities, validate each in order, and pick the highest-priority valid one.Validation is delegated to each backend class via backend_class.validate_configuration(device_capability=..., **attn_selector_config._asdict()).
Backend priority selection flow
Sources: vllm/platforms/cuda.py342-412
_get_backend_priorities (cached) returns different orderings based on use_mla and device_capability.major.
Non-MLA backends
| Priority | Blackwell (CC 10.x) | Other NVIDIA GPUs |
|---|---|---|
| 1 | FLASHINFER | FLASH_ATTN |
| 2 | FLASH_ATTN | FLASHINFER |
| 3 | TRITON_ATTN | TRITON_ATTN |
| 4 | FLEX_ATTENTION | FLEX_ATTENTION |
MLA backends
| Priority | Blackwell (CC 10.x) | Other NVIDIA GPUs |
|---|---|---|
| 1 | FLASHINFER_MLA | FLASH_ATTN_MLA |
| 2 | CUTLASS_MLA | FLASHMLA |
| 3 | FLASH_ATTN_MLA | FLASHINFER_MLA |
| 4 | FLASHMLA | TRITON_MLA |
| 5 | TRITON_MLA | FLASHMLA_SPARSE |
| 6 | FLASHMLA_SPARSE or FLASHINFER_MLA_SPARSE | — |
On Blackwell with MLA, if num_heads <= 16, sparse backends reorder to prefer FLASHINFER_MLA_SPARSE before FLASHMLA_SPARSE, because FlashMLA uses padding that becomes expensive at low head counts.
Sources: vllm/platforms/cuda.py47-97
For vision encoder (ViT) layers, get_vit_attn_backend iterates through supported backends (FLASH_ATTN, TRITON_ATTN, TORCH_SDPA, FLASHINFER) and picks the first one that satisfies supports_head_size, supports_dtype, and supports_compute_capability. Falls back to TORCH_SDPA if none qualify.
Sources: vllm/platforms/cuda.py414-460
check_and_update_configCudaPlatformBase.check_and_update_config is called during engine initialization to enforce CUDA-specific constraints on VllmConfig.
Responsibilities:
Worker class: If parallel_config.worker_cls == "auto", sets it to "vllm.v1.worker.gpu_worker.Worker".
KV cache block size defaults: Sets cache_config.block_size = 16 if not already set (unless MLA forces a different value).
MLA block size enforcement: When MLA is active, the selected backend may require a specific block size alignment:
| Backend | Required Block Size |
|---|---|
FLASHMLA | Multiple of 64 (defaults to 64) |
CUTLASS_MLA | Multiple of 128 (defaults to 128) |
FLASHINFER_MLA | 32, or multiple of 64 (defaults to 64) |
FLASHMLA_SPARSE | Exactly 64 |
FLASHINFER_MLA_SPARSE | 32 or 64 (defaults to 64) |
Multimodal chunking: If the model uses is_mm_prefix_lm (bidirectional multimodal attention), forces scheduler_config.disable_chunked_mm_input = True.
Backend-config alignment: Writes vllm_config.attention_config.backend when auto-selection forces a specific backend (e.g., FLASHINFER_MLA on Blackwell), so downstream code has a consistent view of the selected backend.
MLA block size decision diagram
Sources: vllm/platforms/cuda.py167-297
CudaPlatformBase provides two implementations for GPU memory operations:
Empties the cache before measuring, then returns peak allocated memory in bytes. Used by GPUModelRunner during KV cache sizing.
| Method | Operation |
|---|---|
insert_blocks_to_device | GPU → GPU: src_cache[:, src_block_indices] copied to dst_cache on destination device |
swap_out_blocks_to_host | GPU → CPU: copies blocks via .cpu() |
These are used by the KV cache manager when swapping blocks between GPU and host memory. See KV Cache Management for how these are invoked.
Sources: vllm/platforms/cuda.py299-306 vllm/platforms/cuda.py544-567
NvmlCudaPlatform.is_fully_connected queries pynvml.nvmlDeviceGetP2PStatus for every pair of physical device IDs:
NVML_P2P_CAPS_INDEX_NVLINK capability.True only if all pairs report NVML_P2P_STATUS_OK.False and logs a warning if any pair fails or NVLink is absent.This is used during distributed setup to determine whether all-reduce can use custom fast paths. See Parallelism Strategies for how this feeds into tensor parallel group configuration.
Sources: vllm/platforms/cuda.py630-654
CudaPlatformBase.supported_dtypes returns a capability-gated list:
CC >= 8.0 → [bfloat16, float16, float32]
CC >= 6.0 → [float16, float32]
CC < 6.0 → [float32]
supports_fp8() returns True when has_device_capability(89), which corresponds to Ada Lovelace and Hopper GPUs.
check_if_supports_dtype raises a descriptive ValueError when bfloat16 is requested on a device below CC 8.0.
Note: On CUDA, is_fp8_fnuz() always returns False and fp8_dtype() returns torch.float8_e4m3fn (OCP FP8 standard). The FNUZ variant is an AMD-specific format.
Sources: vllm/platforms/cuda.py124-135 vllm/platforms/cuda.py471-474 vllm/platforms/cuda.py524-542
CudaPlatformBase configures NCCL as the distributed backend (dist_backend = "nccl").
stateless_init_device_torch_dist_pg creates a PyTorch ProcessGroup using ProcessGroupNCCL without relying on global state — suitable for worker processes that initialize their process groups independently.
The CUDA device communicator class is:
vllm.distributed.device_communicators.cuda_communicator.CudaCommunicator
use_custom_allreduce() returns True for CUDA, enabling the faster custom all-reduce kernel path used in tensor parallelism.
Sources: vllm/platforms/cuda.py462-517
On import of vllm/platforms/cuda.py, several side effects occur:
| Side Effect | Code Location | Purpose |
|---|---|---|
import vllm._C | vllm/platforms/cuda.py19 | Triggers registration of custom CUDA ops |
torch.backends.cuda.enable_cudnn_sdp(False) | vllm/platforms/cuda.py44 | Disables cuDNN scaled dot-product attention to avoid crashes on some models |
| NVML availability probe | vllm/platforms/cuda.py706-716 | Determines which platform class to use |
CudaPlatform.log_warnings() | vllm/platforms/cuda.py720 | Warns if heterogeneous GPUs are detected without CUDA_DEVICE_ORDER=PCI_BUS_ID |
Sources: vllm/platforms/cuda.py1-44 vllm/platforms/cuda.py704-720
| Class / Function | Role |
|---|---|
CudaPlatformBase | Shared CUDA logic: attention backend selection, config validation, memory ops, dtype checks |
NvmlCudaPlatform | NVML-backed device queries: capability, name, memory, NVLink, UUID |
NonNvmlCudaPlatform | torch.cuda-backed fallback; NVLink check always returns False |
CudaPlatform | Module-level alias assigned at import time based on NVML availability |
_get_backend_priorities | Returns cached priority-ordered AttentionBackendEnum list per (use_mla, device_capability, num_heads) |
with_nvml_context | Decorator that wraps pynvml.nvmlInit() / pynvml.nvmlShutdown() around a function |
Sources: vllm/platforms/cuda.py47-110 vllm/platforms/cuda.py112-579 vllm/platforms/cuda.py581-718
Refresh this wiki
This wiki was recently refreshed. Please wait 6 days to refresh again.