Skip to content
Go back

RMSNorm Backward: From Derivation to a Triton Kernel

Edit page

TL;DR: We derive RMSNorm backward in plain math (no scary notation), implement a forward+backward Triton kernel.

Why RMSNorm ?

RMSNorm introduces No mean subtraction, fewer ops, simpler math; used in modern LLMs. Good “real world” relevance over LayerNorm.


Cheat-sheet: tiny math facts we’ll use


Forward pass: RMSNorm

We’ll start with one feature vector xRNx \in \mathbb{R}^{N} (think a single row of shape [N]).
RMSNorm scales by the root-mean-square and (optionally) applies per-feature scale γ\gamma and shift β\beta .

Equations

a  =  1Ni=1Nxi2(mean of squares)(1)\tag{1} a \;=\; \frac{1}{N}\sum_{i=1}^{N} x_i^2 \qquad\text{(mean of squares)} rms  =  a+ϵ(root-mean-square with stability ϵ>0)(2)\tag{2} \mathrm{rms} \;=\; \sqrt{a + \epsilon} \qquad\text{(root-mean-square with stability \(\epsilon>0\))} x^i  =  xirms(normalized activations)(3)\tag{3} \widehat{x}_i \;=\; \frac{x_i}{\mathrm{rms}} \qquad\text{(normalized activations)} yi  =  γix^i  +  βi(scale & shift)(4)\tag{4} y_i \;=\; \gamma_i\,\widehat{x}_i \;+\; \beta_i \qquad\text{(scale \& shift)}

What each symbol means

Triton Kernel : Forward Pass

## Import things which are needed.
import triton
import triton.language as tl
import torch

@triton.jit
def rms_norm_forward(
    Y, Y_stride:tl.constexpr,
    X,X_stride : tl.constexpr,
    gamma, gamma_stride : tl.constexpr,
    r,r_stride : tl.constexpr,
    N:tl.constexpr,
    eps : tl.constexpr,
    BLOCK_SIZE : tl.constexpr
):
    pid = tl.program_id(0)
    offs = tl.arange(0,BLOCK_SIZE)
    mask = offs < N

    # finding accurate pointers to the memeory location.
    Y += pid * Y_stride
    X += pid * X_stride
    r += pid * r_stride

    # loading x and gamma to smem from hbm
    x_row = tl.load(X+offs , mask = mask , other = 0 ).to(tl.float32)
    gamma_row = tl.load(gamma + offs , mask = mask , other = 0)

    # sum of x ^ 2 / N
    row_var = tl.sum(x_row * x_row,axis = 0) / N

    # 1 / (a + eps) ^ 1/2
    inv_rms = tl.rsqrt(row_var + eps)

    # storing inv_rms for backward pass we will need this again.
    tl.store(r , inv_rms)

    # final calculation
    norm = x_row * inv_rms * gamma_row
    tl.store(Y + offs , norm , mask = mask)

Backward Pass : RMSNorm

We work per row: xRNx \in \mathbb{R}^N. Everything flows through armsinvx^ a \rightarrow \mathrm{rms} \rightarrow \mathrm{inv} \rightarrow \hat{x}

  1. a=1Nixi2        axj=2Nxj.a=\frac{1}{N}\sum_i x_i^2 \;\;\Rightarrow\;\; \frac{\partial a}{\partial x_j}=\frac{2}{N}x_j.

  2. rms=a+ϵ        rmsxj=12a+ϵaxj=12a+ϵ2Nxj=xjNrms.\mathrm{rms}=\sqrt{a+\epsilon} \;\;\Rightarrow\;\; \frac{\partial\,\mathrm{rms}}{\partial x_j} = \frac{1}{2\sqrt{a+\epsilon}}\cdot\frac{\partial a}{\partial x_j} = \frac{1}{2\sqrt{a+\epsilon}}\cdot\frac{2}{N}x_j = \frac{x_j}{N\,\mathrm{rms}}.

  3. inv=1rms        invxj=1rms2rmsxj=xjNrms3.\mathrm{inv}=\frac{1}{\mathrm{rms}} \;\;\Rightarrow\;\; \frac{\partial\,\mathrm{inv}}{\partial x_j} = -\frac{1}{\mathrm{rms}^2}\,\frac{\partial\,\mathrm{rms}}{\partial x_j} = -\frac{x_j}{N\,\mathrm{rms}^3}.


Let xRNx\in\mathbb{R}^N and

a=1Nk=1Nxk2,rms=a+ε,inv=1rms,x^=xinv.a=\frac{1}{N}\sum_{k=1}^N x_k^2,\qquad \mathrm{rms}=\sqrt{a+\varepsilon},\qquad \mathrm{inv}=\frac{1}{\mathrm{rms}},\qquad \hat{x}=x\cdot \mathrm{inv}.

Jacobian of x^\hat{x} w.r.t.\ xx. For each pair (i,j)(i,j),

x^ixj=(xiinv)xj=(xixj)inv+xi(invxj)=δijinv+xi ⁣(xjNrms3),\frac{\partial \hat{x}_i}{\partial x_j} = \frac{\partial (x_i\,\mathrm{inv})}{\partial x_j} = \left(\frac{\partial x_i}{\partial x_j}\right)\mathrm{inv} + x_i \left(\frac{\partial\,\mathrm{inv}}{\partial x_j}\right) = \delta_{ij}\,\mathrm{inv} + x_i\!\left(-\,\frac{x_j}{N\,\mathrm{rms}^3}\right),

