Batched Speculative Decoding
This blog is a description of my attempt to implement batched speculative decoding using a target-draft setup. I tried not to make this “yet another blog” on speculative decoding!
Background
Large Language Models (LLMs) generate text autoregressively, one token at a time. Each token requires a full forward pass through the model, making generation memory-bandwidth bound rather than compute-bound. The model weights must be loaded from memory for every single token, regardless of batch size.
Speculative decoding accelerates this process using a simple insight: given context, some tokens are easier to predict than others, such as code syntax, or the answer to 1+1=. A smaller, faster “draft” model can propose multiple tokens that a larger “target” model verifies in parallel. When the draft model’s predictions align with what the target would have produced, we get multiple tokens from a single target forward pass. The ability of the transformer to generate logits for an input sequence in parallel is what we exploit to get a speed-up in wall clock time.
The Basic Algorithm
- Draft Phase: A small model generates γ (gamma) candidate tokens autoregressively
- Verify Phase: The target model processes all γ tokens in one forward pass
- Accept/Reject: Using rejection sampling, accept tokens where draft matches target distribution
- Bonus Token: Sample one additional token from the target’s distribution
The key mathematical guarantee: the output distribution is identical to running the target model alone. We’re not approximating—we’re getting exact samples faster. (refer paper)
Rejection Sampling
For each drafted token, we compare probabilities:
-
p(x)= target model’s probability for token x -
q(x)= draft model’s probability for token x
Accept the token with probability min(1, p(x)/q(x)). If rejected at position k:
- Discard tokens k onwards
- Sample a new token from
max(0, p - q)(the “residual” distribution)
This ensures we never accept tokens the target model wouldn’t have produced.
Batched Implementation Challenges
Speculative decoding seems simple enough to implement for B=1. How do we exploit batching to process multiple sequences at once?
Extending speculative decoding to batches introduces significant complexity. The core issue: different sequences accept different numbers of tokens.
Consider a batch of 3 sequences with γ=5:
- Sequence 0: accepts 5 tokens (all drafts correct)
- Sequence 1: accepts 2 tokens (rejection at position 3)
- Sequence 2: accepts 4 tokens
After one iteration, sequences have different lengths. How do we handle this efficiently?
Two Approaches
Approach 1: Per-Sequence Tracking (v1)
- Track individual
start_positionsper sequence - Process only active sequences in draft loop
- Use DynamicCache with pruning to minimum length
- Sequences “catch up” by processing variable token ranges
Approach 2: Synchronized Batch Position (v2)
- Single
batch_positionfor all sequences - Use attention masking to handle “gaps” from rejections
- Place bonus tokens at max accepted position
- Simpler indexing but requires careful mask management
The v2 Implementation
Design Philosophy
v2 uses a synchronized batch_position pointer and StaticCache. All sequences share the same position in the tensor, with attention masks handling the variable-length reality.
After iteration with different acceptance rates:
Sequence 0: [prompt...][tok][tok][tok][tok][tok][bonus][ ][ ]
Sequence 1: [prompt...][tok][tok][PAD][PAD][PAD][bonus][ ][ ]
Sequence 2: [prompt...][tok][tok][tok][tok][PAD][bonus][ ][ ]
^
batch_position
Rejected positions are masked out (attention_mask=0), creating “gaps” in sequences.
Position IDs with Gaps
With gaps in sequences, position IDs must reflect actual content positions, not tensor indices. We compute this via cumulative sum of the attention mask:
pos_id = (attn_mask[:, :seq_len].cumsum(dim=1) - 1).clamp(min=0)
For a mask like [1,1,1,0,0,1,1], this gives positions [0,1,2,2,2,3,4]. The model sees contiguous positions despite gaps in the tensor.
The Draft Loop
for k in range(draft_steps):
draft_output = draft_model(
input_ids=input_ids[:, batch_position - 1 + k].unsqueeze(1),
attention_mask=attn_mask[:, :batch_position + k],
cache_position=cache_position + k,
position_ids=pos_id + k,
past_key_values=draft_cache,
)
# Sample from draft distribution
Q[active_indices, k] = logits_processor(draft_logits[:, -1])
new_token = logits_processor.sample(Q[active_indices, k])
# Place token and update mask
input_ids[active_indices, batch_position + k] = new_token
attn_mask[active_indices, batch_position + k] = 1
Each step:
- Feeds the previous token (at
batch_position - 1 + k) - Extends attention mask to include new position
- Stores draft probabilities in Q matrix for later comparison
- Places drafted token at
batch_position + k
Parallel Verification
The target model verifies all drafted tokens in one forward pass:
target_output = target_model(
input_ids=input_ids[:, batch_position-1:batch_position+draft_steps],
attention_mask=attn_mask[:, :batch_position+draft_steps],
cache_position=torch.arange(batch_position-1, batch_position+draft_steps),
position_ids=pos_id[:, batch_position-1:batch_position+draft_steps],
past_key_values=target_cache,
)
This processes draft_steps + 1 tokens:
- Position
batch_position - 1: the last accepted token (to get logits for first draft) - Positions
batch_positiontobatch_position + draft_steps - 1: the drafted tokens
Vectorized Acceptance
Instead of sequential rejection sampling, we vectorize:
# Gather probabilities for drafted tokens
p_tok = p[:, :draft_steps].gather(dim=2, index=drafted_tokens.unsqueeze(-1)).squeeze(-1)
q_tok = q[:, :draft_steps].gather(dim=2, index=drafted_tokens.unsqueeze(-1)).squeeze(-1)
# Log-space rejection sampling
log_ratio = torch.log(p_tok) - torch.log(q_tok)
log_r = torch.empty_like(log_ratio).uniform_().log_()
# Cumulative product finds first rejection
acceptance_status = (log_r <= log_ratio).cumprod(dim=1).bool()
num_accepted = acceptance_status.sum(dim=1)
The cumprod trick: once we hit a rejection (False), all subsequent positions become False, giving us the count of consecutively accepted tokens.
Bonus Token Placement
Here’s where batching gets tricky. Different sequences accept different amounts:
batch_pos_shift = accepted_draft_length.max().item()
input_ids[active_indices, batch_position + batch_pos_shift] = extra_tokens
attn_mask[active_indices, batch_position + batch_pos_shift] = 1
batch_position += batch_pos_shift + 1
All sequences place their bonus token at the maximum accepted position. Sequences that accepted fewer tokens have gaps between their last accepted token and the bonus. These gaps are masked out, so attention skips them.
Why place at max position?
If we placed bonus tokens at each sequence’s individual position, the next iteration’s input would be at different tensor indices per sequence, breaking batched processing. By aligning to max position, batch_position remains synchronized.
HuggingFace Cache Challenges
StaticCache vs DynamicCache
HuggingFace Transformers offers two KV-cache implementations:
DynamicCache
- Grows dynamically as tokens are generated
- Supports
crop(max_length)to truncate - No
cache_positionparameter needed—just appends - Memory reallocations can cause fragmentation
StaticCache
- Pre-allocated to maximum length
- Requires explicit
cache_positionto specify write locations - More memory efficient for known max lengths
- Can “overwrite” positions (useful for our gap scenario)
The Cache Gap Problem
With synchronized batch_position, a subtle bug emerges:
- Draft loop processes positions
[batch_position-1, batch_position+draft_steps-2] - Target verification processes positions
[batch_position-1, batch_position+draft_steps-1]
The target processes one more position than the draft loop. After the iteration:
- Target cache has KVs for positions 0 to
batch_position + draft_steps - 1 - Draft cache has KVs for positions 0 to
batch_position + draft_steps - 2
Gap: one position missing in draft cache.
In the next iteration, when batch_position advances, the draft model tries to attend to a position that was never computed—containing garbage values.
Why v1 Doesn’t Have This Problem
v1 uses per-sequence tracking and processes token ranges:
draft_output = draft_model(
input_ids=input_ids[active_indices, min_seq_len+k:max_length+k],
...
)
When sequences have different lengths, max_length - min_seq_len can be > 1, processing multiple tokens per step. The cache naturally “catches up” because sequences behind process more tokens.
Solutions for v2
Option 1: Extra Forward Pass
After the draft loop, run one more draft forward to fill the gap:
draft_model(input_ids[:, batch_position + draft_steps - 1], ...)
Simple but adds latency.
Option 2: Process 2 Tokens on First Step
On k=0 of the next iteration, process 2 tokens instead of 1 to fill the gap:
if k == 0 and draft_cache_len < batch_position - 1:
draft_model(input_ids[:, draft_cache_len:batch_position], ...)
Same total compute, just redistributed.
Option 3: Use DynamicCache with Cropping
Crop both caches to batch_position - 1 at iteration start, then process normally. The cache regrows each iteration.
Position ID Alignment
With StaticCache, cache_position determines where KVs are stored, while position_ids determines RoPE embeddings. These can differ!
For sequences with gaps:
- Cache position might be 15 (tensor index)
- Position ID might be 13 (actual content position after gaps)
The KV is stored at index 15, but RoPE is computed for position 13. When later attending, the query’s RoPE (based on its position ID) correctly matches the key’s embedded RoPE.
This works because RoPE is “baked into” K at storage time. The cache index is just for storage/retrieval—the positional information is in the embeddings.
Attention Mask Interactions
The attention mask must correctly reflect which positions are valid:
attn_mask[rejected_indices, rejected_positions] = 0 # Mask out rejections
When computing attention, masked positions contribute zero to the softmax. Even if garbage exists in the cache at those positions, it’s ignored.
Caution: The mask must be consistent between draft and target models. If the target used a different mask during verification, the cached KVs might have different attention patterns than what the draft expects.
Performance Characteristics
Acceptance Rate
The acceptance rate—proportion of drafted tokens accepted—determines speedup:
Speedup ≈ (accepted_tokens + 1) / (draft_cost + target_cost)
Factors affecting acceptance rate:
- Draft/target model alignment (same family helps)
- Task difficulty (factual recall vs creative writing)
- Temperature (lower = more deterministic = higher acceptance)
Block Efficiency
We track “block efficiency”—tokens generated per speculation block:
block_efficiency = total_tokens / num_blocks
With γ=5, perfect acceptance gives efficiency of 6 (5 drafts + 1 bonus). Real-world values of 3-4 are common.
Memory Bandwidth
Speculative decoding shines when memory-bound. The target model’s weights are loaded once to verify γ tokens instead of γ times. For large models on consumer GPUs, this is significant.
Conclusion
Batched speculative decoding requires careful orchestration of:
- Synchronized position tracking across variable-length sequences
- Attention masks that correctly handle gaps from rejections
- KV-caches that align between draft and target models
- Position IDs computed from actual content, not tensor indices
The HuggingFace cache abstractions (StaticCache/DynamicCache) provide building blocks, but their semantics around cache_position vs position_ids require careful handling. The gap problem—where draft and target caches diverge by one position—is a subtle bug that manifests as degraded acceptance rates rather than obvious failures.
Understanding these mechanics enables building efficient, correct batched speculative decoding systems that maintain the mathematical guarantees of the original algorithm while leveraging GPU parallelism.