...

Increasing Transformer Model Efficiency Through Attention Layer Optimization | by Chaim Rand | Nov, 2024


How paying “higher” consideration can drive ML value financial savings

Photograph by Andrew Seaman on Unsplash

Launched within the landmark 2017 paper “Attention Is All You Need” (Vaswani et al., 2017), the Transformer structure is broadly considered one of the vital influential scientific breakthroughs of the previous decade. On the core of the Transformer is the eye mechanism, a novel method that permits AI fashions to grasp complicated constructions by specializing in completely different components of enter sequences primarily based on the duty at hand. Initially demonstrated on the earth of pure language processing, the success of the Transformers structure has shortly unfold to many different domains, together with speech recognition, scene understanding, reinforcement studying, protein construction prediction, and extra. Nonetheless, consideration layers are extremely resource-intensive, and as these layers grow to be the usual throughout more and more giant fashions, the prices related to their coaching and deployment have surged. This has created an pressing want for methods that cut back the computational value of this core layer in order to extend the effectivity and scalability of Transformer-based AI fashions.

On this put up, we’ll discover a number of instruments for optimizing consideration in PyTorch. Our focus might be on strategies that preserve the accuracy of the eye layer. These will embrace PyTorch SDPA, FlashAttention, TransformerEngine Consideration, FlexAttention, and xFormer consideration. Different strategies that cut back the computational value by way of approximation of the eye calculation (e.g., DeepSpeed’s Sparse Attention, Longformer, Linformer, and extra) won’t be thought of. Moreover, we won’t talk about normal optimization methods that, whereas helpful to consideration efficiency, are usually not particular to the eye computation itself (e.g., FP8 training, model sharding, and more).

Importantly, consideration optimization is an energetic space of analysis with new strategies popping out on a fairly common foundation. Our objective is to extend your consciousness of a few of the present options and give you a basis for additional exploration and experimentation. The code we’ll share under is meant for demonstrative functions solely — we make no claims concerning its accuracy, optimality, or robustness. Please don’t interpret our point out of any platforms, libraries, or optimization methods as an endorsement for his or her use. One of the best choices for you’ll rely drastically on the specifics of your individual use-case.

Many due to Yitzhak Levi for his contributions to this put up.

To facilitate our dialogue, we construct a Vision Transformer (ViT)-backed classification mannequin utilizing the favored timm Python bundle (model 0.9.7). We are going to use this mannequin as an example the efficiency impression of varied consideration kernels.

We begin by defining a simplified Transformer block that permits for programming the eye operate by passing it into its constructor. Since consideration implementations assume particular enter tensor codecs, we additionally embrace an possibility for controlling the format, guaranteeing compatibility with the eye kernel of our selecting.

# normal imports
import os, time, functools

# torch imports
import torch
from torch.utils.information import Dataset, DataLoader
import torch.nn as nn

# timm imports
from timm.fashions.vision_transformer import VisionTransformer
from timm.layers import Mlp

IMG_SIZE = 224
BATCH_SIZE = 128

