...

Learning Triton One Kernel at a Time: Softmax


In the previous article of this series, operation in all fields of computer science: matrix multiplication. It is heavily used in neural networks to compute the activation of linear layers. However, activations on their own are difficult to interpret, since their values and statistics (mean, variance, min-max amplitude) can vary wildly from layer to layer. This is one of the reasons why we use activation functions, for example the logistic function (aka sigmoid) which projects any real number in the [0; 1] range.

The softmax function, also known as the normalised exponential function, is a multi-dimensional generalisation of the sigmoid. It converts a vector of raw scores (logits) into a probability distribution over M classes. We can interpret it as a weighted average that behaves as a smooth function and can be conveniently differentiated. It is a crucial component of dot-product attention, language modeling, and multinomial logistic regression.

In this article, we’ll cover:

  1. Implementing an efficient softmax kernel in Triton.
  2. Implementing the backward pass (autograd).
  3. Optimisation: cache modifiers and auto-tuning.

If you aren’t familiar with Triton yet, refer to the previous articles!

Disclaimer: all the illustrations and animations are made by the author unless specified otherwise.

Definition

The softmax is defined as follows:

The normalisation ensures that the vector sums to 1, so that it can be interpreted as a valid probability distribution.

Note that this formulation of the softmax is highly sensitive to numerical overflow. Recall that the maximum value a standard float16 can represent is 65 504, which is roughly exp(11). This means that any input value greater than ~11 will result in exp(z_i) exceeding the representable range, leading to overflow.

A common trick to mitigate this issue is to subtract the maximum value of the input vector from every element, such that the new maximum is 0 before exponentiation and 1 after.

Naive Implementation

As you can see, computing the softmax involves two reduction operations, a max and a sum. A naive algorithm require three separate passes over the input vector. First to compute the maximum, then the sum, and finally the normalised outputs.

Here’s what a naive Numpy implementation looks like:

A recurrent theme in this Triton series is minimising high-latency global memory access. Our current Numpy implementation requires three separate memory reads of the full input vector, which is highly inefficient.

Online Softmax

Fortunately, we can use a clever trick, known as the online softmax, to fuse the max and sum steps, reducing the number of memory reads to 2.

First, we define the sum of exponentials recursively. In the following set of equalities, m_i refers to the maximum over x until the i-th index.

This equality allows us to compute the sum of exponentials iteratively using the maximum value so far. We can leverage it to fuse the first and second loop in the naive implementation and compute the maximum and sum of exponentials iteratively.

Our algorithm becomes:

This is easily translated to Numpy:

Now that we understand the main principles behind the softmax, we’ll implement it in Triton, starting by the simple, single-block version and building up to the online, multi-block formulation. In the end, we want our kernel to behave like a PyTorch module and be compatible with autograd.

Unfortunately, from PyTorch’s point of view, Triton kernels behave like black boxes: the operations they perform are not traced by autograd. This requires us to implement the backward pass ourselves and explicitly specify how gradients should be computed. Let’s brush up on our beloved chain rule and derive the softmax gradient.

Gradient

Since the outputs of the softmax are strictly positive, we can use the logarithmic derivative to make the derivation of the gradient easier. Here, we take the derivative of the log of the output and apply the chain rule:

From there, we rearrange the terms and follow these steps:

Now assume that we have some upstream gradient, for example generated by a loss function L (e.g. a cross-entropy loss). We get the following expression of the gradient:

The simplification of the left term in (9) is due to the fact that δ_ij will only be equal to 1 for the i-th element, collapsing the sum over j to a single term.

Triton Implementation

Single Block Softmax

Now that we worked through the derivation of the gradient, we can write the forward and backward softmax kernels. First, let’s focus on the PyTorch wrapper to understand how the single block implementation works at a high level. Given a 2D input tensor, the forward and backward kernels are going to process all rows in parallel.

For simplicity, we’ll define the BLOCK_SIZE to be large enough to handle all columns at once. Specifically, we’ll set it as the next power of 2 superior to the number of columns, as required by Triton.

Then, we’ll define our `grid` to be the number of rows (it could potentially also handle a batch dimension).

