TL;DR: We tour a single GEMM kernel (
C = A @ B) written in Gluon for the NVIDIA Blackwell B200. It uses two CTAs cooperating on one matrix-multiply (tcgen05), pulls tiles from HBM with the Tensor Memory Accelerator (TMA), accumulates in Tensor Memory (TMEM), and overlaps loads with compute using anmbarrier-driven software pipeline. There are three interactive diagrams below — step through them. By the end you’ll be able to read every line of the kernel and know why it’s there.
The kernel lives here: kernels/gluon_gemm.py. This post explains it as a learning exercise — if you’ve written a Triton kernel before but never touched Blackwell’s new primitives, this is for you.
Table of contents
Open Table of contents
What problem are we solving?
A GEMM — GEneral Matrix Multiply — computes:
Each output element is a dot product over the shared K dimension:
This is the single most important kernel in deep learning: every linear layer, attention projection, and MLP is a GEMM. So the entire game is keeping the tensor cores fed with data fast enough that they never stall. The kernel we’re studying is a hand-pipelined, hardware-accelerated answer to that on Blackwell.
🧠 GPU concept: why tiling?
Matrices don’t fit in fast memory. The trick of every fast GEMM is tiling: chop
A,B, andCinto small blocks that do fit in on-chip shared memory, multiply the blocks, and accumulate. We streamKin chunks ofBKso we only ever hold a thin slice ofAandBon-chip at once. The output tile (BM × BN) stays resident and gets accumulated into across allKsteps.
This kernel runs each output tile on a cluster of two CTAs that split the tile’s rows in half. Use the buttons to see who owns what (defaults BM=256, BN=256, BK=64):
tcgen05 MMA, so the B tile is shared between them.The launch grid has one program (one cluster) per output tile:
grid = (triton.cdiv(M, BM), triton.cdiv(N, BN))
So program_id(0) picks the row-band of C and program_id(1) picks the column-band.
Why Gluon (and not plain Triton)?
You’ve probably seen Triton kernels: you write tl.load, tl.dot, tl.store, and the compiler decides how to lay out data in shared memory, how to pipeline, and how to schedule warps. That’s wonderful for productivity but it hides the hardware.
Gluon is Triton’s experimental lower-level dialect. It exposes the Blackwell primitives directly: you allocate shared-memory buffers yourself, you place memory barriers yourself, you issue the async copy and the tensor-core MMA as explicit instructions, and you own the pipeline. It’s closer to writing CUTLASS than to writing Triton — but in Python.
from triton.experimental import gluon
from triton.experimental.gluon import language as gl
from triton.experimental.gluon.nvidia.hopper import TensorDescriptor
from triton.experimental.gluon.language.nvidia.blackwell import (
TensorMemoryLayout, allocate_tensor_memory, tma, mbarrier,
tcgen05_mma, tcgen05_mma_barrier_count, fence_async_shared, get_tmem_reg_layout,
)
That import block is a great table of contents for what’s special here:
| Primitive | What it is |
|---|---|
TensorDescriptor | A handle describing a tensor + tile shape for TMA copies |
tma | The Tensor Memory Accelerator — async bulk DMA between HBM and SMEM |
mbarrier | Hardware async barriers to coordinate producers/consumers |
allocate_tensor_memory / TensorMemoryLayout | Blackwell’s new Tensor Memory (TMEM) for accumulators |
tcgen05_mma | The Blackwell 5th-gen tensor core matrix-multiply instruction |
Let’s introduce each one before we read the kernel.
A 60-second tour of the Blackwell memory path
The whole kernel is a choreography of data moving through a hierarchy, from slow-and-big to fast-and-small: HBM (global, tens of GB) → SMEM (shared, ~228 KB per SM) → TMEM (tensor memory) → back out. Three pieces of that path are brand-new on Blackwell or unfamiliar from textbook CUDA, so let’s give each a short box.
🧠 GPU concept: TMA — the Tensor Memory Accelerator
Classically, to load a tile you’d have every thread in the block compute an address and issue a
load, then__syncthreads(). That burns registers and instruction slots. TMA (introduced on Hopper) is a dedicated copy engine: you hand it a descriptor (“this 256×64 tile of A, starting at these coordinates”) and it streams the whole tile from HBM into shared memory asynchronously, in the background, with a single instruction. It even handles out-of-bounds tiles (ragged shapes) by zero-filling — which is why this kernel passes on sizes like4000×4096. When the copy finishes, TMA signals anmbarrier.
🧠 GPU concept: Tensor Memory (TMEM)
On Hopper and earlier, tensor-core accumulators lived in registers, spread across a warp. Blackwell adds Tensor Memory — a dedicated on-chip memory bank physically next to the tensor cores, addressed in a 2-D
(rows × columns)layout. The bigBM × BNfp32 accumulator now sits in TMEM instead of hogging the register file, which frees registers and lets the MMA run wider. You allocate it explicitly (allocate_tensor_memory) and copy results out withacc.load(...)when you’re done.
🧠 GPU concept: tcgen05 — the 5th-gen tensor core MMA
Tensor cores are the units that do the actual
D = A·B + Con small matrix tiles.tcgen05is Blackwell’s instruction family for them. Two things make it special here: it reads its operands straight from shared memory (no manual register staging), and it can run in 2-CTA mode, where two thread blocks cooperate on one larger MMA (the split you toggled in the diagram above).
🧠 GPU concept: CTAs and clusters
A CTA (Cooperative Thread Array) is what CUDA calls a thread block — a group of warps that share SMEM and run on one SM. Hopper introduced the thread-block cluster: a small group of CTAs (here, 2) that run on neighboring SMs and can read each other’s shared memory and share barriers. This kernel launches with
num_ctas=2, so each grid program is a 2-CTA cluster.
With that vocabulary, the kernel reads like prose. Let’s go.
The kernel, section by section
1. Deriving the tile shapes
@gluon.jit
def _gemm_2cta_kernel(a_desc, b_desc, c_desc, NB: gl.constexpr, num_warps: gl.constexpr):
cluster_m: gl.constexpr = a_desc.block_type.shape[0] # BM = 256 (rows for the whole cluster)
BK: gl.constexpr = a_desc.block_type.shape[1] # 64 (K-chunk size)
tile_n: gl.constexpr = b_desc.block_type.shape[1] # BN (output columns)
cta_m: gl.constexpr = cluster_m // 2 # 128 (rows each CTA owns)
dtype: gl.constexpr = a_desc.dtype
K = a_desc.shape[1]
num_k = (K + BK - 1) // BK # number of K-steps to loop over
The kernel reads its tile shapes out of the descriptors instead of taking them as arguments — the host already baked BM/BK/BN into the TMA descriptors. The key line is cta_m = cluster_m // 2: the cluster owns 256 rows, each of its two CTAs owns 128. num_k is how many BK-wide slices of the K dimension we’ll stream through.
2. Picking this program’s output tile
pid_m = gl.program_id(0); pid_n = gl.program_id(1)
off_m = pid_m * cluster_m; off_n = pid_n * tile_n
(off_m, off_n) is the top-left corner of the BM × BN block of C this cluster is responsible for. Every TMA copy below will be relative to these offsets.
3. Allocating the staging buffers and barriers
a_bufs = gl.allocate_shared_memory(dtype, [NB] + a_desc.block_type.shape, a_desc.layout)
b_bufs = gl.allocate_shared_memory(dtype, [NB] + b_desc.block_type.shape, b_desc.layout)
ready = mbarrier.allocate_mbarrier(batch=NB, two_ctas=True)
empty = mbarrier.allocate_mbarrier(batch=NB)
cnt: gl.constexpr = tcgen05_mma_barrier_count([a_bufs.index(0), b_bufs.index(0)], False)
for i in gl.static_range(NB):
mbarrier.init(ready.index(i), count=1)
mbarrier.init(empty.index(i), count=cnt)
This is the heart of the pipeline setup. We allocate NB copies of the A-tile and B-tile buffers in shared memory — a ring buffer with NB slots (default NB=3, i.e. triple-buffering). While the tensor cores chew on slot 0, TMA can be filling slots 1 and 2.
Two sets of mbarriers coordinate the ring:
ready[i]— flips when the TMA load into slotihas fully arrived. The MMA waits on this before consuming. (two_ctas=Truebecause in 2-CTA mode both CTAs must see the data.)empty[i]— flips when the MMA is done with sloti, so TMA is allowed to overwrite it with the next K-tile.
cnt is how many arrivals the empty barrier should expect per MMA — the helper tcgen05_mma_barrier_count computes it from the operand shapes so we don’t hard-code it.
🧠 GPU concept: mbarriers and the producer/consumer dance
An
mbarrieris a hardware barrier sitting in shared memory with a phase bit that flips each time the expected number of arrivals lands. Async engines (TMA, tensor cores) “arrive” on it when they finish; threads “wait” on a phase. This is how you build a lock-free producer/consumer pipeline: the producer (TMA) signalsready, the consumer (MMA) signalsempty, and nobody spins on a global lock. It’s the GPU-native version of a bounded queue.
4. Allocating the accumulator in Tensor Memory
acc_layout: gl.constexpr = TensorMemoryLayout((cta_m, tile_n), col_stride=1,
cta_split_num=(2, 1), two_ctas=True)
acc = allocate_tensor_memory(gl.float32, [cluster_m, tile_n], acc_layout)
a_pc: gl.constexpr = a_desc.block_type.nbytes // 2 # bytes of A per CTA (half the tile)
b_pc: gl.constexpr = b_desc.block_type.nbytes // 2
The accumulator is a 256 × BN fp32 tile in TMEM. cta_split_num=(2, 1) says: split it across 2 CTAs along the M dimension, 1 along N — so each CTA physically holds its own 128 × BN half, but the pair addresses it as one 256 × BN logical tile. Accumulating in fp32 (even though A/B are fp16/bf16) is what keeps the result accurate.
a_pc / b_pc are “bytes per CTA” — half of a full tile, because in 2-CTA mode each CTA’s TMA only pulls its half. We’ll feed these to mbarrier.expect so the barrier knows exactly how many bytes constitute “done.”
5. Priming the pipeline (prologue)
for i in gl.static_range(NB):
mbarrier.expect(ready.index(i), a_pc + b_pc)
tma.async_copy_global_to_shared(a_desc, [off_m, i * BK], ready.index(i), a_bufs.index(i))
tma.async_copy_global_to_shared(b_desc, [i * BK, off_n], ready.index(i), b_bufs.index(i))
Before the main loop, we kick off the first NB loads so the pipeline starts full. For each slot i: arm the barrier with expect (“await this many bytes”), then fire two async TMA copies — A’s tile at K-offset i*BK and B’s tile at the same offset. These calls return immediately; the copies run in the background.
6. The main loop — load/compute overlap
This is where the magic happens:
for k in range(num_k):
buf = k % NB; ph = (k // NB) & 1
mbarrier.wait(ready.index(buf), ph, deps=[a_bufs.index(buf), b_bufs.index(buf)])
tcgen05_mma(a_bufs.index(buf), b_bufs.index(buf), acc, use_acc=(k > 0),
mbarriers=[empty.index(buf)])
kk = k + NB
if kk < num_k:
mbarrier.wait(empty.index(buf), ph)
mbarrier.expect(ready.index(buf), a_pc + b_pc)
tma.async_copy_global_to_shared(a_desc, [off_m, kk * BK], ready.index(buf), a_bufs.index(buf))
tma.async_copy_global_to_shared(b_desc, [kk * BK, off_n], ready.index(buf), b_bufs.index(buf))
Rather than describe it in prose, step through it. The widget below runs the kernel for a K split into 6 chunks with NB=3 buffers. Watch how the tensor cores (green) always have a buffer ready, because TMA (amber) filled it three steps earlier:
The tensor cores never wait for HBM, because by the time the MMA needs slot k % NB, the TMA filled it NB steps ago. That overlap is the entire point — a GEMM that stalls on memory runs at a fraction of peak. A couple of details from the code:
use_acc=(k > 0)is a neat trick — on the first K-step we overwriteacc(no garbage init needed), and on every step after we accumulate.- The
if kk < num_kblock is the prefetch: as soon as the MMA is issued, we queue the loadNBsteps ahead into the same slot — but only aftermbarrier.wait(empty[buf], ph)confirms the previous MMA finished reading it.
🧠 GPU concept: software pipelining
Memory is slow (hundreds of cycles); compute is fast. If you load-then-compute serially, the tensor cores idle during every load. Software pipelining breaks that dependency by running
NBiterations “in flight” at once: iterationk’s compute overlaps iterationsk+1..k+NB-1’s loads. More stages (NB) hides more latency but costs more shared memory — which is exactly whyNBis one of the autotuned knobs.
7. Draining the pipeline (epilogue)
last = (num_k - 1) % NB; lph = ((num_k - 1) // NB) & 1
mbarrier.wait(empty.index(last), lph)
After the loop issues the final MMA, we must wait for it to actually finish before touching acc. We compute which slot and phase the last K-step used and wait on its empty barrier. Now the accumulator holds the complete C tile.
8. Writing C back out
reg_layout: gl.constexpr = get_tmem_reg_layout(gl.float32, [cluster_m, tile_n], acc_layout,
num_warps, cga_layout=[(1, 0)])
out = acc.load(reg_layout) # TMEM → registers
c_smem = gl.allocate_shared_memory(dtype, c_desc.block_type.shape, c_desc.layout)
c_smem.store(out.to(dtype)) # registers → SMEM, cast fp32 → fp16/bf16
fence_async_shared()
tma.async_copy_shared_to_global(c_desc, [off_m, off_n], c_smem) # SMEM → HBM
tma.store_wait(pendings=0) # make sure the store finished
The output journey reverses the input one. Press Play to watch the tile travel from Tensor Memory back to global memory — and click any stage for detail:
+ to(dtype)
shared → global
fence_async_shared() ensures the SMEM writes are visible to the TMA engine before the store launches, and tma.store_wait(pendings=0) blocks until the store drains, so the kernel doesn’t exit with an in-flight copy.
The host side
Setting up descriptors and launching
def gemm(A, B, C=None, *, BM=256, BN=256, BK=64, NB=3, num_warps=4):
"""C = A @ B on B200 via 2-CTA tcgen05. A,B 2-D fp16/bf16. Returns C.
BM must be 2*instrM (256); tcgen05 caps the MMA N (=BN) at 256."""
assert A.dtype == B.dtype and A.dtype in (torch.float16, torch.bfloat16)
M, K = A.shape; K2, N = B.shape; assert K == K2
if C is None:
C = torch.empty(M, N, device=A.device, dtype=A.dtype)
gd: gl.constexpr = gl.float16 if A.dtype == torch.float16 else gl.bfloat16
a_layout = gl.NVMMASharedLayout.get_default_for([BM, BK], gd, cga_layout=[(1, 0)])
b_layout = gl.NVMMASharedLayout.get_default_for([BK, BN], gd, cga_layout=[(0, 1)])
c_layout = gl.NVMMASharedLayout.get_default_for([BM, BN], gd, cga_layout=[(1, 0)])
a_desc = TensorDescriptor.from_tensor(A, [BM, BK], a_layout)
b_desc = TensorDescriptor.from_tensor(B, [BK, BN], b_layout)
c_desc = TensorDescriptor.from_tensor(C, [BM, BN], c_layout)
grid = (triton.cdiv(M, BM), triton.cdiv(N, BN))
_gemm_2cta_kernel[grid](a_desc, b_desc, c_desc, NB=NB, num_warps=num_warps, num_ctas=2)
return C
A few things worth calling out:
NVMMASharedLayoutpicks a swizzled shared-memory layout that the tensor cores can read without bank conflicts. Thecga_layoutargument tells it how the tile is split across the 2-CTA cluster: A is split along M ([(1,0)]), B along N ([(0,1)]).TensorDescriptor.from_tensor(A, [BM, BK], ...)builds the TMA descriptor: “this tensor, tiled intoBM×BKblocks.” This object is what the kernel’stma.async_copy_*calls dereference.- The launch passes
num_ctas=2— that single argument is what turns each grid point into a 2-CTA cluster. The constraints in the docstring (BM=256,BN≤256) come straight from thetcgen052-CTA instruction’s fixed shapes.
Autotuning
The fastest (BM, BN, BK, NB) depends on the matrix shape, so there’s a tiny autotuner:
_SPACE = [(256, 256, 64, 2), (256, 256, 64, 3), (256, 256, 64, 4),
(256, 128, 64, 3), (256, 128, 64, 4), (256, 128, 64, 6)]
_BEST = {}
def gemm_auto(A, B, C=None):
"""Autotuned 2-CTA GEMM. First call for a shape sweeps `_SPACE` and caches."""
M, K = A.shape; N = B.shape[1]; key = (M, N, K, A.dtype)
cfg = _BEST.get(key)
if cfg is None:
flush = torch.empty(128 * 1024 * 1024, device=A.device, dtype=torch.int8)
best = None
for (BM, BN, BK, NB) in _SPACE:
try:
for _ in range(5): gemm(A, B, C, BM=BM, BN=BN, BK=BK, NB=NB)
torch.cuda.synchronize()
t = 1e9
for _ in range(3):
flush.zero_()
s = torch.cuda.Event(True); e = torch.cuda.Event(True); s.record()
for _ in range(20): gemm(A, B, C, BM=BM, BN=BN, BK=BK, NB=NB)
e.record(); torch.cuda.synchronize()
t = min(t, s.elapsed_time(e))
if best is None or t < best[0]: best = (t, (BM, BN, BK, NB))
except Exception:
pass
cfg = best[1]; _BEST[key] = cfg
BM, BN, BK, NB = cfg
return gemm(A, B, C, BM=BM, BN=BN, BK=BK, NB=NB)
Two details that separate a toy benchmark from an honest one:
flush.zero_()on a 128 MB buffer before each timed run wipes the L2 cache. Without it, the second iteration would findA/Balready cached and report fantasy throughput.- It uses CUDA events (not Python
time) and takes the min over repeats, which rejects noise from clock-boost warmup and scheduler jitter.
The try/except is there because some configs (e.g. a BN that doesn’t divide cleanly, or one that overflows shared memory) simply fail to compile or launch — we skip them and keep the ones that work.
Correctness against ragged shapes
shapes = [(256, 128, 128), (4096, 4096, 4096), (2048, 2048, 2048),
(4000, 4096, 4096), (4096, 4096, 4000), (1536, 6144, 2048)]
Notice 4000×4096×4096 and 4096×4096×4000: dimensions that aren’t multiples of the tile size. These exist on purpose — they test that TMA’s boundary handling (zero-padding the out-of-range part of edge tiles) is correct. The check compares against an fp32 reference and asserts the relative error is below a dtype-appropriate tolerance:
ref = A.float() @ B.float()
rel = ((C.float() - ref).abs().max() / ref.abs().max()).item()
tol = 5e-2 if dt == torch.float16 else 1e-1
The tolerances look loose, but remember we’re accumulating thousands of fp16/bf16 products — that’s expected rounding, not a bug. The fp32 accumulator is what keeps it this tight.
Why this design is fast
Putting it together, every performance-critical idea in modern GEMM is present in this one file:
| Technique | Where | Pays off by… |
|---|---|---|
| Tiling | BM×BN×BK blocks | keeping working set in fast on-chip memory |
| TMA async copy | tma.async_copy_* | freeing threads from address math; overlapping DMA |
| Software pipelining | NB-slot ring + mbarriers | hiding HBM latency behind tensor-core compute |
| 2-CTA tcgen05 | num_ctas=2, cta_split_num | one MMA spanning two SMs → bigger tiles, more reuse |
| TMEM accumulation | allocate_tensor_memory | a huge fp32 accumulator without register pressure |
| fp32 accumulate | gl.float32 acc | accuracy despite fp16/bf16 inputs |
| L2-flushed autotuning | gemm_auto | picking the best tiling per shape, measured honestly |
The benchmark in main() reports throughput as a percentage of cuBLAS — which is the right yardstick, since cuBLAS is the heavily-tuned vendor library. Getting within striking distance of it with a readable Python kernel is the whole appeal of Gluon.
Where to go next
- Read the Modal GPU Glossary end-to-end — it’s the best free reference for the hardware terms used above (SM, tensor core, thread-block cluster, shared memory).
- Step through the pipeline widget again with
NBin mind — fewer buffers means the green MMA stage would stall waiting on amber loads. - Compare this with a plain-Triton
tl.dotGEMM to see exactly which decisions the compiler was making for you — and which ones Gluon hands back.
The full kernel is on GitHub: Tanmaypatil123/kronos · kernels/gluon_gemm.py. Clone it, run it on a B200, and watch the tensor cores stay fed.