Skip to content
Go back

Inside a Blackwell 2-CTA GEMM: A Gluon Kernel Tour

Edit page

TL;DR: We tour a single GEMM kernel (C = A @ B) written in Gluon for the NVIDIA Blackwell B200. It uses two CTAs cooperating on one matrix-multiply (tcgen05), pulls tiles from HBM with the Tensor Memory Accelerator (TMA), accumulates in Tensor Memory (TMEM), and overlaps loads with compute using an mbarrier-driven software pipeline. There are three interactive diagrams below — step through them. By the end you’ll be able to read every line of the kernel and know why it’s there.

The kernel lives here: kernels/gluon_gemm.py. This post explains it as a learning exercise — if you’ve written a Triton kernel before but never touched Blackwell’s new primitives, this is for you.

Table of contents

Open Table of contents

What problem are we solving?

A GEMM — GEneral Matrix Multiply — computes:

C=AB,ARM×K,  BRK×N,  CRM×NC = A \cdot B, \qquad A \in \mathbb{R}^{M\times K},\; B \in \mathbb{R}^{K\times N},\; C \in \mathbb{R}^{M\times N}

Each output element is a dot product over the shared K dimension:

Cij=k=1KAikBkjC_{ij} = \sum_{k=1}^{K} A_{ik}\, B_{kj}

This is the single most important kernel in deep learning: every linear layer, attention projection, and MLP is a GEMM. So the entire game is keeping the tensor cores fed with data fast enough that they never stall. The kernel we’re studying is a hand-pipelined, hardware-accelerated answer to that on Blackwell.

🧠 GPU concept: why tiling?

Matrices don’t fit in fast memory. The trick of every fast GEMM is tiling: chop A, B, and C into small blocks that do fit in on-chip shared memory, multiply the blocks, and accumulate. We stream K in chunks of BK so we only ever hold a thin slice of A and B on-chip at once. The output tile (BM × BN) stays resident and gets accumulated into across all K steps.

This kernel runs each output tile on a cluster of two CTAs that split the tile’s rows in half. Use the buttons to see who owns what (defaults BM=256, BN=256, BK=64):

🧩 2-CTA tiling — who computes what
Matrix B · [BK × BN] — shared by both CTAs
B (left half)
B (right half)
A · [BM × BK]
CTA 0 rows
CTA 1 rows
Output C · [BM × BN] = A @ B
CTA 0 → top 128 output rows
CTA 1 → bottom 128 output rows
Each grid program is a 2-CTA cluster. The 256-row output tile is split: CTA 0 owns the top 128 rows, CTA 1 the bottom 128. They cooperate on one tcgen05 MMA, so the B tile is shared between them.
CTA 0CTA 1

The launch grid has one program (one cluster) per output tile:

grid = (triton.cdiv(M, BM), triton.cdiv(N, BN))

So program_id(0) picks the row-band of C and program_id(1) picks the column-band.


Why Gluon (and not plain Triton)?

You’ve probably seen Triton kernels: you write tl.load, tl.dot, tl.store, and the compiler decides how to lay out data in shared memory, how to pipeline, and how to schedule warps. That’s wonderful for productivity but it hides the hardware.

Gluon is Triton’s experimental lower-level dialect. It exposes the Blackwell primitives directly: you allocate shared-memory buffers yourself, you place memory barriers yourself, you issue the async copy and the tensor-core MMA as explicit instructions, and you own the pipeline. It’s closer to writing CUTLASS than to writing Triton — but in Python.

from triton.experimental import gluon
from triton.experimental.gluon import language as gl
from triton.experimental.gluon.nvidia.hopper import TensorDescriptor
from triton.experimental.gluon.language.nvidia.blackwell import (
    TensorMemoryLayout, allocate_tensor_memory, tma, mbarrier,
    tcgen05_mma, tcgen05_mma_barrier_count, fence_async_shared, get_tmem_reg_layout,
)

That import block is a great table of contents for what’s special here:

PrimitiveWhat it is
TensorDescriptorA handle describing a tensor + tile shape for TMA copies
tmaThe Tensor Memory Accelerator — async bulk DMA between HBM and SMEM
mbarrierHardware async barriers to coordinate producers/consumers
allocate_tensor_memory / TensorMemoryLayoutBlackwell’s new Tensor Memory (TMEM) for accumulators
tcgen05_mmaThe Blackwell 5th-gen tensor core matrix-multiply instruction

