This page covers the Trainer class in src/transformers/trainer.py: its constructor initialization sequence, the training loop (train, _inner_training_loop, training_step, compute_loss), the evaluation loop (evaluate, evaluation_loop, prediction_step), checkpoint management, gradient accumulation, mixed-precision support, and Accelerate integration.
For TrainingArguments field-by-field documentation, see 3.2. For the callback event lifecycle, see 3.3. For DeepSpeed, FSDP, and tensor-parallel distributed training, see 3.4. For optimizers and LR schedulers, see 3.5.
Trainer is defined in src/transformers/trainer.py and requires torch and accelerate as hard dependencies (enforced at class definition by @requires(backends=("torch", "accelerate"))).
| Parameter | Type | Description |
|---|---|---|
model | PreTrainedModel | nn.Module | Model to train. Must have a task head — bare base models (those in MODEL_MAPPING_NAMES) are rejected with a helpful error. |
args | TrainingArguments | All hyperparameters and infrastructure settings. Defaults to TrainingArguments(output_dir="tmp_trainer"). |
data_collator | DataCollator | Batch collation callable. Defaults to DataCollatorWithPadding(processing_class) when a tokenizer or feature extractor is provided, otherwise default_data_collator. |
train_dataset | Dataset | IterableDataset | Training data. HuggingFace datasets.Dataset columns absent from model.forward() are stripped automatically. |
eval_dataset | Dataset | dict[str, Dataset] | Evaluation data. Dict input triggers evaluation per key, prefixing metric names. |
processing_class | PreTrainedTokenizerBase | BaseImageProcessor | ProcessorMixin | Saved alongside model for reproducibility. |
model_init | Callable[[], PreTrainedModel] | Factory for fresh model instances. Mutually exclusive with model. Required for hyperparameter search. |
compute_loss_func | Callable | Override default loss. Receives (outputs, labels, num_items_in_batch). |
compute_metrics | Callable[[EvalPrediction], dict] | Receives an EvalPrediction at evaluation time; returns a dict of metric values. |
callbacks | list[TrainerCallback] | Appended to the default callback list (not replacing it). |
optimizers | tuple[Optimizer, LRScheduler] | Pre-built optimizer and LR scheduler. Mutually exclusive with optimizer_cls_and_kwargs. |
optimizer_cls_and_kwargs | tuple[type, dict] | Optimizer class and kwargs; instantiated after model parameters are on-device. |
preprocess_logits_for_metrics | Callable[[Tensor, Tensor], Tensor] | Applied to logits before they are cached during evaluation; result is what compute_metrics sees. |
| Attribute | Description |
|---|---|
self.model | Always the core unwrapped model |
self.model_wrapped | Outermost wrapper (DDP, DeepSpeed, etc.); used for forward passes |
self.accelerator | The accelerate.Accelerator instance |
self.state | TrainerState — serialized at each checkpoint |
self.control | TrainerControl — boolean flags set by callbacks |
self.callback_handler | CallbackHandler — dispatches events to all registered callbacks |
self.is_in_train | True between train() entry and exit |
self.is_model_parallel | True when model spans multiple devices via hf_device_map |
self.place_model_on_device | Whether Trainer moves the model; disabled for DeepSpeed, FSDP, MP, etc. |
self._train_batch_size | Effective per-step batch size; may decrease under auto_find_batch_size |
Sources: src/transformers/trainer.py255-612
The constructor is organized into 11 sequential stages. The inline comment block at src/transformers/trainer.py382-394 documents this order explicitly.
Trainer.__init__ — 11-Step Initialization Flow
Sources: src/transformers/trainer.py382-612
train()train(resume_from_checkpoint=None):
self.is_in_train = True.create_optimizer_and_scheduler(num_training_steps) if no optimizer was pre-supplied.self.accelerator.prepare(model, optimizer, dataloader, lr_scheduler) to apply backend-specific wrapping; updates self.model_wrapped._inner_training_loop() (or its find_executable_batch_size-wrapped variant when args.auto_find_batch_size is set).args.load_best_model_at_end).is_in_train = False and returns TrainOutput(global_step, training_loss, metrics)._inner_training_loop()Iterates epochs and steps; dispatches callbacks and side-effects on each event.
Training Data Flow — from Batch to Optimizer Step
Sources: src/transformers/trainer.py1-240
Controlled by args.gradient_accumulation_steps. The Trainer uses accelerate.utils.GradientAccumulationPlugin — no manual no_sync() or loss scaling is needed in user code. One optimizer update step spans gradient_accumulation_steps micro-batches. The num_items_in_batch variable tracks the total sample or token count across the window for correct loss normalization.
When args.average_tokens_across_devices is True, an all_reduce synchronizes num_tokens_in_batch across ranks.
training_step()training_step(model, inputs, num_items_in_batch):
train() mode.inputs to the training device.compute_loss().self.accelerator.backward(loss) — not loss.backward() directly, so AMP loss scaling is handled by Accelerate.compute_loss()compute_loss(model, inputs, return_outputs=False, num_items_in_batch=None):
inputs using self.label_names.outputs = model(**inputs).self.compute_loss_func is set: delegates to it with (outputs, labels, num_items_in_batch).outputs.loss; if self.label_smoother is set, recomputes loss through it.(loss, outputs) or just loss.Mixed precision is fully managed by Accelerate based on args.fp16 or args.bf16:
accelerator.backward() scales the loss via GradScaler; unscaling happens before the optimizer step. Scaler state is saved to scaler.pt.args.fp16_full_eval, args.bf16_full_eval): The model is cast to the specified dtype only during evaluation, not training.When args.gradient_checkpointing is True, the Trainer calls model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs) before training begins. Activations are recomputed during the backward pass rather than stored, trading compute for memory.
torch.compileWhen args.torch_compile is True, the model is compiled via torch.compile(model, backend=args.torch_compile_backend, mode=args.torch_compile_mode). Note: torch.compile is incompatible with quantized PEFT fine-tuning and is rejected by validate_quantization_for_training().
Sources: src/transformers/trainer.py1-248 src/transformers/training_args.py256-320
evaluate()evaluate(eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"):
self.eval_dataset).evaluation_loop() and returns a flat dict of metric values.on_evaluate callback event.predict()predict(test_dataset, ignore_keys=None, metric_key_prefix="test"):
evaluate() but returns PredictionOutput(predictions, label_ids, metrics) — the raw arrays in addition to metrics.evaluation_loop()evaluation_loop(dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix):
eval() mode.prediction_step() per batch.EvalLoopContainer.nested_gather() to collect tensors across processes.EvalPrediction and passes it to self.compute_metrics.EvalLoopOutput(predictions, label_ids, metrics, num_samples).When args.batch_eval_metrics is True, compute_metrics is called at the end of each batch (with a compute_result flag) to support streaming metric computation that avoids holding all logits in memory.
prediction_step()prediction_step(model, inputs, prediction_loss_only, ignore_keys):
model(**inputs) under torch.no_grad() (and autocast for fp16/bf16 eval).loss, logits, and labels.self.preprocess_logits_for_metrics(logits, labels).(loss, logits, labels).EvalPrediction — the Metrics InterfaceEvalPrediction (defined in src/transformers/trainer_utils.py) is the NamedTuple passed to compute_metrics:
| Field | Type | Included when |
|---|---|---|
predictions | np.ndarray | tuple[np.ndarray] | Always |
label_ids | np.ndarray | tuple[np.ndarray] | None | Dataset has labels |
inputs | np.ndarray | tuple[np.ndarray] | None | "inputs" in args.include_for_metrics |
Evaluation Data Flow — from Batches to Metrics
Sources: src/transformers/trainer.py100-148 src/transformers/trainer_utils.py1-170 src/transformers/trainer_pt_utils.py1-100
_save_checkpoint(model, trial) writes to {output_dir}/checkpoint-{global_step}/. The PREFIX_CHECKPOINT_DIR = "checkpoint" constant (from trainer_utils.py) defines the prefix.
| File constant | Filename | Saved when |
|---|---|---|
SAFE_WEIGHTS_NAME | model.safetensors | Always |
CONFIG_NAME | config.json | Always |
TRAINING_ARGS_NAME | training_args.bin | save_only_model=False |
TRAINER_STATE_NAME | trainer_state.json | save_only_model=False |
OPTIMIZER_NAME | optimizer.pt | save_only_model=False |
SCHEDULER_NAME | scheduler.pt | save_only_model=False |
SCALER_NAME | scaler.pt | save_only_model=False, fp16 |
rng_state_{rank}.pth | Per-process RNG | save_only_model=False |
FSDP_MODEL_NAME | pytorch_model_fsdp/ | FSDP only |
The constants are defined at src/transformers/trainer.py239-246
rotate_checkpoints() enforces args.save_total_limit, always preserving the best checkpoint when args.load_best_model_at_end=True.
When train(resume_from_checkpoint=path):
TrainerState is reloaded from trainer_state.json.skip_first_batches().rng_state_{rank}.pth.Sources: src/transformers/trainer.py239-247 src/transformers/trainer_utils.py117-145
create_accelerator_and_postprocess() reads args.accelerator_config (an AcceleratorConfig from trainer_pt_utils.py) and instantiates an accelerate.Accelerator with:
GradientAccumulationPlugin(num_steps=args.gradient_accumulation_steps)DataLoaderConfiguration for dispatch/split batchesDistributedDataParallelKwargs for DDP tuningAfter accelerator.prepare() is called inside train(), self.model_wrapped is updated to the fully wrapped model (e.g., DeepSpeed → DDP → model). self.model always remains the unwrapped core model.
Trainer ↔ Accelerator Interaction Points
Sources: src/transformers/trainer.py213-230 src/transformers/integrations/accelerate.py1-120
get_train_dataloader()Returns a torch.utils.data.DataLoader configured with:
RandomSampler (single-process) or DistributedSampler (multi-process)LengthGroupedSampler when args.group_by_length=TrueIterableDatasetShard for IterableDataset to handle distributed splittingself.data_collator as the collation functionself._train_batch_sizeget_eval_dataloader(eval_dataset=None)Returns a DataLoader with:
SequentialSampler (order preserved)args.per_device_eval_batch_size × args.n_gpu_remove_unused_columns(dataset, description) inspects model.forward() parameter names (cached in self._signature_columns) and strips any dataset columns not present in the model signature. This is performed lazily on the first call.
Sources: src/transformers/trainer.py83-116 src/transformers/trainer_pt_utils.py40-110
create_optimizer_and_scheduler(num_training_steps) calls:
create_optimizer() — dispatches to _OPTIMIZER_HANDLERS[args.optim] in src/transformers/trainer_optimizer.py.create_scheduler(num_training_steps, optimizer) — calls optimization.get_scheduler(args.lr_scheduler_type, optimizer, ...).Both steps are skipped if the user supplied pre-built instances via the optimizers constructor argument.
_OPTIMIZER_HANDLERS is a registry mapping each OptimizerNames enum value to a handler function that constructs the appropriate optimizer (AdamW, Adafactor, 8-bit Adam, GaLore, etc.).
See 3.5 for the full list of available optimizers and scheduler types.
Sources: src/transformers/trainer.py93-98 src/transformers/training_args.py110-157
During __init__, the Trainer introspects the unwrapped model to set:
| Attribute | How determined |
|---|---|
model_accepts_loss_kwargs | Whether forward() accepts **kwargs (checked via inspect.signature) |
label_names | find_labels(model_class) inspects forward() for params like labels, start_positions |
can_return_loss | can_return_loss(model_class) from utils/generic.py |
label_smoother | LabelSmoother(epsilon=args.label_smoothing_factor) when factor > 0 |
model_accepts_loss_kwargs controls whether num_items_in_batch is passed through to the model forward call during training, enabling token-level loss normalization for language models.
Sources: src/transformers/trainer.py493-526
When args.neftune_noise_alpha is set, the Trainer calls activate_neftune(model) (from src/transformers/integrations/neftune.py) before training begins. This patches the embedding layer's forward pass with noise injection to improve instruction-tuning quality. deactivate_neftune(model) is called unconditionally when training ends.
Sources: src/transformers/trainer.py71-72
TrainerState and TrainerControlBoth are defined in src/transformers/trainer_callback.py.
TrainerState — Serialized Training Progress| Field | Description |
|---|---|
epoch | Current epoch as float; decimal = fraction of epoch completed |
global_step | Number of optimizer update steps completed |
max_steps | Total steps planned |
log_history | List of all logged metric dicts |
best_metric | Best observed metric_for_best_model value |
best_model_checkpoint | Path to the corresponding checkpoint directory |
stateful_callbacks | Callbacks that are ExportableState instances, serialized alongside state |
TrainerState is written to trainer_state.json at each checkpoint via save_state() and restored on resume.
TrainerControl — Callback Signal Flags| Flag | Effect |
|---|---|
should_training_stop | Halts the outer training loop |
should_epoch_stop | Halts the current epoch |
should_save | Triggers _save_checkpoint() |
should_evaluate | Triggers evaluation_loop() |
should_log | Triggers metric logging |
Callbacks set these flags in their event handlers. The Trainer checks them after each callback dispatch.
Sources: src/transformers/trainer_callback.py34-250
Code Entity Map — Trainer Subsystem
Sources: src/transformers/trainer.py83-148 src/transformers/trainer_callback.py1-50 src/transformers/trainer_utils.py1-160 src/transformers/trainer_pt_utils.py1-65 src/transformers/trainer_optimizer.py1-100
Refresh this wiki
This wiki was recently refreshed. Please wait 2 days to refresh again.