“”“
Continuous batching = iteration-level scheduling + ragged (packed) batching.
Two approaches are compared (both run BATCH_SIZE sequences concurrently, so the
comparison is slot-for-slot fair):
1. Static batching (baseline):
Prompts are processed BATCH_SIZE at a time. Each wave is padded to a
common length and run together until the LONGEST request in that wave
finishes; a hard “batch barrier” then has to clear before the next wave
starts. Short requests sit idle behind the barrier.
2. Continuous batching (production-aligned):
Two ideas combine to keep the GPU busy.
(a) Iteration-level scheduling: the moment a sequence finishes it frees
its slot, and the next queued prompt is admitted on the SAME step –
no waiting for the rest of the batch.
(b) Ragged / packed batching – the part that makes it truly “continuous“:
instead of padding every sequence into a rectangular [B, max_len]
tensor, ALL in-flight tokens are concatenated into a single unpadded
[1, total_tokens] row and run in ONE forward pass. A block-diagonal
causal attention mask stops tokens from attending across sequence
boundaries, so packing is mathematically identical to running each
sequence on its own (verified: greedy output matches per-prompt
generation token-for-token).
Because attention is governed entirely by the mask, a newly admitted
prompt’s multi-token PREFILL rides along in the same forward pass as
every other sequence’s single-token DECODE step. Prefill and decode are
fused: no padding, no separate prefill pass.
KV cache: each sequence keeps its own DynamicCache; every step the caches
are concatenated along the time axis into one packed cache, and the newly
computed KV is scattered back per sequence. (Real engines store the
cache in fixed-size pages – “paged attention” – to avoid this per-step
reassembly, but the attention/masking logic is exactly what you see here.)
““”
import time
import torch
from dataclasses import dataclass, field
from typing import Optional
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
from transformers.cache_utils import DynamicLayer
MODEL_ID = “openai-community/gpt2” # swap for any causal LM
BATCH_SIZE = 3 # max concurrent sequences (slots)
def _device_sync(model) -> None:
“”“Block until queued GPU work finishes, so timings are accurate.”“”
if model.device.type == “cuda”:
torch.cuda.synchronize()
elif model.device.type == “mps”:
torch.mps.synchronize()
def static_batching(requests: list[tuple[str, int]], tokenizer, model) -> list[str]:
“”“Baseline. Process requests BATCH_SIZE at a time; each wave runs together
until its LONGEST request finishes, then a batch barrier clears before the
next wave starts.
Downside: short requests in a wave idle until the wave’s longest is done –
and no slot can be refilled until the whole wave clears the barrier.
““”
if not requests:
return []
tokenizer.padding_side = “left”
results: dict[int, str] = {}
indexed = list(enumerate(requests)) # (req_id, (prompt, cap))
for wave_start in range(0, len(indexed), BATCH_SIZE):
wave = indexed[wave_start: wave_start + BATCH_SIZE]
wave_max = max(cap for _, (_, cap) in wave)
# Show which request occupies each slot in this wave.
for slot, (req_id, (prompt, cap)) in enumerate(wave):
print(f” ++ slot {slot} <- req {req_id} ({cap} tok cap): {prompt!r}", flush=True)
prompts = [p for _, (p, _) in wave]
inputs = tokenizer(
prompts, return_tensors=“pt”, padding=True, truncation=True
).to(model.device)
with torch.no_grad():
output_ids = model.generate(
**inputs,
max_new_tokens=wave_max, # whole wave decodes to the longest
pad_token_id=tokenizer.eos_token_id,
do_sample=False,
)
width = inputs.input_ids.shape[1]
print(
f” *** batch barrier: all {len(wave)} slots wait for the longest “
f“({wave_max} tokens) ***”,
flush=True,
)
for slot, ((req_id, (prompt, cap)), row) in enumerate(zip(wave, output_ids)):
text = prompt + tokenizer.decode(row[width:width + cap], skip_special_tokens=True)
results[req_id] = text
print(
f” — slot {slot} done req {req_id} ({cap}/{wave_max} tokens): {text[:90]}”,
flush=True,
)
return [results[k] for k in sorted(results)]
@dataclass
class Sequence:
“”“State for a single in-flight sequence.”“”
req_id: int # original request index (for ordering results)
prompt: str
max_new_tokens: int # per-request cap so short requests finish early
# Tokens to feed on the NEXT step: the whole prompt right after admission
# (prefill), then a single token per step (decode).
pending_ids: list[int]
# Per-sequence KV-cache; None until this sequence has run once.
kv_cache: Optional[DynamicCache] = None
kv_len: int = 0 # number of cached tokens (prompt + generated)
tokens_generated: int = 0
output_ids: list[int] = field(default_factory=list)
def _make_cache(layers_kv: list[tuple[torch.Tensor, torch.Tensor]]) -> DynamicCache:
“”“Build a DynamicCache from explicit per-layer (keys, values) tensors.
We SET the tensors directly instead of calling DynamicLayer.update() (which
would append), because we are assembling caches from scratch each step.
““”
cache = DynamicCache()
for k, v in layers_kv:
layer = DynamicLayer()
layer.lazy_initialization(k, v)
layer.keys = k
layer.values = v
cache.layers.append(layer)
return cache
def _ragged_step(seqs: list[Sequence], model, device, dtype) -> list[int]:
“”“Run ONE packed forward pass over every active sequence.
All sequences are flattened into a single row (batch dim = 1):
input_ids [1, total_q] – every sequence’s pending tokens
position_ids [1, total_q] – each token’s position in ITS sequence
attention_mask [1, 1, total_q, total_kv + total_q] – block-diagonal causal
past_key_values packed cache [1, H, total_kv, D]
total_q = sum of pending tokens (1 per decoding seq, prompt_len per new seq)
total_kv = sum of already-cached tokens across sequences
Returns the next greedy token for each sequence (same order as “seqs“).
““”
q_lens = [len(s.pending_ids) for s in seqs]
total_q = sum(q_lens)
total_kv = sum(s.kv_len for s in seqs)
# Packed inputs: concatenate every sequence’s pending tokens into one row.
flat_ids = [t for s in seqs for t in s.pending_ids]
input_ids = torch.tensor([flat_ids], dtype=torch.long, device=device)
# Tag every KEY and every QUERY token with (sequence index, position-in-sequence).
# Key space is laid out as [ cached tokens | this step’s new tokens ], matching
# how the model appends new KV to the end of the packed cache.
key_seq, key_pos = [], []
for si, s in enumerate(seqs): # cached block
for p in range(s.kv_len):
key_seq.append(si)
key_pos.append(p)
q_seq, q_pos = [], []
for si, s in enumerate(seqs): # new block (also queries)
for j in range(len(s.pending_ids)):
pos = s.kv_len + j
q_seq.append(si)
q_pos.append(pos)
key_seq.append(si)
key_pos.append(pos)
q_seq_t = torch.tensor(q_seq, device=device)
q_pos_t = torch.tensor(q_pos, device=device)
key_seq_t = torch.tensor(key_seq, device=device)
key_pos_t = torch.tensor(key_pos, device=device)
# Each token’s positional embedding uses its own sequence position, not its
# offset in the packed row.
position_ids = q_pos_t.unsqueeze(0) # [1, total_q]
# Block-diagonal causal mask: a query may attend to a key only if they belong
# to the SAME sequence (block-diagonal) and the key is not in the future
# (causal). This is the whole trick – it makes packing equivalent to running
# each sequence separately. 0.0 = attend, large-negative = blocked (additive).
same = q_seq_t[:, None] == key_seq_t[None, :]
causal = key_pos_t[None, :] <= q_pos_t[:, None]
allowed = same & causal # [total_q, total_kv + total_q]
attn_mask = torch.zeros(1, 1, total_q, total_kv + total_q, dtype=dtype, device=device)
attn_mask.masked_fill_(~allowed[None, None], torch.finfo(dtype).min)
# Packed KV-cache: concatenate each sequence’s cache along the time axis.
# Freshly admitted sequences (kv_len == 0) contribute nothing here.
cached = [s for s in seqs if s.kv_len > 0]
if cached:
num_layers = len(cached[0].kv_cache.layers)
layers_kv = []
for l in range(num_layers):
ks = torch.cat([s.kv_cache.layers[l].keys for s in cached], dim=2)
vs = torch.cat([s.kv_cache.layers[l].values for s in cached], dim=2)
layers_kv.append((ks, vs))
past = _make_cache(layers_kv)
else:
past = DynamicCache()
with torch.no_grad():
out = model(
input_ids=input_ids,
attention_mask=attn_mask,
position_ids=position_ids,
past_key_values=past,
use_cache=True,
)
# Greedy next token for each sequence: read the logits at its LAST pending
# token (for a prefilling sequence that is the final prompt token).
logits = out.logits[0] # [total_q, vocab]
offsets, last_idx, off = [], [], 0
for ql in q_lens:
offsets.append(off)
last_idx.append(off + ql – 1)
off += ql
next_tokens = [int(logits[i].argmax()) for i in last_idx]
# Scatter the newly computed KV back to each sequence. The output cache is
# [ old packed block | new packed block ]; slice this step’s new block per
# sequence and append it to that sequence’s own cache.
out_kv = out.past_key_values
num_layers = len(out_kv.layers)
for si, s in enumerate(seqs):
o, ql = offsets[si], q_lens[si]
layers_kv = []
for l in range(num_layers):
k_new = out_kv.layers[l].keys[:, :, total_kv + o: total_kv + o + ql, :]
v_new = out_kv.layers[l].values[:, :, total_kv + o: total_kv + o + ql, :]
if s.kv_cache is None:
layers_kv.append((k_new, v_new))
else:
layers_kv.append((
torch.cat([s.kv_cache.layers[l].keys, k_new], dim=2),
torch.cat([s.kv_cache.layers[l].values, v_new], dim=2),
))
s.kv_cache = _make_cache(layers_kv)
s.kv_len += ql
return next_tokens
def visualize_ragged_step(seqs: list[Sequence], tokenizer, title: str, slot_ids: list[int]) -> None:
“”“Illustrative print of ONE packed step: the concatenated input row and the
block-diagonal causal attention mask.
This mirrors the masking logic in _ragged_step (recomputed here as a boolean
grid purely for display) so you can SEE that sequences are packed together
yet isolated by the mask. Each sequence gets a letter A, B, C, …
# = a query may attend to that key . = blocked
““”
labels = [chr(ord(“A”) + s.req_id) for s in seqs]
q_lens = [len(s.pending_ids) for s in seqs]
total_q = sum(q_lens)
total_kv = sum(s.kv_len for s in seqs)
print(f“\n{‘=’ * 72}\n {title}”)
print(f” total_q={total_q} tokens fed this step | total_kv={total_kv} cached”)
print(f” {len(seqs)} sequences packed into ONE unpadded row of shape [1, {total_q}]:\n”)
# The concatenated tokens, grouped per sequence (this is the “ragged” row).
for i, s in enumerate(seqs):
kind = f“PREFILL({q_lens[i]})” if s.kv_len == 0 else f“decode({q_lens[i]})”
toks = ” “.join(repr(tokenizer.decode([t])) for t in s.pending_ids)
if len(toks) > 66:
toks = toks[:63] + “…”
print(f” {labels[i]} = slot {slot_ids[i]} {kind:<11} {toks}")
# Rebuild the block-diagonal causal mask as a boolean grid for display.
key_seq, key_pos = [], []
for si, s in enumerate(seqs): # cached keys
key_seq += [si] * s.kv_len
key_pos += list(range(s.kv_len))
q_seq, q_pos = [], []
for si, s in enumerate(seqs): # new keys / queries
for j in range(q_lens[si]):
q_seq.append(si)
q_pos.append(s.kv_len + j)
key_seq += q_seq
key_pos += q_pos
q_seq_t, q_pos_t = torch.tensor(q_seq), torch.tensor(q_pos)
key_seq_t, key_pos_t = torch.tensor(key_seq), torch.tensor(key_pos)
allowed = (q_seq_t[:, None] == key_seq_t[None, :]) & (key_pos_t[None, :] <= q_pos_t[:, None])
K = len(key_seq)
def row_str(cells):
# Space between sequence groups; ‘ | ‘ at the cached -> new-tokens split.
out = []
for ki in range(K):
if total_kv > 0 and ki == total_kv:
out.append(” | “)
elif ki > 0 and key_seq[ki] != key_seq[ki – 1]:
out.append(” “)
out.append(cells[ki])
return “”.join(out)
def line(left, cells):
return f“{left:>7} “ + row_str(cells)
print(f“\n block-diagonal causal mask (row = query, col = key) # attend . blocked”)
if total_kv > 0:
print(f” key layout: [ cached KV | this step’s new tokens ]”)
print(line(“keys:”, [labels[key_seq[ki]] for ki in range(K)]))
for qi in range(total_q):
cells = [“#” if allowed[qi, ki] else “.” for ki in range(K)]
print(line(f“{labels[q_seq[qi]]} p{q_pos[qi]}”, cells))
def continuous_batching(requests: list[tuple[str, int]], tokenizer, model) -> list[str]:
“”“Ragged continuous batching: dynamic scheduling + packed prefill/decode.
Scheduling policy:
– Up to BATCH_SIZE sequences run concurrently.
– A newly admitted sequence is queued with its full prompt as the next
tokens to feed; its prefill then happens packed into the next step
alongside everyone else’s decode.
– Every step runs ONE packed forward pass across all active slots.
– When a sequence finishes it is immediately replaced by the next prompt.
The admission log shows slots being reused (iteration-level scheduling).
Two representative steps are visualized: the first step (all prompts being
prefilled at once) and the first step that fuses a new prompt’s prefill with
other sequences’ decode tokens.
““”
device = model.device
dtype = next(model.parameters()).dtype
queue = list(enumerate(requests)) # (req_id, (prompt, max_new_tokens))
slots: list[Optional[Sequence]] = [None] * BATCH_SIZE
results: dict[int, str] = {}
def _admit(slot_idx: int) -> None:
if not queue:
slots[slot_idx] = None
return
req_id, (prompt, max_new_tokens) = queue.pop(0)
prompt_ids = tokenizer(prompt)[“input_ids”]
slots[slot_idx] = Sequence(
req_id=req_id,
prompt=prompt,
max_new_tokens=max_new_tokens,
pending_ids=list(prompt_ids), # prefill rides the next step
)
print(
f” ++ [step {step:3d}] slot {slot_idx} <- admit req {req_id} "
f“({max_new_tokens} tok cap): {prompt!r}”,
flush=True,
)
# Fill the pool with the first batch of prompts (step 0 = before any decode).
step = 0
for i in range(BATCH_SIZE):
_admit(i)
printed_mixed = False
while any(s is not None for s in slots):
step += 1
active = [(i, s) for i, s in enumerate(slots) if s is not None]
seqs = [s for _, s in active]
slot_ids = [i for i, _ in active]
# Visualize a couple of representative steps so the packing is visible
# (printing every step would be far too much output).
mixed = any(s.kv_len == 0 for s in seqs) and any(s.kv_len > 0 for s in seqs)
if step == 1:
visualize_ragged_step(
seqs, tokenizer, f“STEP {step} – prompts packed together (all PREFILL)”, slot_ids)
elif mixed and not printed_mixed:
visualize_ragged_step(
seqs, tokenizer, f“STEP {step} – PREFILL + DECODE fused in one pass”, slot_ids)
printed_mixed = True
# ONE packed forward pass (prefill + decode fused, no padding).
next_tokens = _ragged_step(seqs, model, device, dtype)
for (slot_idx, seq), tok in zip(active, next_tokens):
seq.output_ids.append(tok)
seq.tokens_generated += 1
seq.pending_ids = [tok] # next step: a single decode token
if tok == tokenizer.eos_token_id or seq.tokens_generated >= seq.max_new_tokens:
result_text = seq.prompt + \
tokenizer.decode(seq.output_ids, skip_special_tokens=True)
results[seq.req_id] = result_text
print(
f” — step {step:3d}] slot {slot_idx} done req {seq.req_id} “
f“({seq.tokens_generated}/{seq.max_new_tokens} tokens): {result_text[:90]}”,
flush=True,
)
_admit(slot_idx)
return [results[k] for k in sorted(results)]
def main():
print(f“Loading {MODEL_ID}”)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
# Pick the fastest available device. On Apple Silicon (M1/M2/…) this is
# the MPS GPU. We keep float32 on MPS on purpose: float16 there flips a few
# greedy ties, which would break the “static == continuous, token-for-token”
# property this demo relies on.
if torch.cuda.is_available():
device, dtype = “cuda”, torch.float16
elif torch.backends.mps.is_available():
device, dtype = “mps”, torch.float32
else:
device, dtype = “cpu”, torch.float32
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
dtype=dtype,
attn_implementation=“eager”, # use our custom 4D mask directly
)
model.eval()
model.to(device)
print(f“Running on {device} ({dtype})\n”)
requests = [
(“The capital of France is”, 6),
(“Today’s weather is so”, 50),
(“In machine learning, a transformer is”, 300),
(“Once upon a time in a land far away,”, 30),
(“Quantum computing differs from classical computing because”, 180),
(“The history of the Roman Empire began”, 45),
]
print(“=== Static batching ===”)
_device_sync(model)
start = time.perf_counter()
static_batching(requests, tokenizer, model)
_device_sync(model)
static_elapsed = time.perf_counter() – start
print(f“\nStatic batching elapsed: {static_elapsed:.2f}s\n”)
print(“=== Continuous batching (ragged) ===”)
_device_sync(model)
start = time.perf_counter()
continuous_batching(requests, tokenizer, model)
_device_sync(model)
continuous_elapsed = time.perf_counter() – start
print(f“\nContinuous batching elapsed: {continuous_elapsed:.2f}s”)
if __name__ == “__main__”:
main()