The PyTorch wrapper for our SoftmaxSingleBlock is a class inheriting from torch.autograd.Function that implements forward and backward. Both methods take a ctx argument, which we’ll use to cache the softmax outputs during the forward pass and reuse them during the backward pass.

Both kernels are pretty straightforward, we start by loading the row inputs using the same syntax as in my previous vector addition article. Notice that BLOCK_SIZE and num_warps are computed using a calculate_settings function. This function comes from the Unsloth library and was reused in other kernel libraries such as LigerKernel (which the kernels in this article are loosely based on), it provides heuristics to tune both variables:

def calculate_settings(n: int) -> tuple[int, int]:
 MAX_FUSED_SIZE = 65536 # maximum grid dimension on Nvidia GPUs
    BLOCK_SIZE = next_power_of_2(n)
    if BLOCK_SIZE > MAX_FUSED_SIZE:
        # we remove this assertion in this article
        raise RuntimeError(
            f"Cannot launch Triton kernel since n = {n} exceeds "
            f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}."
        )
    num_warps = 4
    if BLOCK_SIZE >= 32768:
        num_warps = 32
    elif BLOCK_SIZE >= 8192:
        num_warps = 16
    elif BLOCK_SIZE >= 2048:
        num_warps = 8
    return BLOCK_SIZE, num_warps

Then, we implement the regular softmax for the forward pass and equation (10) for the backward pass. The only novelty here compared to previous articles is the use of cache modifiers, which tell the compiler how to cache and evict data. For now, we’ll only focus on three cache modifiers:

  • .ca (Cache at all levels): Tells the compiler to load the data in both L1 and L2 cache, suggesting that it might be reused soon. This modifier should be used when the data is small enough to fit into L1 (~128–192KB per SM on an A100) and will likely be accessed repeatedly.
  • .cs (Streaming): Treat data as streaming, it will be used once and then discarded to free up space in L1.
  • .wb (Write-back): Normal cached write, the data will remain in the cache hierarchy, good if the output may be reused.

In the following kernels, we’ll use the .ca modifier for loads since we perform multiple operations on the loaded data. For storing, we’ll use .cs in the forward pass, since the outputs won’t be immediately reused and .wb in the backward pass since in the context of autograd (i.e. the chain rule), gradient outputs will be consumed by downstream kernels.

Multi-Block Softmax

Now, let’s take a look at the online formulation of the softmax. In this section, we implement a multi-block variant of the previous kernel. This version will use BLOCK_SIZE , in other words, we’ll only load a tile with BLOCK_SIZE elements at a time, similar to how we handled tiled GEMM in the last tutorial. Now you might ask “how do we select the block size?”. 

This is a great occasion to introduce Triton’s autotune utility. Provided with a list of configuration, autotune will perform a grid-search to determine and cache the best configuration for a specific input shape. This process is repeated every time a new input shape is passed to the kernel.

Here, we perform a grid search over the block size and number of warps using the following utility function:

from itertools import product

# --- Multi Block Tuning ---
BLOCK_SIZES = [256, 512, 1024, 2048, 4096, 8192]
NUM_WARPS = [2, 4, 8, 16]

def get_autotune_config(
    block_sizes: list[int], num_warps: list[int]
) -> list[triton.Config]:
    return [
        triton.Config(kwargs={"BLOCK_SIZE": bs}, num_warps=nw)
        for (bs, nw) in list(product(block_sizes, num_warps))
    ]

We can now decorate our multi-block kernels with autotune and pass the list of configs, key=”n_cols” indicates that the optimal config is dependent on the number of columns of the input.

The implementation of these kernels is conceptually very close to the online softmax we covered before, the main differences is that we iterate over tiles (not over single elements like in Numpy), which requires some adjustments. For instance, we add a sum over the tile in the d update and the backward kernel now requires two iterations as well.

Note: the PyTorch wrapper is exactly the same except we delete the line where BLOCK_SIZE and num_warps are declared (since they are picked by autotune).

Testing and Benchmarking

We can now execute a forward and backward pass with both kernels and ensure they match the PyTorch baselines:

