This blog is a description of my attempt to implement batched speculative decoding using a target-draft setup. I tried not to make this “yet another blog” on speculative decoding!

Background

Large Language Models (LLMs) generate text autoregressively, one token at a time. Each token requires a full forward pass through the model, making generation memory-bandwidth bound rather than compute-bound. The model weights must be loaded from memory for every single token, regardless of batch size.

Speculative decoding accelerates this process using a simple insight: given context, some tokens are easier to predict than others, such as code syntax, or the answer to 1+1=. A smaller, faster “draft” model can propose multiple tokens that a larger “target” model verifies in parallel. When the draft model’s predictions align with what the target would have produced, we get multiple tokens from a single target forward pass. The ability of the transformer to generate logits for an input sequence in parallel is what we exploit to get a speed-up in wall clock time.

The Basic Algorithm

  1. Draft Phase: A small model generates γ (gamma) candidate tokens autoregressively
  2. Verify Phase: The target model processes all γ tokens in one forward pass
  3. Accept/Reject: Using rejection sampling, accept tokens where draft matches target distribution
  4. Bonus Token: Sample one additional token from the target’s distribution

The key mathematical guarantee: the output distribution is identical to running the target model alone. We’re not approximating—we’re getting exact samples faster. (refer paper)

Rejection Sampling

For each drafted token, we compare probabilities:

  • p(x) = target model’s probability for token x
  • q(x) = draft model’s probability for token x

Accept the token with probability min(1, p(x)/q(x)). If rejected at position k:

  • Discard tokens k onwards
  • Sample a new token from max(0, p - q) (the “residual” distribution)

This ensures we never accept tokens the target model wouldn’t have produced.


Batched Implementation Challenges

Speculative decoding seems simple enough to implement for B=1. How do we exploit batching to process multiple sequences at once?

Extending speculative decoding to batches introduces significant complexity. The core issue: different sequences accept different numbers of tokens.

Consider a batch of 3 sequences with γ=5:

  • Sequence 0: accepts 5 tokens (all drafts correct)
  • Sequence 1: accepts 2 tokens (rejection at position 3)
  • Sequence 2: accepts 4 tokens

After one iteration, sequences have different lengths. How do we handle this efficiently?

Two Approaches

Approach 1: Per-Sequence Tracking (v1)

  • Track individual start_positions per sequence
  • Process only active sequences in draft loop
  • Use DynamicCache with pruning to minimum length
  • Sequences “catch up” by processing variable token ranges

Approach 2: Synchronized Batch Position (v2)

  • Single batch_position for all sequences
  • Use attention masking to handle “gaps” from rejections
  • Place bonus tokens at max accepted position
  • Simpler indexing but requires careful mask management

The v2 Implementation

Design Philosophy

v2 uses a synchronized batch_position pointer and StaticCache. All sequences share the same position in the tensor, with attention masks handling the variable-length reality.

After iteration with different acceptance rates:
Sequence 0: [prompt...][tok][tok][tok][tok][tok][bonus][ ][ ]
Sequence 1: [prompt...][tok][tok][PAD][PAD][PAD][bonus][ ][ ]
Sequence 2: [prompt...][tok][tok][tok][tok][PAD][bonus][ ][ ]
                                                   ^
                                            batch_position

Rejected positions are masked out (attention_mask=0), creating “gaps” in sequences.

Position IDs with Gaps

With gaps in sequences, position IDs must reflect actual content positions, not tensor indices. We compute this via cumulative sum of the attention mask:

pos_id = (attn_mask[:, :seq_len].cumsum(dim=1) - 1).clamp(min=0)

For a mask like [1,1,1,0,0,1,1], this gives positions [0,1,2,2,2,3,4]. The model sees contiguous positions despite gaps in the tensor.

The Draft Loop

for k in range(draft_steps):
    draft_output = draft_model(
        input_ids=input_ids[:, batch_position - 1 + k].unsqueeze(1),
        attention_mask=attn_mask[:, :batch_position + k],
        cache_position=cache_position + k,
        position_ids=pos_id + k,
        past_key_values=draft_cache,
    )
    
    # Sample from draft distribution
    Q[active_indices, k] = logits_processor(draft_logits[:, -1])
    new_token = logits_processor.sample(Q[active_indices, k])
    
    # Place token and update mask
    input_ids[active_indices, batch_position + k] = new_token
    attn_mask[active_indices, batch_position + k] = 1

Each step:

  1. Feeds the previous token (at batch_position - 1 + k)
  2. Extends attention mask to include new position
  3. Stores draft probabilities in Q matrix for later comparison
  4. Places drafted token at batch_position + k

Parallel Verification

The target model verifies all drafted tokens in one forward pass:

target_output = target_model(
    input_ids=input_ids[:, batch_position-1:batch_position+draft_steps],
    attention_mask=attn_mask[:, :batch_position+draft_steps],
    cache_position=torch.arange(batch_position-1, batch_position+draft_steps),
    position_ids=pos_id[:, batch_position-1:batch_position+draft_steps],
    past_key_values=target_cache,
)

This processes draft_steps + 1 tokens:

  • Position batch_position - 1: the last accepted token (to get logits for first draft)
  • Positions batch_position to batch_position + draft_steps - 1: the drafted tokens

Vectorized Acceptance

Instead of sequential rejection sampling, we vectorize:

# Gather probabilities for drafted tokens
p_tok = p[:, :draft_steps].gather(dim=2, index=drafted_tokens.unsqueeze(-1)).squeeze(-1)
q_tok = q[:, :draft_steps].gather(dim=2, index=drafted_tokens.unsqueeze(-1)).squeeze(-1)