Let’s introduce each one before we read the kernel.


A 60-second tour of the Blackwell memory path

The whole kernel is a choreography of data moving through a hierarchy, from slow-and-big to fast-and-small: HBM (global, tens of GB) → SMEM (shared, ~228 KB per SM) → TMEM (tensor memory) → back out. Three pieces of that path are brand-new on Blackwell or unfamiliar from textbook CUDA, so let’s give each a short box.

🧠 GPU concept: TMA — the Tensor Memory Accelerator

Classically, to load a tile you’d have every thread in the block compute an address and issue a load, then __syncthreads(). That burns registers and instruction slots. TMA (introduced on Hopper) is a dedicated copy engine: you hand it a descriptor (“this 256×64 tile of A, starting at these coordinates”) and it streams the whole tile from HBM into shared memory asynchronously, in the background, with a single instruction. It even handles out-of-bounds tiles (ragged shapes) by zero-filling — which is why this kernel passes on sizes like 4000×4096. When the copy finishes, TMA signals an mbarrier.

🧠 GPU concept: Tensor Memory (TMEM)

On Hopper and earlier, tensor-core accumulators lived in registers, spread across a warp. Blackwell adds Tensor Memory — a dedicated on-chip memory bank physically next to the tensor cores, addressed in a 2-D (rows × columns) layout. The big BM × BN fp32 accumulator now sits in TMEM instead of hogging the register file, which frees registers and lets the MMA run wider. You allocate it explicitly (allocate_tensor_memory) and copy results out with acc.load(...) when you’re done.

🧠 GPU concept: tcgen05 — the 5th-gen tensor core MMA

Tensor cores are the units that do the actual D = A·B + C on small matrix tiles. tcgen05 is Blackwell’s instruction family for them. Two things make it special here: it reads its operands straight from shared memory (no manual register staging), and it can run in 2-CTA mode, where two thread blocks cooperate on one larger MMA (the split you toggled in the diagram above).

🧠 GPU concept: CTAs and clusters

A CTA (Cooperative Thread Array) is what CUDA calls a thread block — a group of warps that share SMEM and run on one SM. Hopper introduced the thread-block cluster: a small group of CTAs (here, 2) that run on neighboring SMs and can read each other’s shared memory and share barriers. This kernel launches with num_ctas=2, so each grid program is a 2-CTA cluster.

With that vocabulary, the kernel reads like prose. Let’s go.


The kernel, section by section

1. Deriving the tile shapes

@gluon.jit
def _gemm_2cta_kernel(a_desc, b_desc, c_desc, NB: gl.constexpr, num_warps: gl.constexpr):
    cluster_m: gl.constexpr = a_desc.block_type.shape[0]   # BM = 256 (rows for the whole cluster)
    BK: gl.constexpr = a_desc.block_type.shape[1]          # 64  (K-chunk size)
    tile_n: gl.constexpr = b_desc.block_type.shape[1]      # BN  (output columns)
    cta_m: gl.constexpr = cluster_m // 2                   # 128 (rows each CTA owns)
    dtype: gl.constexpr = a_desc.dtype
    K = a_desc.shape[1]
    num_k = (K + BK - 1) // BK                             # number of K-steps to loop over

The kernel reads its tile shapes out of the descriptors instead of taking them as arguments — the host already baked BM/BK/BN into the TMA descriptors. The key line is cta_m = cluster_m // 2: the cluster owns 256 rows, each of its two CTAs owns 128. num_k is how many BK-wide slices of the K dimension we’ll stream through.

2. Picking this program’s output tile

    pid_m = gl.program_id(0); pid_n = gl.program_id(1)
    off_m = pid_m * cluster_m; off_n = pid_n * tile_n

(off_m, off_n) is the top-left corner of the BM × BN block of C this cluster is responsible for. Every TMA copy below will be relative to these offsets.

