This page covers the assisted and speculative decoding subsystem: the CandidateGenerator abstraction, all concrete implementations, the candidate verification function _speculative_sampling, ConfidenceCriteria for early assistant stopping, and the dynamic token budget system.
For the GenerationConfig fields that select and configure this mode, see Generation Configuration and Modes. For the base stopping criteria framework that ConfidenceCriteria extends, see Logits Processing Pipeline. For the cache types required for rollback, see Cache System.
Standard autoregressive decoding runs one forward pass through the large target model per generated token. Speculative decoding breaks this bottleneck by using a cheap draft source to generate k candidate tokens at once. The target model then processes all k+1 positions in a single parallel forward pass, accepting tokens where it agrees and resampling at the first rejection. When the draft acceptance rate is high, effective throughput increases significantly because the target model's compute is shared across multiple accepted tokens.
The framework generalizes this into the CandidateGenerator interface, which abstracts over several draft sources: a separate smaller model, early exit from the target model itself, or n-gram lookup in the input prompt.
`src/transformers/generation/candidate_generator.py39-75
CandidateGenerator is the abstract base class for all candidate generators. It specifies two methods:
| Method | Signature | Purpose |
|---|---|---|
get_candidates | (input_ids) → (candidate_ids, candidate_logits) | Produce up to k draft token IDs and optional logits |
update_candidate_strategy | (input_ids, scores, num_matches) → None | Adapt the draft budget given how many tokens were accepted |
candidate_ids has shape (batch_size, current_length + k) — it is the full sequence including existing tokens followed by new draft tokens. candidate_logits is (batch_size, k, vocab_size) and may be None for generators that cannot produce per-token distributions.
Sources: src/transformers/generation/candidate_generator.py
Diagram: CandidateGenerator class hierarchy
Sources: src/transformers/generation/candidate_generator.py
`src/transformers/generation/candidate_generator.py78-299
Used when target and assistant share the same tokenizer. Calls assistant_model.generate() to produce up to num_assistant_tokens tokens per iteration. Key initialization decisions:
cache_implementation = "dynamic_full" on the assistant's generation config so the cache can be rolled back on rejection.generation_config.is_assistant = True, which causes ConfidenceCriteria to be injected into the assistant's stopping criteria list when assistant_confidence_threshold > 0.MinLengthLogitsProcessor from the processor list before handing it to the assistant; the main model's min_length constraint must not block draft tokens.model_kwargs via detach() or deepcopy() per tensor, but passes encoder_outputs by reference for encoder-decoder models.generation_config.eos_token_id on the assistant to match the target, so the assistant terminates on the same EOS tokens.get_candidates internally calls: _calculate_new_tokens → _update_past_and_masks → _prepare_generation_args → _generate_candidates.
Extends AssistedCandidateGenerator when target and assistant tokenizers differ. After the assistant model produces draft tokens, a re-encoding step is performed:
assistant_lookbehind and target_lookbehind control how many additional context tokens surround the re-encoded window to prevent boundary tokenization artifacts.
Extends AssistedCandidateGeneratorDifferentTokenizers with direct vocabulary-level translation, bypassing the decode-then-reencode step. It uses an AssistantToTargetTranslator obtained via AssistantVocabTranslatorCache.
Diagram: Vocabulary translation in UniversalSpeculativeDecodingGenerator
Tokens shared between both vocabularies are mapped directly. Tokens present only in the assistant vocabulary are mapped to SUPPRESS_TOKEN_ID, forcing the target model to reject them. Logits are projected into target vocabulary space for use in _speculative_sampling.
Sources: src/transformers/generation/candidate_generator.py, tests/generation/test_candidate_generator.py
Inherits from AssistedCandidateGenerator. Instead of a separate model, it uses logits from an intermediate layer of the target model itself as the draft. The assistant_early_exit integer in GenerationConfig specifies the layer index at which to exit. Requires a model where the LM head can interpret intermediate hidden states.
Requires no assistant model. At each step it:
max_matching_ngram_size tokens of the current sequence as a query.num_output_tokens tokens following the first match as candidates.Produces no logits (candidate_logits = None). The update_candidate_strategy method adapts num_output_tokens using the same heuristic as AssistedCandidateGenerator.
Sources: src/transformers/generation/candidate_generator.py
_assisted_decoding Loop`src/transformers/generation/utils.py`
GenerationMixin._assisted_decoding is called when GenerationMode.ASSISTED_GENERATION is selected by generate(). GENERATION_MODES_MAPPING maps this mode to "_assisted_decoding" at `src/transformers/generation/utils.py132-143
Diagram: Assisted decoding iteration
Between iterations, when candidates are partially rejected, the KV cache is rolled back to the accepted prefix length using DynamicLayer.crop().
Three module-level helper functions from candidate_generator.py are used within _assisted_decoding to prepare inputs for the assistant: _prepare_attention_mask, _prepare_position_ids, and _prepare_token_type_ids.
Sources: src/transformers/generation/utils.py, src/transformers/generation/candidate_generator.py
_speculative_sampling`src/transformers/generation/utils.py`
_speculative_sampling is a module-level function (visible in test imports at `tests/generation/test_utils.py106). It implements the token acceptance/rejection logic applied at each iteration.
For each draft token at position i:
p(token_i) from the target model logits (after applying logits_processor).q(token_i) from the assistant logits (if provided).i with probability min(1, p / q).normalize(max(0, p − q)).If candidate_logits is None (e.g., from PromptLookupCandidateGenerator), verification degrades to checking whether the target model's argmax agrees with the candidate token. In this case no acceptance probability is computed.
The function returns a tensor of accepted tokens plus one new token, and the count num_matches which drives update_candidate_strategy.
Sources: src/transformers/generation/utils.py, tests/generation/test_utils.py:106]()
ConfidenceCriteria is a StoppingCriteria subclass in src/transformers/generation/stopping_criteria.py, listed in generation/__init__.py under _import_structure["stopping_criteria"].
It is added to the assistant's stopping criteria list automatically by AssistedCandidateGenerator when:
generation_config.is_assistant = True (set by the generator)generation_config.assistant_confidence_threshold > 0At each assistant draft step, ConfidenceCriteria computes the softmax probability of the top-1 token. If it falls below assistant_confidence_threshold, the assistant stops generating further draft tokens even if num_assistant_tokens has not been exhausted.
The threshold is not static. update_candidate_strategy maintains two tracking lists, self.probs and self.matches, and — when sklearn is available — fits a ROC curve to find the threshold that minimizes a cost function weighting false positives at 25% and false negatives at 75%. This biases the threshold toward accepting more tokens to avoid unnecessary rejections.
`src/transformers/generation/candidate_generator.py251-280
Sources: src/transformers/generation/candidate_generator.py, src/transformers/generation/stopping_criteria.py, src/transformers/generation/__init__.py
update_candidate_strategy adjusts self.num_assistant_tokens based on the num_assistant_tokens_schedule field.
| Schedule | Rule | Persistence |
|---|---|---|
"heuristic" | +2 if all k candidates accepted; −1 otherwise | Persists across generate() calls with the same generator |
"heuristic_transient" | Same rule | Reset to initial value after each generate() call |
"constant" | No change | N/A |
The heuristic relies on the observation that a fully accepted batch suggests the assistant is calibrated for the current context, and it is safe to speculate more aggressively.
`src/transformers/generation/candidate_generator.py225-250
Sources: src/transformers/generation/candidate_generator.py
generate()Diagram: Generator selection and dispatch in GenerationMixin.generate()
All paths converge in _assisted_decoding, which is the registered handler for GenerationMode.ASSISTED_GENERATION in GENERATION_MODES_MAPPING at `src/transformers/generation/utils.py132-143
Sources: src/transformers/generation/utils.py, src/transformers/generation/candidate_generator.py
All assisted/speculative decoding settings live in GenerationConfig in src/transformers/generation/configuration_utils.py:
`src/transformers/generation/configuration_utils.py291-327
| Parameter | Type | Description |
|---|---|---|
num_assistant_tokens | int | Initial draft token budget per iteration |
num_assistant_tokens_schedule | str | Budget schedule: "heuristic", "heuristic_transient", "constant" |
assistant_confidence_threshold | float | Probability below which the assistant halts early |
prompt_lookup_num_tokens | int | Candidate count for PromptLookupCandidateGenerator |
max_matching_ngram_size | int | N-gram query size for prompt lookup (default 2) |
assistant_early_exit | int | Layer index for EarlyExitCandidateGenerator |
assistant_lookbehind | int | Context window for assistant re-encoding alignment |
target_lookbehind | int | Context window for target re-encoding alignment |
is_assistant | bool | Marks model as draft model; set automatically by generator |
Sources: src/transformers/generation/configuration_utils.py
AssistedCandidateGenerator.__init__ forces cache_implementation = "dynamic_full" on the configuration used by the assistant model:
`src/transformers/generation/candidate_generator.py188-192
The main model's cache must also support rollback. This is why the test suite passes cache_implementation = "dynamic_full" when testing assisted generation (see `tests/generation/test_utils.py701). When candidates are rejected, DynamicLayer.crop() trims the main model's cache back to the last accepted position:
`src/transformers/cache_utils.py141-153
Static caches (StaticLayer) and sliding-window caches do not support crop(), making them incompatible with this decoding path.
Sources: src/transformers/generation/candidate_generator.py, src/transformers/cache_utils.py, tests/generation/test_utils.py
`src/transformers/generation/candidate_generator.py`
A module-level singleton class that caches AssistantToTargetTranslator instances, keyed by weak references to (target_tokenizer, assistant_tokenizer) pairs. Computing the vocabulary mapping is O(vocab_size) and must not be repeated on every generate() call.
Key methods:
| Method | Description |
|---|---|
get_translator(target_tokenizer, assistant_tokenizer, ...) | Return cached translator, or create and store a new one |
cleanup() | Remove entries whose tokenizer weak references have been garbage-collected |
The use of weakref ensures that the cache does not prevent tokenizers from being garbage-collected when they fall out of scope.
Sources: src/transformers/generation/candidate_generator.py, tests/generation/test_candidate_generator.py
Refresh this wiki
This wiki was recently refreshed. Please wait 2 days to refresh again.