# Outline ViT settings
NUM_HEADS = 16
HEAD_DIM = 64
DEPTH = 24
PATCH_SIZE = 16
SEQ_LEN = (IMG_SIZE // PATCH_SIZE)**2 # 196

class MyAttentionBlock(nn.Module):
def __init__(
self,
attn_fn,
format = None,
dim: int = 768,
num_heads: int = 12,
**kwargs
) -> None:
tremendous().__init__()
self.attn_fn = attn_fn
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.qkv = nn.Linear(dim, dim * 3, bias=False)
self.proj = nn.Linear(dim, dim)
self.mlp = Mlp(
in_features=dim,
hidden_features=dim * 4,
)
permute = (2, 0, 3, 1, 4)
self.permute_attn = functools.partial(torch.transpose,dim0=1,dim1=2)

if format == 'bshd':
permute = (2, 0, 1, 3, 4)
self.permute_attn = nn.Id()
self.permute_qkv = functools.partial(torch.permute,dims=permute)

def ahead(self, x_in: torch.Tensor) -> torch.Tensor:
x = self.norm1(x_in)
B, N, C = x.form
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
# permute tensor primarily based on the desired format
qkv = self.permute_qkv(qkv)
q, ok, v = qkv.unbind(0)
# use the eye operate specified by the consumer
x = self.attn_fn(q, ok, v)
# permute output in response to the desired format
x = self.permute_attn(x).reshape(B, N, C)
x = self.proj(x)
x = x + x_in
x = x + self.mlp(self.norm2(x))
return x

We outline a randomly generated dataset which we’ll use to feed to our mannequin throughout coaching.

# Use random information
class FakeDataset(Dataset):
def __len__(self):
return 1000000

def __getitem__(self, index):
rand_image = torch.randn([3, IMG_SIZE, IMG_SIZE],
dtype=torch.float32)
label = torch.tensor(information=index % 1000, dtype=torch.int64)
return rand_image, label

Subsequent, we outline our ViT coaching operate. Whereas our instance focuses on demonstrating a coaching workload, it’s essential to emphasise that optimizing the eye layer is equally, if no more, vital throughout mannequin inference.

The coaching operate we outline accepts the custom-made Transformer block and a flag that controls using torch.compile.

def train_fn(block_fn, compile):
torch.random.manual_seed(0)
system = torch.system("cuda:0")
torch.set_float32_matmul_precision("excessive")

# Create dataset and dataloader
train_set = FakeDataset()
train_loader = DataLoader(
train_set, batch_size=BATCH_SIZE,
num_workers=12, pin_memory=True, drop_last=True)

mannequin = VisionTransformer(
img_size=IMG_SIZE,
patch_size=PATCH_SIZE,
embed_dim=NUM_HEADS*HEAD_DIM,
depth=DEPTH,
num_heads=NUM_HEADS,
class_token=False,
global_pool="avg",
block_fn=block_fn
).to(system)

if compile:
mannequin = torch.compile(mannequin)

# Outline loss and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(mannequin.parameters())

mannequin.practice()

t0 = time.perf_counter()
summ = 0
rely = 0
for step, information in enumerate(train_loader):
# Copy information to GPU
inputs = information[0].to(system=system, non_blocking=True)
label = information[1].to(system=system, non_blocking=True)
with torch.amp.autocast('cuda', enabled=True, dtype=torch.bfloat16):
outputs = mannequin(inputs)
loss = criterion(outputs, label)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()

# Seize step time
batch_time = time.perf_counter() - t0
if step > 20: # Skip first steps
summ += batch_time
rely += 1
t0 = time.perf_counter()
if step > 100:
break
print(f'common step time: {summ / rely}')

# outline compiled and uncompiled variants of our practice operate
practice = functools.partial(train_fn, compile=False)
train_compile = functools.partial(train_fn, compile=True)

Within the code block under we outline a PyTorch-native consideration operate and use it to coach our ViT mannequin:

def attn_fn(q, ok, v):
scale = HEAD_DIM ** -0.5
q = q * scale
attn = q @ ok.transpose(-2, -1)
attn = attn.softmax(dim=-1)
x = attn @ v
return x

block_fn = functools.partial(MyAttentionBlock, attn_fn=attn_fn)

print('Default Consideration')
practice(block_fn)
print('Compiled Default Consideration')
train_compile(block_fn)

We ran this on an NVIDIA H100 with CUDA 12.4 and PyTorch 2.5.1. The uncompiled variant resulted in a mean step time of 370 milliseconds (ms), whereas the compiled variant improved to 242 ms. We are going to use these outcomes as a baseline for comparability as we think about various options for performing the eye computation.

One of many best methods to spice up the efficiency of our consideration layers in PyTorch is to make use of the scaled_dot_product_attention (SDPA) operate. At present in beta, PyTorch SDPA consolidates a number of kernel-level optimizations and dynamically selects essentially the most environment friendly one primarily based on the enter’s properties. Supported backends (as of now) embrace: FlashAttention-2, Memory-Efficient Attention, a C++-based Math Consideration, and CuDNN. These backends fuse collectively high-level operations whereas using GPU-level optimizations for growing compute effectivity and reminiscence utilization.

SDPA is repeatedly evolving, with new and improved backend implementations being launched recurrently. Staying updated with the most recent PyTorch releases is vital to leveraging the latest efficiency enhancements. For instance, PyTorch 2.5 launched an up to date CuDNN backend that includes a specialised SDPA primitive particularly tailor-made for coaching on NVIDIA Hopper architecture GPUs.

Within the code block under, we iterate by means of the record of supported backends and assess the runtime efficiency of coaching with each. We use a helper operate, set_sdpa_backend, for programming the SDPA backend:

from torch.nn.useful import scaled_dot_product_attention as sdpa

def set_sdpa_backend(backend):
torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(False)
torch.backends.cuda.enable_cudnn_sdp(False)

if backend in ['flash_sdp','all']:
torch.backends.cuda.enable_flash_sdp(True)
if backend in ['mem_efficient_sdp','all']:
torch.backends.cuda.enable_mem_efficient_sdp(True)
if backend in ['math_sdp','all']:
torch.backends.cuda.enable_math_sdp(True)
if backend in ['cudnn_sdp','all']:
torch.backends.cuda.enable_cudnn_sdp(True)

for backend in ['flash_sdp', 'mem_efficient_sdp',
'math_sdp', 'cudnn_sdp']:
set_sdpa_backend(backend)
block_fn = functools.partial(MyAttentionBlock,
attn_fn=sdpa)

print(f'PyTorch SDPA - {backend}')
practice(block_fn)
print(f'Compiled PyTorch SDPA - {backend}')
train_compile(block_fn)

We summarize our interim ends in the desk under

Step instances for varied consideration features (decrease is healthier) — by Creator

Whereas the selection of SDPA backend has a noticeable impression on efficiency when operating in keen mode, the optimizations carried out by model compilation seem to overshadow the variations between the eye kernels. As soon as once more, we warning towards deriving any conclusions from these outcomes because the efficiency impression of various consideration features can fluctuate considerably relying on the precise mannequin and use case.

Whereas PyTorch SDPA is a good place to start out, utilizing third-party consideration kernels may also help speed up your ML workloads additional. These options usually include added flexibility, providing a wider vary of configuration choices for consideration. Some might also embrace optimizations tailor-made for particular {hardware} accelerators or newer GPU architectures.

On this part, we’ll discover a few of the third-party consideration kernels obtainable and consider their potential impression on runtime efficiency.

FlashAttention-3

Whereas Pytorch SDPA helps a FlashAttention backend, extra superior FlashAttention implementations could be discovered within the flash-attn library. Right here we’ll discover the FlashAttention-3 beta launch which boasts a pace of as much as 2x in comparison with FlashAttention-2. Given the early stage in its growth, FlashAttention-3 can solely be put in immediately from the GitHub repository and its use is proscribed to sure head dimensions. Moreover, it doesn’t but assist mannequin compilation. Within the following code block, we configure our transformer block to make use of flash-attn-3 whereas setting the eye enter format to “bshd” (batch, sequence, head, depth) to satisfy the expectations of the library.

# flash consideration 3
from flash_attn_interface import flash_attn_func as fa3
attn_fn = lambda q,ok,v: fa3(q,ok,v)[0]
block_fn = functools.partial(MyAttentionBlock,
attn_fn=attn_fn,
format='bshd')

print(f'Flash Consideration 3')
practice(block_fn)

The resultant step time was 240 ms, making it 5% sooner than the SDPA flash-attn.

Transformer Engine

Transformer Engine (TE) is a specialised library designed to speed up Transformer fashions on NVIDIA GPUs. TE is up to date recurrently with optimizations that leverage the capabilities of the most recent NVIDIA {hardware} and software program choices, giving customers entry to specialised kernels lengthy earlier than they’re built-in into general-purpose frameworks comparable to PyTorch.

Within the code block under we use DotProductAttention from TE version 1.11.0. Just like PyTorch SDPA, TE helps plenty of backends that are managed by way of surroundings variables. Right here we exhibit using the NVTE_FUSED_ATTN backend.

def set_te_backend(backend):
# should be utilized earlier than first use of
# transformer_engine.pytorch.consideration
os.environ["NVTE_FLASH_ATTN"] = '0'
os.environ["NVTE_FUSED_ATTN"] = '0'
os.environ["NVTE_UNFUSED_ATTN"] = '0'
if backend == 'flash':
os.environ["NVTE_FLASH_ATTN"] = '1'
if backend == 'fused':
os.environ["NVTE_FUSED_ATTN"] = '1'
if backend == 'unfused':
os.environ["NVTE_UNFUSED_ATTN"] = '1'

from transformer_engine.pytorch.consideration import DotProductAttention
set_te_backend('fused')
attn_fn = DotProductAttention(NUM_HEADS, HEAD_DIM, NUM_HEADS,
qkv_format='bshd',
# disable masking (default is causal masks)
attn_mask_type='no_mask')

block_fn = functools.partial(MyAttentionBlock,
attn_fn=attn_fn,
format='bshd')

print(f'Transformer Engine Consideration')
practice(block_fn)
print(f'Compiled Transformer Engine Consideration')
train_compile(block_fn)

TE consideration resulted in common step instances of 243 ms and 204 ms for the keen and compiled mannequin variants, correspondingly.

XFormer Consideration

Underlying the memory-efficient backend of PyTorch SDPA is an consideration kernel supplied by the xFormers library. As soon as once more, we are able to go to the supply to profit from the most recent kernel optimizations and from the total set of API capabilities. Within the following code block we use the memory_efficient_attention operator from xFormers version 0.0.28.

# xformer reminiscence environment friendly consideration
from xformers.ops import memory_efficient_attention as mea
block_fn = functools.partial(MyAttentionBlock,
attn_fn=mea,
format='bshd')

print(f'xFormer Consideration ')
practice(block_fn)
print(f'Compiled xFormer Consideration ')
train_compile(block_fn)

This keen mannequin variant resulted in a mean step time of 246 ms, making it 10.5% sooner than the SDPA reminiscence environment friendly kernel. The compiled variant resulted in a step time of 203 ms.

Outcomes

The desk under summarizes our experiments:

Step instances for varied consideration features (decrease is healthier) — by Creator

The winner for the keen mannequin was flash-attn-3 with a mean step time that’s 54% sooner than our baseline mannequin. This interprets to an identical 54% discount in coaching prices. In compiled mode, the efficiency throughout the optimized kernels was kind of equal, with the quickest implementations reaching 202 ms, representing a 20% enchancment in comparison with the baseline experiment.

As talked about above, the exact impression financial savings is drastically depending on the mannequin definition. To evaluate this variability, we reran the experiments utilizing modified settings that elevated the eye sequence size to 3136 tokens.

IMG_SIZE = 224
BATCH_SIZE = 8

# Outline ViT settings
NUM_HEADS = 12
HEAD_DIM = 64
DEPTH = 6
PATCH_SIZE = 4
SEQ_LEN = (IMG_SIZE // PATCH_SIZE)**2 # 3136

The outcomes are summarized within the desk under:

Outcomes for giant seqlen (decrease is healthier) — by Creator

Our instant statement is that when the sequence size is bigger the efficiency impression of the eye kernels is much extra pronounced. As soon as once more, flash-attn-3 got here out in entrance for the keen execution mode — this time with a ~5x enhance in efficiency in comparison with the PyTorch-native operate. For the compiled mannequin we see that the TE kernel broke away from the pack with an total finest step-time of 53 ms.

To date, we’ve centered on the usual consideration operate. Nonetheless, typically we could wish to use a variant of the everyday consideration computation during which we both masks out a few of the values of intermediate tensors or apply some operation on them. All these adjustments could intervene with our capability to make use of the optimized consideration blocks we coated above. On this part we talk about a few of the methods to handle this:

Leverage Superior Kernel APIs
Many optimized consideration kernels present intensive APIs with controls for customizing the eye computation. Earlier than implementing a brand new answer, discover these APIs to find out in the event that they already assist your required performance.

Implement a customized kernel:
If the present APIs don’t meet your wants, you could possibly think about creating your individual customized consideration implementation. In earlier posts (e.g., here) we mentioned a few of the professionals and cons of customized kernel growth. Attaining optimum efficiency could be extraordinarily troublesome. When you do go down this path, one method may be to start out with an present (optimum) kernel and apply minimal adjustments to combine the specified change.

Use FlexAttention:
A latest addition to PyTorch, FlexAttention empowers customers to implement all kinds of consideration variants with no need to compromise on efficiency. Denoting the results of the dot product of the question and key tokens by rating, flex_attention permits for programming both a score_mod operate or a block_mask masks that’s mechanically utilized to the rating tensor. See the documentation in addition to the accompanying attention-gym repository for examples of the forms of operations that the API permits.

FlexAttention works by compiling the score_mod operator into the eye operator, thereby making a single fused kernel. It additionally leverages the sparsity of block_masks to keep away from pointless computations. The benchmarks reported within the FlexAttention documentation present appreciable efficiency features for quite a lot of use circumstances.

Let’s see each the score_mod and block_mask in motion.

Rating Mod Instance — Mushy-Capping with Tanh

Mushy-capping is a standard method used to regulate the logit sizes (e.g., see here). The next code block extends our PyTorch-native consideration kernel with soft-capping:

def softcap_attn(q, ok, v):
scale = HEAD_DIM ** -0.5
q = q * scale
attn = q @ ok.transpose(-2, -1)
# apply soft-capping
attn = 30 * torch.tanh(attn/30)
attn = attn.softmax(dim=-1)
x = attn @ v
return x

Within the code block under we practice our mannequin, first with our PyTorch-native kernel, after which with the optimized Flex Consideration API. These experiments had been run with the 3136-length sequence settings.

# flex consideration imports
from torch.nn.consideration.flex_attention import (
create_block_mask,
create_mask,
flex_attention
)
compiled_flex = torch.compile(flex_attention)

# score_mod definition
def tanh_softcap(rating, b, h, q_idx, kv_idx):
return 30 * torch.tanh(rating/30)

block_fn = functools.partial(MyAttentionBlock, attn_fn=softcap_attn)

print(f'Consideration with Softcap')
practice(block_fn)
print(f'Compiled Consideration with Softcap')
train_compile(block_fn)

flex_fn = functools.partial(flex_attention, score_mod=tanh_softcap)
compiled_flex_fn = functools.partial(compiled_flex, score_mod=tanh_softcap)

block_fn = functools.partial(MyAttentionBlock,
attn_fn=flex_fn)
compiled_block_fn = functools.partial(MyAttentionBlock,
attn_fn=compiled_flex_fn)

print(f'Flex Consideration with Softcap')
practice(compiled_block_fn)
print(f'Compiled Flex Consideration with Softcap')
train_compile(block_fn)

The outcomes of the experiments are captured within the desk under:

Mushy-cap step time outcomes (decrease is healthier) — by Creator

The impression of the Flash Consideration kernel is clearly evident, delivering efficiency boosts of roughly 3.5x in keen mode and 1.5x in compiled mode.

Masks Mod Instance — Neighborhood Masking

We assess the mask_mod performance by making use of a sparse masks to our consideration rating. Recall that every token in our sequence represents a patch in our 2D enter picture. We modify our kernel so that every token solely attends to different tokens that our inside a 5×5 window within the corresponding 2-D token array.

# convert the token id to a second index
def seq_indx_to_2d(idx):
n_row_patches = IMG_SIZE // PATCH_SIZE
r_ind = idx // n_row_patches
c_ind = idx % n_row_patches
return r_ind, c_ind

# solely attend to tokens in a 5x5 surrounding window in our 2D token array
def mask_mod(b, h, q_idx, kv_idx):
q_r, q_c = seq_indx_to_2d(q_idx)
kv_r, kv_c = seq_indx_to_2d(kv_idx)
return torch.logical_and(torch.abs(q_r-kv_r)<5, torch.abs(q_c-kv_c)<5)

As a baseline for our experiment, we use PyTorch SDPA which incorporates assist for passing in an consideration masks. The next block contains the masked SDPA experiment adopted by the Flex Consideration implementation:

# materialize the masks to make use of in SDPA
masks = create_mask(mask_mod, 1, 1, SEQ_LEN, SEQ_LEN, system='cuda')

set_sdpa_backend('all')
masked_sdpa = functools.partial(sdpa, attn_mask=masks)
block_fn = functools.partial(MyAttentionBlock,
attn_fn=masked_sdpa)
print(f'Masked SDPA Consideration')
practice(block_fn)
print(f'Compiled Masked SDPA Consideration')
train_compile(block_fn)

block_mask = create_block_mask(mask_mod, None, None, SEQ_LEN, SEQ_LEN)
flex_fn = functools.partial(flex_attention, block_mask=block_mask)
compiled_flex_fn = functools.partial(compiled_flex, block_mask=block_mask)

block_fn = functools.partial(MyAttentionBlock,
attn_fn=flex_fn)
compiled_block_fn = functools.partial(MyAttentionBlock,
attn_fn=compiled_flex_fn)

print(f'Masked Flex Consideration')
practice(compiled_block_fn)
print(f'Compiled Masked Flex Consideration')
train_compile(block_fn)

The outcomes of the experiments are captured under:

Masked consideration step time outcomes (decrease is healthier) — by Creator

As soon as once more, Flex Consideration presents a substantial efficiency increase, amounting to 2.19x in keen mode and a pair of.59x in compiled mode.

Flex Consideration Limitations

Though we’ve succeeded in demonstrating the facility and potential of Flex Consideration, there are just a few limitations that ought to be famous:

  1. Restricted Scope of Modifications: With Flex Consideration you may (as of the time of this writing) solely modify the eye rating (the results of the dot product between the question and key tokens). It doesn’t assist adjustments at different levels of the eye computation.
  2. Dependency on torcch.compile: Given the reliance on torch.compile, nice care should be taken to keep away from extreme recompilations which may drastically degrade runtime efficiency. For example, whereas the assist for Document Masking very compelling, it would solely carry out as anticipated if the sum of the lengths of the entire paperwork stays mounted.
  3. No Help for Trainable Parameters in score_mod: On the time of this writing, Flex Consideration doesn’t assist a score_mod implementation that features trainable parameters. For instance, whereas the documentation highlights assist for relative position encodings, these are generally carried out with trainable parameters (relatively than mounted values) which can’t at the moment be accommodated.

Within the face of those limitations, we are able to return to one of many different optimization alternatives mentioned above.

Because the reliance on transformer architectures and a focus layers in ML fashions will increase, so does the necessity for instruments and methods for optimizing these elements. On this put up, we’ve explored plenty of consideration kernel variants, every with its personal distinctive properties, capabilities, and limitations. Importantly, one dimension doesn’t match all — completely different fashions and use circumstances will warrant using completely different kernels and completely different optimization methods. This underscores the significance of getting all kinds instruments and methods for optimizing consideration layers.

In a future put up, we hope to additional discover consideration layer optimization by specializing in making use of a few of the instruments we mentioned to sort out the problem of dealing with variable-sized enter sequences. Keep tuned…

Source link

#Rising #Transformer #Mannequin #Effectivity #Consideration #Layer #Optimization #Chaim #Rand #Nov