3. Allocating the staging buffers and barriers

    a_bufs = gl.allocate_shared_memory(dtype, [NB] + a_desc.block_type.shape, a_desc.layout)
    b_bufs = gl.allocate_shared_memory(dtype, [NB] + b_desc.block_type.shape, b_desc.layout)
    ready = mbarrier.allocate_mbarrier(batch=NB, two_ctas=True)
    empty = mbarrier.allocate_mbarrier(batch=NB)
    cnt: gl.constexpr = tcgen05_mma_barrier_count([a_bufs.index(0), b_bufs.index(0)], False)
    for i in gl.static_range(NB):
        mbarrier.init(ready.index(i), count=1)
        mbarrier.init(empty.index(i), count=cnt)

This is the heart of the pipeline setup. We allocate NB copies of the A-tile and B-tile buffers in shared memory — a ring buffer with NB slots (default NB=3, i.e. triple-buffering). While the tensor cores chew on slot 0, TMA can be filling slots 1 and 2.

Two sets of mbarriers coordinate the ring:

cnt is how many arrivals the empty barrier should expect per MMA — the helper tcgen05_mma_barrier_count computes it from the operand shapes so we don’t hard-code it.

🧠 GPU concept: mbarriers and the producer/consumer dance

An mbarrier is a hardware barrier sitting in shared memory with a phase bit that flips each time the expected number of arrivals lands. Async engines (TMA, tensor cores) “arrive” on it when they finish; threads “wait” on a phase. This is how you build a lock-free producer/consumer pipeline: the producer (TMA) signals ready, the consumer (MMA) signals empty, and nobody spins on a global lock. It’s the GPU-native version of a bounded queue.

4. Allocating the accumulator in Tensor Memory

    acc_layout: gl.constexpr = TensorMemoryLayout((cta_m, tile_n), col_stride=1,
                                                  cta_split_num=(2, 1), two_ctas=True)
    acc = allocate_tensor_memory(gl.float32, [cluster_m, tile_n], acc_layout)
    a_pc: gl.constexpr = a_desc.block_type.nbytes // 2     # bytes of A per CTA (half the tile)
    b_pc: gl.constexpr = b_desc.block_type.nbytes // 2

The accumulator is a 256 × BN fp32 tile in TMEM. cta_split_num=(2, 1) says: split it across 2 CTAs along the M dimension, 1 along N — so each CTA physically holds its own 128 × BN half, but the pair addresses it as one 256 × BN logical tile. Accumulating in fp32 (even though A/B are fp16/bf16) is what keeps the result accurate.

a_pc / b_pc are “bytes per CTA” — half of a full tile, because in 2-CTA mode each CTA’s TMA only pulls its half. We’ll feed these to mbarrier.expect so the barrier knows exactly how many bytes constitute “done.”

5. Priming the pipeline (prologue)

    for i in gl.static_range(NB):
        mbarrier.expect(ready.index(i), a_pc + b_pc)
        tma.async_copy_global_to_shared(a_desc, [off_m, i * BK], ready.index(i), a_bufs.index(i))
        tma.async_copy_global_to_shared(b_desc, [i * BK, off_n], ready.index(i), b_bufs.index(i))

Before the main loop, we kick off the first NB loads so the pipeline starts full. For each slot i: arm the barrier with expect (“await this many bytes”), then fire two async TMA copies — A’s tile at K-offset i*BK and B’s tile at the same offset. These calls return immediately; the copies run in the background.

6. The main loop — load/compute overlap

