This page documents the callback system that allows custom logic to be injected into the Trainer training loop. It covers TrainerCallback, TrainerState, TrainerControl, CallbackHandler, built-in callbacks, and integration callbacks.
For the overall Trainer architecture and how the training loop itself executes, see 3.1 For TrainingArguments configuration fields that control callback-relevant schedules (logging, saving, evaluation), see 3.2
The callback system provides a set of event hooks that fire at well-defined points during training. Callbacks receive the current TrainingArguments, TrainerState, and TrainerControl on every event. They can read state, write to logs, or signal the Trainer to change its behavior (save, evaluate, stop) by modifying fields on the returned TrainerControl object.
All callback infrastructure is defined in src/transformers/trainer_callback.py
TrainerState is a @dataclass that carries all mutable training progress information. It is checkpointed alongside the model.
src/transformers/trainer_callback.py34-150
| Field | Type | Description |
|---|---|---|
epoch | float | None | Current epoch (decimal part = fraction of epoch done) |
global_step | int | Number of optimizer update steps completed |
max_steps | int | Total planned update steps |
num_train_epochs | int | Total planned epochs |
logging_steps | int | Steps between log events |
eval_steps | int | Steps between evaluation events |
save_steps | int | Steps between checkpoint saves |
train_batch_size | int | Effective per-device batch size |
num_input_tokens_seen | int | Cumulative non-padding input tokens |
total_flos | float | Cumulative floating-point operations |
log_history | list[dict] | List of all logged metrics dicts |
best_metric | float | None | Best value of metric_for_best_model seen |
best_model_checkpoint | str | None | Path to checkpoint with best metric |
is_local_process_zero | bool | Whether this is the main process on this node |
is_world_process_zero | bool | Whether this is the global main process |
stateful_callbacks | list[ExportableState] | Callbacks that implement state save/restore |
TrainerState is saved to trainer_state.json at every checkpoint, and can be loaded with TrainerState.load_from_json.
TrainerControl is a @dataclass holding boolean flags. Callbacks set these flags to request actions from the Trainer. The Trainer reads and resets them at appropriate points. TrainerControl also implements ExportableState so its flags are persisted across checkpoints.
src/transformers/trainer_callback.py150-220
| Field | Type | Default | Effect when True |
|---|---|---|---|
should_training_stop | bool | False | Ends training after the current step |
should_epoch_stop | bool | False | Ends the current epoch after the current step |
should_save | bool | False | Triggers a checkpoint save |
should_evaluate | bool | False | Triggers an evaluation run |
should_log | bool | False | Triggers a log event |
TrainerCallback is the abstract base class for all callbacks. Subclasses override only the event methods they need.
src/transformers/trainer_callback.py220-380
Every event method has the signature:
on_<event>(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> TrainerControl | None
The **kwargs are populated differently depending on the event. Common keyword arguments passed are:
| kwarg | Provided in events |
|---|---|
model | Most events |
tokenizer / processing_class | Most events |
optimizer | on_step_end, on_train_begin |
lr_scheduler | on_step_end, on_train_begin |
logs | on_log |
metrics | on_evaluate |
Event methods:
| Method | Fires when |
|---|---|
on_init_end | End of Trainer.__init__ |
on_train_begin | Start of Trainer.train() |
on_train_end | End of Trainer.train() |
on_epoch_begin | Start of each epoch |
on_epoch_end | End of each epoch |
on_step_begin | Before each optimizer step's forward pass |
on_substep_end | After each gradient-accumulation substep |
on_step_end | After each optimizer update step |
on_evaluate | After each evaluation run completes |
on_predict | After each prediction run completes |
on_save | After each checkpoint is saved |
on_log | When metrics are being logged |
on_prediction_step | After each prediction batch |
If a callback returns a TrainerControl, the handler merges it with the current control by ORing all boolean flags.
Sources: src/transformers/trainer_callback.py
CallbackHandler is itself a TrainerCallback. It holds the list of registered callbacks and dispatches every event to all of them in insertion order.
src/transformers/trainer_callback.py380-480
Trainer's callback setup (from Trainer.__init__):
src/transformers/trainer.py552-566
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(args.report_to)
# User-supplied callbacks are appended after defaults
callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
self.callback_handler = CallbackHandler(
callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
)
self.add_callback(PrinterCallback if args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
CallbackHandler exposes:
add_callback(callback) — append a callback instance or classremove_callback(callback) — remove by instance or classpop_callback(callback) — remove and returncallbacks — list of registered TrainerCallback instancesThe Trainer itself exposes the same add_callback, remove_callback, and pop_callback methods that delegate to its CallbackHandler.
Sources: src/transformers/trainer_callback.py380-480 src/transformers/trainer.py552-566
Callback event firing order within Trainer.train()
Sources: src/transformers/trainer.py src/transformers/trainer_callback.py
Class hierarchy and responsibilities
Sources: src/transformers/trainer_callback.py
src/transformers/trainer_callback.py480-560
Always present (DEFAULT_CALLBACKS = [DefaultFlowCallback]). On each on_step_end and on_epoch_end, it reads TrainingArguments and TrainerState to decide whether to set control.should_log, control.should_evaluate, and control.should_save. It enforces the logging, evaluation, and save strategies ("steps", "epoch", "no", "best").
src/transformers/trainer_callback.py560-640
Shows tqdm progress bars during training and prediction. Uses on_train_begin/on_train_end to open/close the training bar, and on_prediction_step for the prediction bar. Selected by default when disable_tqdm=False; replaced by PrinterCallback when disable_tqdm=True.
In notebook environments, DEFAULT_PROGRESS_CALLBACK is automatically replaced with NotebookProgressCallback from src/transformers/utils/notebook.py.
src/transformers/trainer_callback.py640-680
A minimal callback that prints log dicts to stdout on on_log. Used when args.disable_tqdm=True.
src/transformers/trainer_callback.py680-780
Monitors a metric specified by args.metric_for_best_model. If the metric does not improve by more than early_stopping_threshold for early_stopping_patience consecutive evaluations, it sets control.should_training_stop = True.
Constructor parameters:
| Parameter | Default | Description |
|---|---|---|
early_stopping_patience | 1 | Number of evaluations without improvement before stopping |
early_stopping_threshold | 0.0 | Minimum absolute improvement to count as better |
EarlyStoppingCallback implements ExportableState, so its patience counter is saved in trainer_state.json.
Requires args.load_best_model_at_end=True and args.metric_for_best_model to be set.
Sources: src/transformers/trainer_callback.py
ExportableState is a mixin that callbacks can implement to make their internal state persistent across checkpoints.
src/transformers/trainer_callback.py780-820
class ExportableState:
def state(self) -> dict:
"""Return a serializable dict of internal state."""
...
def from_state(self, state: dict):
"""Restore internal state from a dict."""
...
When the Trainer initializes TrainerState, it collects all ExportableState callbacks:
src/transformers/trainer.py576-583
TrainerControl itself implements ExportableState. Built-in ExportableState implementors:
TrainerControlEarlyStoppingCallbackSources: src/transformers/trainer_callback.py780-820 src/transformers/trainer.py576-583
When args.report_to lists one or more integrations, the Trainer calls get_reporting_integration_callbacks(args.report_to) from src/transformers/integrations/__init__.py to obtain matching callback instances.
src/transformers/trainer.py39-41 src/transformers/trainer.py553
All integration callbacks are defined in src/transformers/integrations/integration_utils.py
Integration callback map
Each integration callback:
on_train_beginon_logon_save (where applicable)on_train_endTensorBoardCallback additionally logs hyperparameters on on_train_begin and writes the SummaryWriter on on_train_end.
WandbCallback supports resuming a run when args.resume_from_checkpoint is set, and can log model artifacts based on args.hub_strategy.
Sources: src/transformers/integrations/integration_utils.py96-99 src/transformers/trainer.py39-41 src/transformers/trainer.py553-554
Full callback assembly sequence
Sources: src/transformers/trainer.py181-182 src/transformers/trainer.py552-566 src/transformers/trainer.py184-187
To inject custom logic, subclass TrainerCallback, override the relevant event methods, and pass an instance to Trainer(callbacks=[...]).
Example pattern (from tests/trainer/test_trainer.py):
src/transformers/trainer.py311-314 tests/trainer/test_trainer.py189-198
To signal early stopping, set control.should_training_stop = True and return the control object.
To add or remove callbacks after Trainer is constructed:
All three methods delegate to trainer.callback_handler.
Sources: src/transformers/trainer_callback.py tests/trainer/test_trainer.py189-214 tests/trainer/test_trainer_callback.py
Refresh this wiki
This wiki was recently refreshed. Please wait 2 days to refresh again.