i.e.

  x^ixj=δijrmsxixjNrms3  \boxed{\; \frac{\partial \hat{x}_i}{\partial x_j} = \frac{\delta_{ij}}{\mathrm{rms}} - \frac{x_i x_j}{N\,\mathrm{rms}^3} \;}

Matrix form. Stacking x^i/xj\partial \hat{x}_i/\partial x_j gives the Jacobian

  x^x=1rmsI1Nrms3xx ⁣  \boxed{\; \frac{\partial \hat{x}}{\partial x} = \frac{1}{\mathrm{rms}}\,I - \frac{1}{N\,\mathrm{rms}^3}\,x x^{\!\top} \;}

Index / layout convention, We treat xx as a column vector\emph{column vector} (N×1N\times 1).
Then the Jacobian J=x^/xJ=\partial \hat{x}/\partial x is an N×NN\times N matrix with

Jij=x^ixj,J_{ij}=\frac{\partial \hat{x}_i}{\partial x_j},

so index ii corresponds to the row\textbf{row} (output component) and index jj to the column\textbf{column} (input component).
If you prefer a row-vector convention for xx^\top, the Jacobian appears transposed accordingly.

Chain rule → (dX)

Lxj=i=1Nhi(δijrmsxixjNrms3)=hjrmsxjNrms3ihixi.\frac{\partial \mathcal{L}}{\partial x_j} = \sum_{i=1}^N h_i \left(\frac{\delta_{ij}}{\mathrm{rms}} - \frac{x_i x_j}{N\,\mathrm{rms}^3}\right) = \frac{h_j}{\mathrm{rms}} - \frac{x_j}{N\,\mathrm{rms}^3}\sum_i h_i x_i.

Vector form:

dX  =  hrms    xh,xNrms3whereh,x=ihixi.\boxed{\, dX \;=\; \frac{h}{\mathrm{rms}} \;-\; \frac{x\,\langle h,x\rangle}{N\,\mathrm{rms}^3} \,} \quad\text{where}\quad \langle h,x\rangle=\sum_i h_i x_i.

Using x^=xinv\hat{x}=x\cdot\text{inv} and ihix^i=h,xinv\sum_i h_i \hat{x}_i=\langle h,x\rangle\cdot\text{inv},

dX  =  invN(Nhx^ihix^i).dX \;=\; \frac{\text{inv}}{N}\Big(Nh - \hat{x}\sum_i h_i\hat{x}_i\Big).

in the derivation h is just a shorthand for the upstream grad (dydy) scaled by the weight (gammagamma).

Triton Kernel : Backward Pass

@triton.jit
def rms_norm_backward(
    dY, dY_stride : tl.constexpr,
    dX, dX_stride : tl.constexpr,
    X,   X_stride : tl.constexpr,
    gamma,  gamma_stride : tl.constexpr,
    r,   r_stride : tl.constexpr,
    N     : tl.constexpr,
    eps        : tl.constexpr,
    BLOCK_SIZE : tl.constexpr,
):
    pid = tl.program_id(0)
    offs = tl.arange(0, BLOCK_SIZE)
    mask = offs < N

    dY += pid * dY_stride
    X  += pid *  X_stride
    r  += pid *  r_stride

    dX = dY

    dY_row = tl.load(dY + offs, mask = mask, other = 0).to(tl.float32)
    X_row  = tl.load(X  + offs, mask = mask, other = 0).to(tl.float32)
    gamma_row  = tl.load(gamma  + offs, mask = mask, other = 0).to(tl.float32)

    # Get saved row variance
    inv_var = tl.load(r).to(tl.float32)
    # normed = x * inv
    normed = X_row * inv_var
    # h = y * r
    dY_gamma = dY_row * gamma_row
    # sum of h * normed
    rowsum_dY_normed = tl.sum(dY_gamma * normed, axis = 0)

    # inv / N (N * h - normed * sum of h and normed)
    output = inv_var/N * (N*dY_gamma - normed*rowsum_dY_normed)

    #storing back the output
    tl.store(dX + offs, output, mask = mask)

Use Triton kernels with pytorch.

class RMS_Layernorm(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx, X: torch.Tensor,W : torch.Tensor ,eps : float 
    ):
        shape = X.shape
        dim = shape[-1]
        X = X.view(-1,dim)
        n_rows , n_cols = X.shape
        BLOCK_SIZE = triton.next_power_of_2(n_cols)
        device = X.device
        Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = device)
        r = torch.empty(n_rows, dtype = torch.float32, device = device)

        with torch.cuda.device(device):
            rms_norm_forward[(n_rows,)](
                Y, Y.stride(0),
                X, X.stride(0),
                W, W.stride(0),
                r, r.stride(0),
                n_cols, eps,
                BLOCK_SIZE = BLOCK_SIZE
            )
        ctx.eps = eps
        ctx.BLOCK_SIZE = BLOCK_SIZE
        ctx.save_for_backward(X, W, r)
        return Y.view(*shape)
    @staticmethod
    def backward(ctx, dY : torch.Tensor):
        shape = dY.shape
        dim : int = shape[-1]
        dY = dY.view(-1, dim)
        X, W, r = ctx.saved_tensors
        n_rows : int
        n_cols : int
        n_rows, n_cols = dY.shape
        dX = dY

        with torch.cuda.device(dY.device):
            rms_norm_backward[(n_rows,)](
                dY, dY.stride(0),
                dX, dX.stride(0),
                X,  X .stride(0),
                W,  W .stride(0),
                r,  r .stride(0),
                n_cols, ctx.eps,
                BLOCK_SIZE = ctx.BLOCK_SIZE,
            )
        dX = dX.view(*shape)
        return dX, None, None, None

Edit page
Share this post on: