TL;DR: We derive RMSNorm backward in plain math (no scary notation), implement a forward+backward Triton kernel.
Why RMSNorm ?
RMSNorm introduces No mean subtraction, fewer ops, simpler math; used in modern LLMs. Good “real world” relevance over LayerNorm.
Cheat-sheet: tiny math facts we’ll use
- Square:
- Square root:
- Reciprocal:
- Product rule:
- Kronecker delta: (1 if (i=j), else 0)
Forward pass: RMSNorm
We’ll start with one feature vector (think a single row of shape [N]).
RMSNorm scales by the root-mean-square and (optionally) applies per-feature scale and shift .
Equations
What each symbol means
- : input vector (last/feature dimension length )
- : mean of squared elements of
- : small constant for numerical stability
- : root-mean-square of
- : normalized
- (or scalars): per-feature scale and shift
- : output
Triton Kernel : Forward Pass
## Import things which are needed.
import triton
import triton.language as tl
import torch
@triton.jit
def rms_norm_forward(
Y, Y_stride:tl.constexpr,
X,X_stride : tl.constexpr,
gamma, gamma_stride : tl.constexpr,
r,r_stride : tl.constexpr,
N:tl.constexpr,
eps : tl.constexpr,
BLOCK_SIZE : tl.constexpr
):
pid = tl.program_id(0)
offs = tl.arange(0,BLOCK_SIZE)
mask = offs < N
# finding accurate pointers to the memeory location.
Y += pid * Y_stride
X += pid * X_stride
r += pid * r_stride
# loading x and gamma to smem from hbm
x_row = tl.load(X+offs , mask = mask , other = 0 ).to(tl.float32)
gamma_row = tl.load(gamma + offs , mask = mask , other = 0)
# sum of x ^ 2 / N
row_var = tl.sum(x_row * x_row,axis = 0) / N
# 1 / (a + eps) ^ 1/2
inv_rms = tl.rsqrt(row_var + eps)
# storing inv_rms for backward pass we will need this again.
tl.store(r , inv_rms)
# final calculation
norm = x_row * inv_rms * gamma_row
tl.store(Y + offs , norm , mask = mask)
Backward Pass : RMSNorm
We work per row: . Everything flows through
Let and
Jacobian of w.r.t.\ . For each pair ,
i.e.
Matrix form. Stacking gives the Jacobian
Index / layout convention,
We treat as a ().
Then the Jacobian is an matrix with
so index corresponds to the (output component) and index to the (input component).
If you prefer a row-vector convention for , the Jacobian appears transposed accordingly.
Chain rule → (dX)
Vector form:
Using and ,
in the derivation is just a shorthand for the upstream grad () scaled by the weight ().
Triton Kernel : Backward Pass
@triton.jit
def rms_norm_backward(
dY, dY_stride : tl.constexpr,
dX, dX_stride : tl.constexpr,
X, X_stride : tl.constexpr,
gamma, gamma_stride : tl.constexpr,
r, r_stride : tl.constexpr,
N : tl.constexpr,
eps : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
):
pid = tl.program_id(0)
offs = tl.arange(0, BLOCK_SIZE)
mask = offs < N
dY += pid * dY_stride
X += pid * X_stride
r += pid * r_stride
dX = dY
dY_row = tl.load(dY + offs, mask = mask, other = 0).to(tl.float32)
X_row = tl.load(X + offs, mask = mask, other = 0).to(tl.float32)
gamma_row = tl.load(gamma + offs, mask = mask, other = 0).to(tl.float32)
# Get saved row variance
inv_var = tl.load(r).to(tl.float32)
# normed = x * inv
normed = X_row * inv_var
# h = y * r
dY_gamma = dY_row * gamma_row
# sum of h * normed
rowsum_dY_normed = tl.sum(dY_gamma * normed, axis = 0)
# inv / N (N * h - normed * sum of h and normed)
output = inv_var/N * (N*dY_gamma - normed*rowsum_dY_normed)
#storing back the output
tl.store(dX + offs, output, mask = mask)
Use Triton kernels with pytorch.
class RMS_Layernorm(torch.autograd.Function):
@staticmethod
def forward(
ctx, X: torch.Tensor,W : torch.Tensor ,eps : float
):
shape = X.shape
dim = shape[-1]
X = X.view(-1,dim)
n_rows , n_cols = X.shape
BLOCK_SIZE = triton.next_power_of_2(n_cols)
device = X.device
Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = device)
r = torch.empty(n_rows, dtype = torch.float32, device = device)
with torch.cuda.device(device):
rms_norm_forward[(n_rows,)](
Y, Y.stride(0),
X, X.stride(0),
W, W.stride(0),
r, r.stride(0),
n_cols, eps,
BLOCK_SIZE = BLOCK_SIZE
)
ctx.eps = eps
ctx.BLOCK_SIZE = BLOCK_SIZE
ctx.save_for_backward(X, W, r)
return Y.view(*shape)
@staticmethod
def backward(ctx, dY : torch.Tensor):
shape = dY.shape
dim : int = shape[-1]
dY = dY.view(-1, dim)
X, W, r = ctx.saved_tensors
n_rows : int
n_cols : int
n_rows, n_cols = dY.shape
dX = dY
with torch.cuda.device(dY.device):
rms_norm_backward[(n_rows,)](
dY, dY.stride(0),
dX, dX.stride(0),
X, X .stride(0),
W, W .stride(0),
r, r .stride(0),
n_cols, ctx.eps,
BLOCK_SIZE = ctx.BLOCK_SIZE,
)
dX = dX.view(*shape)
return dX, None, None, None