This is where the magic happens:

    for k in range(num_k):
        buf = k % NB; ph = (k // NB) & 1
        mbarrier.wait(ready.index(buf), ph, deps=[a_bufs.index(buf), b_bufs.index(buf)])
        tcgen05_mma(a_bufs.index(buf), b_bufs.index(buf), acc, use_acc=(k > 0),
                    mbarriers=[empty.index(buf)])
        kk = k + NB
        if kk < num_k:
            mbarrier.wait(empty.index(buf), ph)
            mbarrier.expect(ready.index(buf), a_pc + b_pc)
            tma.async_copy_global_to_shared(a_desc, [off_m, kk * BK], ready.index(buf), a_bufs.index(buf))
            tma.async_copy_global_to_shared(b_desc, [kk * BK, off_n], ready.index(buf), b_bufs.index(buf))

Rather than describe it in prose, step through it. The widget below runs the kernel for a K split into 6 chunks with NB=3 buffers. Watch how the tensor cores (green) always have a buffer ready, because TMA (amber) filled it three steps earlier:

⛓️ Software pipeline — step through it
buffer 0
buffer 1
buffer 2
loading (TMA)readyMMA (compute)free

The tensor cores never wait for HBM, because by the time the MMA needs slot k % NB, the TMA filled it NB steps ago. That overlap is the entire point — a GEMM that stalls on memory runs at a fraction of peak. A couple of details from the code:

🧠 GPU concept: software pipelining

Memory is slow (hundreds of cycles); compute is fast. If you load-then-compute serially, the tensor cores idle during every load. Software pipelining breaks that dependency by running NB iterations “in flight” at once: iteration k’s compute overlaps iterations k+1..k+NB-1’s loads. More stages (NB) hides more latency but costs more shared memory — which is exactly why NB is one of the autotuned knobs.

7. Draining the pipeline (epilogue)

    last = (num_k - 1) % NB; lph = ((num_k - 1) // NB) & 1
    mbarrier.wait(empty.index(last), lph)

After the loop issues the final MMA, we must wait for it to actually finish before touching acc. We compute which slot and phase the last K-step used and wait on its empty barrier. Now the accumulator holds the complete C tile.

8. Writing C back out

    reg_layout: gl.constexpr = get_tmem_reg_layout(gl.float32, [cluster_m, tile_n], acc_layout,
                                                   num_warps, cga_layout=[(1, 0)])
    out = acc.load(reg_layout)                  # TMEM → registers
    c_smem = gl.allocate_shared_memory(dtype, c_desc.block_type.shape, c_desc.layout)
    c_smem.store(out.to(dtype))                 # registers → SMEM, cast fp32 → fp16/bf16
    fence_async_shared()
    tma.async_copy_shared_to_global(c_desc, [off_m, off_n], c_smem)   # SMEM → HBM
    tma.store_wait(pendings=0)                  # make sure the store finished

The output journey reverses the input one. Press Play to watch the tile travel from Tensor Memory back to global memory — and click any stage for detail:

📤 Writing C back: TMEM → HBM
TMEM
acc · fp32
acc.load()
registers
out
c_smem.store()
+ to(dtype)
shared mem
c_smem · fp16
tma.async_copy
shared → global
global
C matrix
fp32 accumulate the whole way, narrowed to fp16/bf16 only at the shared-memory hop. Press Play or click a stage.

fence_async_shared() ensures the SMEM writes are visible to the TMA engine before the store launches, and tma.store_wait(pendings=0) blocks until the store drains, so the kernel doesn’t exit with an in-flight copy.


The host side

Setting up descriptors and launching

def gemm(A, B, C=None, *, BM=256, BN=256, BK=64, NB=3, num_warps=4):
    """C = A @ B on B200 via 2-CTA tcgen05. A,B 2-D fp16/bf16. Returns C.
    BM must be 2*instrM (256); tcgen05 caps the MMA N (=BN) at 256."""
    assert A.dtype == B.dtype and A.dtype in (torch.float16, torch.bfloat16)
    M, K = A.shape; K2, N = B.shape; assert K == K2
    if C is None:
        C = torch.empty(M, N, device=A.device, dtype=A.dtype)
    gd: gl.constexpr = gl.float16 if A.dtype == torch.float16 else gl.bfloat16
    a_layout = gl.NVMMASharedLayout.get_default_for([BM, BK], gd, cga_layout=[(1, 0)])
    b_layout = gl.NVMMASharedLayout.get_default_for([BK, BN], gd, cga_layout=[(0, 1)])
    c_layout = gl.NVMMASharedLayout.get_default_for([BM, BN], gd, cga_layout=[(1, 0)])
    a_desc = TensorDescriptor.from_tensor(A, [BM, BK], a_layout)
    b_desc = TensorDescriptor.from_tensor(B, [BK, BN], b_layout)
    c_desc = TensorDescriptor.from_tensor(C, [BM, BN], c_layout)
    grid = (triton.cdiv(M, BM), triton.cdiv(N, BN))
    _gemm_2cta_kernel[grid](a_desc, b_desc, c_desc, NB=NB, num_warps=num_warps, num_ctas=2)
    return C

A few things worth calling out:

Autotuning

The fastest (BM, BN, BK, NB) depends on the matrix shape, so there’s a tiny autotuner:

_SPACE = [(256, 256, 64, 2), (256, 256, 64, 3), (256, 256, 64, 4),
          (256, 128, 64, 3), (256, 128, 64, 4), (256, 128, 64, 6)]
_BEST = {}

def gemm_auto(A, B, C=None):
    """Autotuned 2-CTA GEMM. First call for a shape sweeps `_SPACE` and caches."""
    M, K = A.shape; N = B.shape[1]; key = (M, N, K, A.dtype)
    cfg = _BEST.get(key)
    if cfg is None:
        flush = torch.empty(128 * 1024 * 1024, device=A.device, dtype=torch.int8)
        best = None
        for (BM, BN, BK, NB) in _SPACE:
            try:
                for _ in range(5): gemm(A, B, C, BM=BM, BN=BN, BK=BK, NB=NB)
                torch.cuda.synchronize()
                t = 1e9
                for _ in range(3):
                    flush.zero_()
                    s = torch.cuda.Event(True); e = torch.cuda.Event(True); s.record()
                    for _ in range(20): gemm(A, B, C, BM=BM, BN=BN, BK=BK, NB=NB)
                    e.record(); torch.cuda.synchronize()
                    t = min(t, s.elapsed_time(e))
                if best is None or t < best[0]: best = (t, (BM, BN, BK, NB))
            except Exception:
                pass
        cfg = best[1]; _BEST[key] = cfg
    BM, BN, BK, NB = cfg
    return gemm(A, B, C, BM=BM, BN=BN, BK=BK, NB=NB)

Two details that separate a toy benchmark from an honest one:

The try/except is there because some configs (e.g. a BN that doesn’t divide cleanly, or one that overflows shared memory) simply fail to compile or launch — we skip them and keep the ones that work.

Correctness against ragged shapes

    shapes = [(256, 128, 128), (4096, 4096, 4096), (2048, 2048, 2048),
              (4000, 4096, 4096), (4096, 4096, 4000), (1536, 6144, 2048)]

Notice 4000×4096×4096 and 4096×4096×4000: dimensions that aren’t multiples of the tile size. These exist on purpose — they test that TMA’s boundary handling (zero-padding the out-of-range part of edge tiles) is correct. The check compares against an fp32 reference and asserts the relative error is below a dtype-appropriate tolerance:

    ref = A.float() @ B.float()
    rel = ((C.float() - ref).abs().max() / ref.abs().max()).item()
    tol = 5e-2 if dt == torch.float16 else 1e-1

The tolerances look loose, but remember we’re accumulating thousands of fp16/bf16 products — that’s expected rounding, not a bug. The fp32 accumulator is what keeps it this tight.


Why this design is fast

Putting it together, every performance-critical idea in modern GEMM is present in this one file:

TechniqueWherePays off by…
TilingBM×BN×BK blockskeeping working set in fast on-chip memory
TMA async copytma.async_copy_*freeing threads from address math; overlapping DMA
Software pipeliningNB-slot ring + mbarriershiding HBM latency behind tensor-core compute
2-CTA tcgen05num_ctas=2, cta_split_numone MMA spanning two SMs → bigger tiles, more reuse
TMEM accumulationallocate_tensor_memorya huge fp32 accumulator without register pressure
fp32 accumulategl.float32 accaccuracy despite fp16/bf16 inputs
L2-flushed autotuninggemm_autopicking the best tiling per shape, measured honestly

The benchmark in main() reports throughput as a percentage of cuBLAS — which is the right yardstick, since cuBLAS is the heavily-tuned vendor library. Getting within striking distance of it with a readable Python kernel is the whole appeal of Gluon.


Where to go next

The full kernel is on GitHub: Tanmaypatil123/kronos · kernels/gluon_gemm.py. Clone it, run it on a B200, and watch the tensor cores stay fed.


Edit page
Share this post on:

Next Post
RMSNorm Backward: From Derivation to a Triton Kernel