# Log-space rejection sampling
log_ratio = torch.log(p_tok) - torch.log(q_tok)
log_r = torch.empty_like(log_ratio).uniform_().log_()

# Cumulative product finds first rejection
acceptance_status = (log_r <= log_ratio).cumprod(dim=1).bool()
num_accepted = acceptance_status.sum(dim=1)

The cumprod trick: once we hit a rejection (False), all subsequent positions become False, giving us the count of consecutively accepted tokens.

Bonus Token Placement

Here’s where batching gets tricky. Different sequences accept different amounts:

batch_pos_shift = accepted_draft_length.max().item()
input_ids[active_indices, batch_position + batch_pos_shift] = extra_tokens
attn_mask[active_indices, batch_position + batch_pos_shift] = 1
batch_position += batch_pos_shift + 1

All sequences place their bonus token at the maximum accepted position. Sequences that accepted fewer tokens have gaps between their last accepted token and the bonus. These gaps are masked out, so attention skips them.

Why place at max position?

If we placed bonus tokens at each sequence’s individual position, the next iteration’s input would be at different tensor indices per sequence, breaking batched processing. By aligning to max position, batch_position remains synchronized.


HuggingFace Cache Challenges

StaticCache vs DynamicCache

HuggingFace Transformers offers two KV-cache implementations:

DynamicCache

  • Grows dynamically as tokens are generated
  • Supports crop(max_length) to truncate
  • No cache_position parameter needed—just appends
  • Memory reallocations can cause fragmentation

StaticCache

  • Pre-allocated to maximum length
  • Requires explicit cache_position to specify write locations
  • More memory efficient for known max lengths
  • Can “overwrite” positions (useful for our gap scenario)

The Cache Gap Problem

With synchronized batch_position, a subtle bug emerges:

  • Draft loop processes positions [batch_position-1, batch_position+draft_steps-2]
  • Target verification processes positions [batch_position-1, batch_position+draft_steps-1]

The target processes one more position than the draft loop. After the iteration:

  • Target cache has KVs for positions 0 to batch_position + draft_steps - 1
  • Draft cache has KVs for positions 0 to batch_position + draft_steps - 2

Gap: one position missing in draft cache.

In the next iteration, when batch_position advances, the draft model tries to attend to a position that was never computed—containing garbage values.

Why v1 Doesn’t Have This Problem

v1 uses per-sequence tracking and processes token ranges:

draft_output = draft_model(
    input_ids=input_ids[active_indices, min_seq_len+k:max_length+k],
    ...
)

When sequences have different lengths, max_length - min_seq_len can be > 1, processing multiple tokens per step. The cache naturally “catches up” because sequences behind process more tokens.

Solutions for v2

Option 1: Extra Forward Pass

After the draft loop, run one more draft forward to fill the gap:

draft_model(input_ids[:, batch_position + draft_steps - 1], ...)

Simple but adds latency.

Option 2: Process 2 Tokens on First Step

On k=0 of the next iteration, process 2 tokens instead of 1 to fill the gap:

if k == 0 and draft_cache_len < batch_position - 1:
    draft_model(input_ids[:, draft_cache_len:batch_position], ...)

Same total compute, just redistributed.

Option 3: Use DynamicCache with Cropping

Crop both caches to batch_position - 1 at iteration start, then process normally. The cache regrows each iteration.

Position ID Alignment

With StaticCache, cache_position determines where KVs are stored, while position_ids determines RoPE embeddings. These can differ!

For sequences with gaps:

  • Cache position might be 15 (tensor index)
  • Position ID might be 13 (actual content position after gaps)

The KV is stored at index 15, but RoPE is computed for position 13. When later attending, the query’s RoPE (based on its position ID) correctly matches the key’s embedded RoPE.

This works because RoPE is “baked into” K at storage time. The cache index is just for storage/retrieval—the positional information is in the embeddings.

Attention Mask Interactions

The attention mask must correctly reflect which positions are valid:

attn_mask[rejected_indices, rejected_positions] = 0  # Mask out rejections

When computing attention, masked positions contribute zero to the softmax. Even if garbage exists in the cache at those positions, it’s ignored.

Caution: The mask must be consistent between draft and target models. If the target used a different mask during verification, the cached KVs might have different attention patterns than what the draft expects.


Performance Characteristics

Acceptance Rate

The acceptance rate—proportion of drafted tokens accepted—determines speedup:

Speedup ≈ (accepted_tokens + 1) / (draft_cost + target_cost)

Factors affecting acceptance rate:

  • Draft/target model alignment (same family helps)
  • Task difficulty (factual recall vs creative writing)
  • Temperature (lower = more deterministic = higher acceptance)

Block Efficiency

We track “block efficiency”—tokens generated per speculation block:

block_efficiency = total_tokens / num_blocks

With γ=5, perfect acceptance gives efficiency of 6 (5 drafts + 1 bonus). Real-world values of 3-4 are common.

Memory Bandwidth

Speculative decoding shines when memory-bound. The target model’s weights are loaded once to verify γ tokens instead of γ times. For large models on consumer GPUs, this is significant.


Conclusion

Batched speculative decoding requires careful orchestration of:

  • Synchronized position tracking across variable-length sequences
  • Attention masks that correctly handle gaps from rejections
  • KV-caches that align between draft and target models
  • Position IDs computed from actual content, not tensor indices

The HuggingFace cache abstractions (StaticCache/DynamicCache) provide building blocks, but their semantics around cache_position vs position_ids require careful handling. The gap problem—where draft and target caches diverge by one position—is a subtle bug that manifests as degraded acceptance rates rather than obvious failures.

Understanding these mechanics enables building efficient, correct batched speculative decoding systems that maintain the mathematical guarantees of the original algorithm while leveraging GPU parallelism.