is the part of a series of posts on the topic of analyzing and optimizing PyTorch models. Throughout the series, we have advocated for using the PyTorch Profiler in AI model development and demonstrated the potential impact of performance optimization on the speed and cost of running AI/ML workloads. One common phenomenon we have seen is how seemingly innocent code can hamper runtime performance. In this post, we explore some of the penalties associated with the naive use of variable-shaped tensors — tensors whose shape is dependent on preceding computations and/or inputs. While not applicable to all situations, there are times when the use of variable-shaped tensors can be avoided — although this may come at the expense of additional compute and/or memory. We will demonstrate the tradeoffs of these alternatives on a toy implementation of data sampling in PyTorch.
Three Downsides of Variable Shaped Tensors
We motivate the discussion by presenting three disadvantages to the use of variable-shaped tensors:
Host-Device Sync Events
In an ideal scenario, the CPU and GPU are able to run in parallel in an asynchronous manner, with the CPU continuously feeding the GPU with input samples, allocating required GPU memory, and loading GPU compute kernels, and the GPU executing the loaded kernels on the provided inputs using the allocated memory. The presence of dynamic-shaped tensors throws a wrench into this parallelism. In order to allocate the appropriate amount memory, the CPU must wait for the GPU to report the tensor’s shape, and then the GPU must wait for the CPU to allocate the memory and proceed with the kernel loading. The overhead of this sync event can cause a drop in the GPU utilization and slow runtime performance.
We saw an example of this in part three of this series when we studied a naive implementation of the common cross-entropy loss that included calls to torch.nonzero and torch.unique. Both APIs return tensors with shapes that are dynamic and dependent on the contents of the input. When these functions are run on the GPU, a host-device synchronization event occurs. In the case of the cross-entropy loss, we discovered the inefficiency through the use of PyTorch Profiler and were able to easily overcome it with an alternative implementation that avoided the use of variable-shaped tensors and demonstrated much better runtime performance.
Graph Compilation
In a recent post we explored the performance benefits of applying just-in-time (JIT) compilation using the torch.compile operator. One of our observations was that graph compilation provided much better results when the graph was static. The presence of dynamic shapes in the graph limits the extent of the optimization via compilation: In some cases, it fails completely; in others it results in lower performance gains. The same implications also apply to other forms of graph compilation, such as XLA, ONNX, OpenVINO, and TensorRT.
Data Batching
Another optimization we have encountered in several of our posts (e.g., here) is sample-batching. Batching improves performance in two primary ways:
- Reducing overhead of kernel loading: Rather than loading the GPU kernels required for the computation pipeline once per input sample, the CPU can load the kernels once per batch.
- Maximizing parallelization across compute units: GPUs are highly parallel compute engines. The more we are able to parallelize computation, the more we can saturate the GPU and increase its utilization. By batching we can potentially increase the degree of parallelization by a factor of the batch size.
Despite their downsides, the use of variable-shaped tensors is often unavoidable. But sometimes we can modify our model implementation to circumvent them. Sometimes these changes will be straightforward (as in the cross-entropy loss example). Other times they may require some creativity in coming up with a different sequence of fixed-shape PyTorch APIs that provide the same numerical result. Often, this effort can deliver meaningful rewards in runtime and costs.
In the next sections, we will study the use of variable-shaped tensors in the context of the data sampling operation. We will start with a trivial implementation and analyze its performance. We will then propose a GPU-friendly alternative that avoids the use of variable-shaped tensors.
To compare our implementations, we will use an Amazon EC2 g6e.xlarge with an NVIDIA L40S running an AWS Deep Learning AMI (DLAMI) with PyTorch (2.8). The code we will share is intended for demonstration purposes. Please do not rely on it for accuracy or optimality. Please do not interpret our mention of any framework, library, or platform and an endorsement of its use.
Sampling in AI Model Workloads
In the context of this post, sampling refers to the selection of a subset of items from a large set of candidates for the purposes of computational efficiency, balancing of datatypes, or regularization. Sampling is common in many AI/ML models, such as detection, ranking, and contrastive learning systems.
We define a simple variation of the sampling problem: Given a list of N tensors each with a binary label, we are asked to return a subset of K tensors containing both positive and negative examples, in random order. If the input list contains enough samples of each label (K/2), the returned subset should be evenly split. If it is lacking samples of one type, these should be filled with random samples of the second type.
The code block below contains a PyTorch implementation of our sampling function. The implementation is inspired by the popular Detectron2 library (e.g., see here and here). For the experiments in this post, we will fix the sampling ratio to 1:10.
import torch
INPUT_SAMPLES = 10000
SUB_SAMPLE = INPUT_SAMPLES // 10
FEATURE_DIM = 16
def sample_data(input_array, labels):
device = labels.device
positive = torch.nonzero(labels == 1, as_tuple=True)[0]
negative = torch.nonzero(labels == 0, as_tuple=True)[0]
num_pos = min(positive.numel(), SUB_SAMPLE//2)
num_neg = min(negative.numel(), SUB_SAMPLE//2)
if num_neg Performance Analysis With PyTorch Profiler
Even when not immediately obvious, the use of dynamic shapes is easily identifiable in the PyTorch Profiler Trace view. We use the following function to enable PyTorch Profiler:
def profile(fn, input, labels):
def export_trace(p):
p.export_chrome_trace(f"{fn.__name__}.json")
with torch.profiler.profile(
activities=[torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA],
with_stack=True,
schedule=torch.profiler.schedule(wait=0, warmup=10, active=5),
on_trace_ready=export_trace
) as prof:
for _ in range(20):
fn(input, labels)
torch.cuda.synchronize() # explicit sync for trace readability
prof.step()
# create random input
input_samples = torch.randn((INPUT_SAMPLES, FEATURE_DIM), device='cuda')
labels = torch.randint(0, 2, (INPUT_SAMPLES,),
device='cuda', dtype=torch.int64)
# run with profiler
profile(sample_data, input_samples, labels) The image below was captured for the value of ten million input samples. It clearly shows the presence of sync events coming from the torch.nonzero call, as well as the corresponding drops in GPU utilization:

The use of torch.nonzero in our implementation is not ideal, but can it be avoided?
A GPU-Friendly Data Sampler
We propose an alternative implementation of our sampling function that replaces the dynamic torch.nonzero function with a creative combination of the static torch.count_nonzero, torch.topk, and other APIs:
def opt_sample_data(input, labels):
pos_mask = labels == 1
neg_mask = labels == 0
num_pos_idxs = torch.count_nonzero(pos_mask, dim=-1)
num_neg_idxs = torch.count_nonzero(neg_mask, dim=-1)
half_samples = labels.new_full((), SUB_SAMPLE // 2)
num_pos = torch.minimum(num_pos_idxs, half_samples)
num_neg = torch.minimum(num_neg_idxs, half_samples)
num_pos = torch.where(
num_neg 1:
# unsqueeze to support batched input
arange = arange.unsqueeze(0)
num_pos = num_pos.unsqueeze(-1)
num_neg = num_neg.unsqueeze(-1)
top_pos_rand = torch.where(arange >= num_pos, -1, top_pos_rand)
# repeat for neg entries
top_neg_rand, top_neg_idx = torch.topk(neg_rand, k=SUB_SAMPLE)
top_neg_rand = torch.where(arange >= num_neg, -1, top_neg_rand)
# combine and mix together positive and negative idxs
cat_rand = torch.cat([top_pos_rand, top_neg_rand], dim=-1)
cat_idx = torch.cat([top_pos_idx, top_neg_idx], dim=-1)
topk_rand_idx = torch.topk(cat_rand, k=SUB_SAMPLE)[1]
sampled_idxs = torch.gather(cat_idx, dim=-1, index=topk_rand_idx)
sampled_input = torch.gather(input, dim=-2,
index=sampled_idxs.unsqueeze(-1))
sampled_labels = torch.gather(labels, dim=-1, index=sampled_idxs)
return sampled_input, sampled_labels Clearly, this function requires more memory and more operations than our first implementation. The question is: Do the performance benefits of a static, synchronization-free implementation outweigh the extra cost in memory and compute?
To assess the tradeoffs between the two implementations, we introduce the following benchmarking utility:
def benchmark(fn, input, labels):
# warm-up
for _ in range(20):
_ = fn(input, labels)
iters = 100
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start.record()
for _ in range(iters):
_ = fn(input, labels)
end.record()
torch.cuda.synchronize()
avg_time = start.elapsed_time(end) / iters
print(f"{fn.__name__} average step time: {(avg_time):.4f} ms")
benchmark(sample_data, input_samples, labels)
benchmark(opt_sample_data, input_samples, labels) The following table compares the average runtime of each of the implementations for a variety of input sample sizes:

For most of the input sample sizes, the overhead of the host-device sync event is either comparable or lower than the additional compute of the static implementation. Disappointingly, we only see a major benefit from the sync-free alternative when the input sample size reaches ten million. Sample sizes that large are uncommon in AI/ML settings. But it’s not our tendency to give up so easily. As noted above, the static implementation enables other optimizations like graph compilation and input batching.
Graph Compilation
Contrary to the original function — which fails to compile — our static implementation is fully compatible with torch.compile:
benchmark(torch.compile(opt_sample_data), input_samples, labels) The following table includes the runtimes of our compiled function:

The results are significantly better — providing a 70–75 percent boost over the original sampler implementation in the 1–10 thousand range. But we still have one more optimization up our sleeve.
Maximizing Performance with Batched Input
Because the original implementation contains variable-shaped operations, it cannot handle batched input directly. To process a batch, we have no choice but to apply it to each input individually, in a Python loop:
BATCH_SIZE = 32
def batched_sample_data(inputs, labels):
sampled_inputs = []
sampled_labels = []
for i in range(inputs.size(0)):
inp, lab = sample_data(inputs[i], labels[i])
sampled_inputs.append(inp)
sampled_labels.append(lab)
return torch.stack(sampled_inputs), torch.stack(sampled_labels) In contrast, our optimized function supports batched inputs as is — no changes necessary.
input_batch = torch.randn((BATCH_SIZE, INPUT_SAMPLES, FEATURE_DIM),
device='cuda')
labels = torch.randint(0, 2, (BATCH_SIZE, INPUT_SAMPLES),
device='cuda', dtype=torch.int64)
benchmark(batched_sample_data, input_batch, labels)
benchmark(opt_sample_data, input_batch, labels)
benchmark(torch.compile(opt_sample_data), input_batch, labels) The table below compares the step times of our sampling functions on a batch size of 32:

Now the results are definitive: By using a static implementation of the data sampler, we are able to boost performance by 2X–52X(!!) the variable-shaped option, depending on the input sample size.
Note that although our experiments were run on a GPU device, the model compilation and input batching optimizations also apply to a CPU environment. Thus, avoiding variable shapes could have implications on AI/ML model performance on CPU, as well.
Summary
The optimization process we demonstrated in this post generalizes beyond the specific case of data sampling:
- Discovery via Performance Profiling: Using the PyTorch Profiler we were able to identify drops in GPU utilization and discover their source: the presence of variable-shaped tensors resulting from the torch.nonzero operation.
- An Alternate Implementation: Our profiling findings allowed us to develop an alternative implementation that accomplished the same goal while avoiding the use of variable-shaped tensors. However, this step came at the cost of additional compute and memory overhead. As seen in our initial benchmarks, the sync-free alternative demonstrated worse performance on common input sizes.
- Unlocking Further Potential for Optimization: The true breakthrough came because the static-shaped implementation was compilation-friendly and supported batching. These optimizations provided performance gains that dwarfed the initial overhead, leading to a 2x to 52x speedup over the original implementation.
Naturally, not all stories will end as happily as ours. In many cases, we may come across PyTorch code that performs poorly on the GPU but does not have an alternative implementation, or it may have one that requires significantly more compute resources. However, given the potential for meaningful gains in performance and reductions in cost, the process of identifying runtime inefficiencies and exploring alternative implementations is an essential part of AI/ML development.
Source link
#Overcoming #Hidden #Performance #Traps #VariableShaped #Tensors #Efficient #Data #Sampling #PyTorch