def validate_kernel(kernel_fn: callable) -> None:
    device = "cuda:0" if torch.cuda.is_available() else "cpu"
    torch.random.manual_seed(0)

    # Generate inputs
    x = torch.randn((256, 512), device=device) # triton input
    x.requires_grad = True
    xt = deepcopy(x) # torch input

    triton_output = kernel_fn(x)
    torch_output = torch.softmax(xt, dim=1)
    torch.testing.assert_close(triton_output, torch_output) # test fwd kernel

    # Setup fake labels
    y = torch.zeros_like(x)
    inds = (torch.arange(0, y.shape[0]), torch.randint(0, 3, (y.shape[0],)))
    y[inds] = 1

    # Define loss and run backward pass
    loss_fn = torch.nn.CrossEntropyLoss()
    loss = loss_fn(torch_output, y)
    loss.backward()

    # Save gradient tensor for later
    torch_xgrad = xt.grad.detach().clone()
    triton_loss = loss_fn(triton_output, y)
    triton_loss.backward()
    torch.testing.assert_close(x.grad, torch_xgrad) # test grad outputs

validate_kernel(softmax_sb)
validate_kernel(softmax_mb)

Finally, we benchmark our implementation against the PyTorch baseline using the following snippet:

# --- Source: Triton softmax tutorial ---
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["N"],  # argument names to use as an x-axis for the plot
        x_vals=[
            128 * i for i in range(2, 100)
        ],  # different possible values for `x_name`
        line_arg="provider",  # argument name whose value corresponds to a different line in the plot
        line_vals=[
            "triton_single_block",
            "triton_multi_block",
            "torch",
        ],  # possible values for `line_arg``
        line_names=[
            "Triton_single_block",
            "Triton_multi_block",
            "Torch",
        ],  # label name for the lines
        styles=[("blue", "-"), ("green", "-"), ("red", "-")],
        ylabel="GB/s",  # label name for the y-axis
        plot_name="softmax-performance",  # name for the plot. Used also as a file name for saving the plot.
        args={"M": 4096},  # values for function arguments not in `x_names` and `y_name`
    )
)
def benchmark(M, N, provider):
    x = torch.randn(M, N, device=DEVICE, dtype=torch.float32)
    stream = getattr(torch, DEVICE.type).Stream()
    getattr(torch, DEVICE.type).set_stream(stream)
    if provider == "torch":
        ms = triton.testing.do_bench(lambda: torch.softmax(x, axis=-1))
    if provider == "triton_single_block":
        torch.cuda.synchronize()
        ms = triton.testing.do_bench(lambda: softmax_sb(x))
        torch.cuda.synchronize()
    if provider == "triton_multi_block":
        torch.cuda.synchronize()
        ms = triton.testing.do_bench(lambda: softmax_mb(x))
        torch.cuda.synchronize()
    gbps = lambda ms: 2 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms)

benchmark.run(show_plots=True, print_data=True)

Good news! Our single-block kernel consistently outperforms the PyTorch baseline while the multi-block variant falls off for inputs with more than 6k columns:

Considering larger inputs, we can make several observations:

  1. The multi-block kernel eventually stabilises around 900GB/s of throughput, surpassing the PyTorch baseline for inputs with more than 30k columns. 
  2. Interestingly, it seems like the multi-block variant will dominate for inputs with more than 60k columns.
  3.  Even though we exceed the maximum block size with the single-block variant, the kernel still runs smoothly for some reason. Indeed, Triton automatically manages the block size under the hood. 
    When n_cols is larger than the hardware limit, Triton will break down the input and iterate over it. However, this seems to be slower than the multi-block approach. 

To go further, we could combine both approaches in a single kernel that explicitly selects the optimal kernel based on the input size. This way, we would benefit from the high performance of the single-block kernel for small inputs and the higher throughput of the multi-block variant for inputs with more than 60k columns.

This concludes the third episode of this Triton series, thanks again for your support!

In the next article, we’ll leverage the online softmax formulation in the context of Flash Attention.

Until next time! 👋

Resources:

Source link

#Learning #Triton #Kernel #TimeSoftmax