This page introduces the components of the training system in the transformers library: the Trainer class, TrainingArguments, the callback system, distributed training backends, and optimization utilities.
For the training loop internals and checkpoint logic, see Trainer Architecture and Training Loop. For all TrainingArguments fields in depth, see TrainingArguments Configuration. For the callback event lifecycle, see Callbacks and Extensibility. For DeepSpeed and FSDP integration details, see Distributed Training Support. For optimizer and scheduler specifics, see Optimization and Scheduling.
The training system is composed of several cooperating components. At its center is the Trainer class, which orchestrates the training loop. Configuration is provided entirely through TrainingArguments. Extensibility is handled by a callback system. Distributed and hardware-specific concerns are delegated to Accelerate and optional backends like DeepSpeed and FSDP.
Component overview:
| Component | File | Role |
|---|---|---|
Trainer | src/transformers/trainer.py | Main training/evaluation/prediction driver |
TrainingArguments | src/transformers/training_args.py | All hyperparameters and configuration |
TrainerCallback | src/transformers/trainer_callback.py | Event hooks for custom logic |
TrainerState | src/transformers/trainer_callback.py | Mutable training progress state |
TrainerControl | src/transformers/trainer_callback.py | Signal object for controlling the loop |
CallbackHandler | src/transformers/trainer_callback.py | Dispatches events to all registered callbacks |
OptimizerNames | src/transformers/training_args.py | Enum of supported optimizer names |
trainer_utils.py | src/transformers/trainer_utils.py | Seed control, eval types, checkpoint utilities |
trainer_pt_utils.py | src/transformers/trainer_pt_utils.py | PyTorch-specific data/distributed helpers |
Sources: src/transformers/trainer.py249-360 src/transformers/training_args.py177-200 src/transformers/trainer_callback.py34-300
Diagram: Training System Component Relationships
Sources: src/transformers/trainer.py362-612 src/transformers/trainer_callback.py1-50 src/transformers/training_args.py110-157
The Trainer class (src/transformers/trainer.py255-348) is the library's primary training abstraction. It is a PyTorch-only class and requires accelerate as a dependency, enforced via the @requires(backends=("torch", "accelerate")) decorator.
The Trainer.__init__ method (src/transformers/trainer.py362-381) accepts:
| Argument | Type | Purpose |
|---|---|---|
model | PreTrainedModel | nn.Module | Model to train |
args | TrainingArguments | All training configuration |
data_collator | DataCollator | Batch assembly function |
train_dataset | Dataset | IterableDataset | Training data |
eval_dataset | Dataset | dict | Evaluation data |
processing_class | Tokenizer/Processor | Used for DataCollatorWithPadding and saving |
model_init | Callable | Factory for hyperparameter search |
compute_loss_func | Callable | Custom loss function |
compute_metrics | Callable[[EvalPrediction], dict] | Metric computation at eval |
callbacks | list[TrainerCallback] | Additional callbacks |
optimizers | tuple[Optimizer, LRScheduler] | Pre-built optimizer/scheduler pair |
optimizer_cls_and_kwargs | tuple[Type, dict] | Optimizer class + kwargs (alternative to optimizers) |
preprocess_logits_for_metrics | Callable | Transform logits before metric caching |
Trainer.__init__ follows a documented 11-step sequence (src/transformers/trainer.py383-612):
set_seed or enable_full_determinismcreate_accelerator_and_postprocess(), sets up memory trackermodel vs. model_init, applies Liger kernel if requested, calls validate_quantization_for_trainingaccepts_loss_kwargs, determines label names, creates LabelSmootherdata_collator, datasets, compute_metrics, optimizer referencesCallbackHandler with default + reporting + user callbacksinit_hf_repo() if push_to_hub=True, creates output_dirTrainerState and TrainerControluse_cache on model config, handles XLA FSDPv2 mesh| Attribute | Description |
|---|---|
self.model | The unwrapped core model |
self.model_wrapped | The outermost wrapper (DDP, DeepSpeed, etc.) |
self.accelerator | The Accelerator instance |
self.state | TrainerState — current training progress |
self.control | TrainerControl — signals to the loop |
self.callback_handler | CallbackHandler dispatching events |
self.is_model_parallel | True when hf_device_map spans multiple GPUs |
self.is_deepspeed_enabled | Set by create_accelerator_and_postprocess |
self.is_fsdp_enabled | Set by create_accelerator_and_postprocess |
Sources: src/transformers/trainer.py330-612
TrainingArguments (src/transformers/training_args.py177-600), a @dataclass, is the single configuration object for all training behavior. It is designed to be parseable from the command line via HfArgumentParser.
Diagram: TrainingArguments Field Groups
OptimizerNames (src/transformers/training_args.py110-156) lists all recognized optimizer identifiers. Key values:
| Enum Value | String Key | Requires |
|---|---|---|
ADAMW_TORCH | "adamw_torch" | PyTorch built-in |
ADAMW_TORCH_FUSED | "adamw_torch_fused" | PyTorch >= 2.8 (default there) |
ADAFACTOR | "adafactor" | Built-in |
ADAMW_BNB | "adamw_bnb_8bit" | bitsandbytes |
GALORE_ADAMW | "galore_adamw" | galore_torch |
LOMO | "lomo" | lomo |
SCHEDULE_FREE_ADAMW | "schedule_free_adamw" | schedulefree |
SchedulerType / IntervalStrategy / SaveStrategy are defined in src/transformers/trainer_utils.py and used as field types in TrainingArguments.
effective_batch_size = per_device_train_batch_size
× num_devices
× gradient_accumulation_steps
Sources: src/transformers/training_args.py177-600 src/transformers/training_args.py110-156
The callback system provides structured injection points throughout the training loop.
Diagram: Callback Event Flow
TrainerCallback (src/transformers/trainer_callback.py) — Abstract base with event methods. All event methods receive (args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs).TrainerState (src/transformers/trainer_callback.py34-120) — Mutable dataclass tracking global_step, epoch, log_history, best_metric, best_model_checkpoint, and more. Serialized to trainer_state.json at each checkpoint.TrainerControl (src/transformers/trainer_callback.py) — Dataclass with boolean flags: should_training_stop, should_epoch_stop, should_save, should_evaluate, should_log. Callbacks set these; the training loop reads them.CallbackHandler (src/transformers/trainer_callback.py) — Holds the ordered list of callbacks, dispatches events, and merges TrainerControl signals.| Callback Class | Purpose |
|---|---|
DefaultFlowCallback | Sets should_log, should_evaluate, should_save based on strategy |
ProgressCallback | tqdm progress bar display |
PrinterCallback | Simple print-based progress (no tqdm) |
EarlyStoppingCallback | Stops training when metric stagnates |
TensorBoardCallback | Logs to TensorBoard |
WandbCallback | Logs to Weights & Biases |
MLflowCallback | Logs to MLflow |
CometCallback | Logs to Comet ML |
DVCLiveCallback | Logs to DVCLive |
Reporting callbacks are registered automatically from get_reporting_integration_callbacks(args.report_to) (src/transformers/trainer.py553).
ExportableStateCallbacks implementing ExportableState (src/transformers/trainer_callback.py) have their state serialized to trainer_state.json so training can be resumed exactly. EarlyStoppingCallback and TrainerControl implement this protocol.
Sources: src/transformers/trainer_callback.py1-400 src/transformers/trainer.py552-566 src/transformers/integrations/integration_utils.py1-100
Diagram: Distributed Backend Selection in Trainer
Every Trainer creates an Accelerator instance (src/transformers/trainer.py409). The Accelerator handles:
AcceleratorConfig (src/transformers/trainer_pt_utils.py) is a dataclass populated from TrainingArguments fields and passed to Accelerator via DataLoaderConfiguration and GradientAccumulationPlugin.
DeepSpeed integration is in src/transformers/integrations/deepspeed.py. Key functions called from Trainer:
deepspeed_init(trainer, num_training_steps) — Initializes DeepSpeed engine from the config in args.deepspeed, wraps the model and optimizer.deepspeed_load_checkpoint(deepspeed_engine, checkpoint_path) — Resumes optimizer and model states.propagate_args_to_deepspeed(trainer, auto_find_batch_size) — Syncs TrainingArguments fields into the DeepSpeed config dict.HfDeepSpeedConfig — Must be instantiated before the model when using ZeRO-3; it patches the model init to partition parameters.FSDP is configured through args.fsdp (a list of FSDPOption strings) and args.fsdp_config (a dict). The Trainer uses Accelerate's FSDP utilities:
save_fsdp_model / load_fsdp_modelsave_fsdp_optimizer / load_fsdp_optimizerget_fsdp_ckpt_kwargs from src/transformers/integrations/fsdp.pyXLA FSDPv2 (for TPU) is handled via torch_xla.distributed.spmd when fsdp_config["xla_fsdp_v2"] is true (src/transformers/trainer.py604-610).
Enabled by is_sagemaker_mp_enabled(). Uses smdistributed.modelparallel.torch and the helper functions smp_forward_backward, smp_forward_only, smp_nested_concat from src/transformers/trainer_pt_utils.py.
Sources: src/transformers/trainer.py407-488 src/transformers/integrations/deepspeed.py src/transformers/integrations/fsdp.py src/transformers/integrations/accelerate.py
Trainer.create_optimizer() calls handlers registered in _OPTIMIZER_HANDLERS (src/transformers/trainer_optimizer.py). For the default adamw_torch case, it creates torch.optim.AdamW with parameter groups that exclude bias and LayerNorm parameters from weight_decay. For specialized optimizers (GaLore, LOMO, bitsandbytes 8-bit, etc.), external packages are invoked.
optimizer_cls_and_kwargs is an alternative to passing a pre-built optimizers tuple; it defers instantiation so the model can be placed on the correct device first.
get_scheduler(name, optimizer, ...) from src/transformers/optimization.py maps lr_scheduler_type to one of:
SchedulerType | Behavior |
|---|---|
linear | Linear decay with warmup |
cosine | Cosine decay with warmup |
cosine_with_restarts | Cosine with hard restarts |
polynomial | Polynomial decay |
constant | Flat LR |
constant_with_warmup | Flat after warmup |
inverse_sqrt | T5-style inverse square root |
reduce_lr_on_plateau | Metric-driven reduction |
warmup_stable_decay | Three-phase schedule |
Applied in the training step via accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm). Controlled by TrainingArguments.max_grad_norm (default 1.0).
Activated by calling model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs) when args.gradient_checkpointing=True.
validate_quantization_for_training(model) (src/transformers/trainer_utils.py102-150) is called during Trainer.__init__. It raises ValueError if:
torch.compile (incompatible with PEFT fine-tuning)hf_quantizer.is_trainable = False)Sources: src/transformers/trainer_utils.py102-150 src/transformers/optimization.py src/transformers/trainer.py546-548
Diagram: Data Path through Trainer
get_train_dataloader() wraps train_dataset in a DataLoader, applying LengthGroupedSampler when group_by_length=True or standard RandomSampler / DistributedSampler.IterableDatasetShard (src/transformers/trainer_pt_utils.py) shards iterable datasets across distributed processes.RemoveColumnsCollator (src/transformers/trainer_utils.py) strips dataset columns that are not in model.forward's signature, unless remove_unused_columns=False.EvalPrediction (src/transformers/trainer_utils.py) is a NamedTuple with fields predictions, label_ids, and optionally inputs, passed to compute_metrics.
EvalLoopOutput wraps the full evaluation result: predictions, label_ids, metrics, and num_samples.
Sources: src/transformers/trainer.py349-358 src/transformers/trainer_pt_utils.py1-100 src/transformers/trainer_utils.py1-100
Checkpoint files written per save event:
| File | Content |
|---|---|
model.safetensors / pytorch_model.bin | Model weights |
config.json | PreTrainedConfig |
training_args.bin | Serialized TrainingArguments |
trainer_state.json | TrainerState (step, epoch, log history) |
optimizer.pt | Optimizer state dict |
scheduler.pt | LR scheduler state dict |
scaler.pt | GradScaler state (FP16) |
rng_state_*.pth | Per-process RNG state |
save_strategy controls frequency: "steps", "epoch", "best", "no".save_total_limit triggers rotate_checkpoints() (src/transformers/trainer_utils.py) to delete the oldest checkpoints.save_only_model=True skips optimizer/scheduler/RNG state, reducing checkpoint size but preventing exact resume.enable_jit_checkpoint=True registers a SIGTERM handler via JITCheckpointCallback for preemptible workloads.Sources: src/transformers/trainer.py238-246 src/transformers/training_args.py446-475 src/transformers/trainer_utils.py1-50
TrainingArguments.report_to accepts one or more of:
"azure_ml", "clearml", "codecarbon", "comet_ml", "dagshub", "dvclive", "flyte", "mlflow", "swanlab", "tensorboard", "trackio", "wandb".
Each integration is a TrainerCallback subclass registered via get_reporting_integration_callbacks (src/transformers/integrations/integration_utils.py). All implement on_log to forward the logs dict to the tracking platform.
Sources: src/transformers/training_args.py382-398 src/transformers/integrations/integration_utils.py100-200 src/transformers/trainer.py553
Trainer.hyperparameter_search() supports Optuna, Ray Tune, and SigOpt via the HPSearchBackend enum and the model_init callable. Each trial reinitializes the model from model_init, trains, and returns a BestRun object with the best hyperparameters and metric value. default_compute_objective uses the evaluation loss as the objective by default.
Sources: src/transformers/trainer_utils.py1-50 src/transformers/hyperparameter_search.py
Refresh this wiki
This wiki was recently refreshed. Please wait 2 days to refresh again.