Speed Up PyTorch With Custom Kernels. But It Gets Progressively Darker
PyTorch offers remarkable flexibility, allowing you to code complex GPU-accelerated operations in a matter of seconds. However, this convenience comes at a cost. PyTorch executes your code sequentially, resulting in suboptimal performance. This translates into slower model training, which impacts the iteration cycle of your experiments, the robustness of your team, the financial implications, and so on.
In this post, I’ll explore three strategies for accelerating your PyTorch operations. Each method uses softmax
as our “Hello World” demonstration, but you can swap it with any function you like, and the discussed methods would still apply.
We’ll begin with torch.compile
, move on to writing a custom Triton kernel, and finally dive into designing a CUDA kernel.
So, this post may get complicated, but bear with me.
torch.compile
— A Quick Way to Boost Performance
— Yes.
The torch.compile
is a relatively new API in PyTorch that uses runtime graph capture and kernel fusion under the hood . With one decorator, you can often see speed improvements without significant changes to your code.
Speaking simply, for example, we can speed up calculations by merging operations into one GPU function, which removes overheads of separate GPU calls. Or even better, optimize a chain of operations by replacing them with one equivalent!
Such optimizations are not possible in the regular PyTorch execution mode (eager) as it is eager and executes operations just as they are called in the code.
Softmax Implementation with torch.compile
Below is a simple example showing how to implement and compile a softmax function using torch.compile
. Replace it in your model’s forward pass, and your code (hopefully) runs faster.
import torch
# Our softmax function in PyTorch land
def softmax_pytorch(x):
# Avoid numerical instability by subtracting max
x_max = torch.max(x, dim=-1, keepdim=True).values
x_exp = torch.exp(x - x_max)
return x_exp / torch.sum(x_exp, dim=-1, keepdim=True)
# Let's compile it with torch.compile
@torch.compile
def compiled_softmax(x):
return softmax_pytorch(x)
if __name__ == "__main__":
# Example usage:
input_tensor = torch.randn((2, 4), device="cuda")
output = compiled_softmax(input_tensor)
print("Input:", input_tensor)
print("Compiled Softmax Output:", output)
Pros:
- One line to enable the compiler.
- No black magic rituals needed (except for the dynamic shapes maybe).
Cons:
- The first pass can be slower while it compiles; afterwards, it picks up speed.
- Doesn’t always produce dramatic speed-ups for all models and can occasionally break if your code is too creative.
- Still has problems with handling dynamic shapes.
The ways to debug this is a whole new article.
Triton Code — Write GPU Kernels With Python Breeze
Why Use Triton?
Triton is a language that compiles to efficient GPU kernels while letting you write Pythonic code. It’s used under the hood of PyTorch’s dynamo/inductor stack, but you can also write your own custom ops! For many matrix/tensor operations — like softmax — you can get huge speed-ups. Because why wait for official PyTorch kernels when you can write your own?
Softmax in Triton
Here’s a minimal snippet that shows how we might do a naive softmax forward in Triton. I'll keep it short and sweet for demonstration. In a real project, you’d likely do more advanced tiling and block management.
Check out their guides!
import torch
import triton
import triton.language as tl
@triton.autotune(
configs=[
triton.Config(
kwargs=dict(
BLOCK_SIZE_ROWS=BLOCK_SIZE_ROWS,
num_stages=num_stages,
),
num_warps=num_warps,
num_stages=num_stages,
)
for BLOCK_SIZE_ROWS in (16, 32, 64, 128)
for num_stages in (2, 3, 4)
for num_warps in (2, 4, 8)
],
key=['N_COLS'],
)
@triton.heuristics(
values=dict(
BLOCK_SIZE_COLS=lambda args: triton.next_power_of_2(args['N_COLS'])
)
)
@triton.jit
def softmax_kernel(
input_ptr: tl.tensor,
output_ptr: tl.tensor,
input_row_stride: int,
output_row_stride: int,
n_rows: int,
N_COLS: tl.constexpr,
BLOCK_SIZE_ROWS: tl.constexpr,
BLOCK_SIZE_COLS: tl.constexpr,
num_stages: tl.constexpr
):
input_ptr = tl.make_block_ptr(
base=input_ptr,
shape=(n_rows, N_COLS),
strides=(input_row_stride, 1),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_ROWS, BLOCK_SIZE_COLS),
order=(1, 0),
)
output_ptr = tl.make_block_ptr(
base=output_ptr,
shape=(n_rows, N_COLS),
strides=(output_row_stride, 1),
offsets=(0, 0),
block_shape=(BLOCK_SIZE_ROWS, BLOCK_SIZE_COLS),
order=(1, 0),
)
cols_mask = tl.arange(0, BLOCK_SIZE_COLS) < N_COLS
row_idx = tl.program_id(0) * BLOCK_SIZE_ROWS
in_tile_ptr = tl.advance(input_ptr, (row_idx, 0))
row = tl.load(pointer=in_tile_ptr, boundary_check=(0, 1))
# Subtract maximum for numerical stability
row_minus_max = row - tl.max(row, axis=1, keep_dims=True)
row_minus_max = tl.where(cols_mask, row_minus_max, -float('inf'))
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=1, keep_dims=True)
softmax_output = numerator / denominator
out_tile_ptr = tl.advance(output_ptr, (row_idx, 0))
tl.store(out_tile_ptr, softmax_output, boundary_check=(0, 1))
def softmax(x: torch.Tensor):
x_orig_shape = x.shape
x = x.view(-1, x_orig_shape[-1])
n_rows, n_cols = x.shape
y = torch.empty_like(x, memory_format=torch.contiguous_format)
grid = lambda args: (
triton.cdiv(n_rows, args['BLOCK_SIZE_ROWS']),
1,
1
)
softmax_kernel[grid](
input_ptr=x,
output_ptr=y,
input_row_stride=x.stride(0),
output_row_stride=y.stride(0),
n_rows=n_rows,
N_COLS=n_cols,
)
return y.view(*x_orig_shape)
Indeed, it looks complicated. But the core of the algorithm is summarized in a few lines.
row_minus_max = row - tl.max(row, axis=1, keep_dims=True)
row_minus_max = tl.where(cols_mask, row_minus_max, -float('inf'))
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=1, keep_dims=True)
softmax_output = numerator / denominator
Everything else is just data management and side-hustle.
If we'll conduct benchmarking for different data length, we'll see that we match torch.nn.functional.softmax
performance (which is highly optimized kernel!) and dramatically outperform naive torch implementation.
You may find the full code for the kernel and benchmark in the following github file.
Pros:
- Potentially huge speed-ups by fusing ops and optimizing memory access patterns.
- More control than
torch.compile
. - Easy to write efficient code (we matched torch implementation!)
- Easy to write inefficient code (if you don't know what you're doing).
Cons:
- You’re now the kernel developer, which means debugging if something goes sideways. Which is tough. Really.
- If you go further with custom backward passes, you might need a second coffee… or more. That's because torch cannot use autograd for triton. So you will need to define backward yourself.
- Subscribe so you don't miss a post about usage of triton kernels + autograd + torch.compile tandem.
Pure CUDA (a.k.a. Going Hardcore)
Sometimes even Triton won’t cut it, or you just enjoy living on the edge. In that case, you can write a custom CUDA kernel in C++, compile it, and tie it into PyTorch via a custom extension. Projects like [this fused CUDA softmax reference] show how people build specialized kernels for maximum speed.
Softmax in Custom CUDA
You’ll typically have a setup.py
that compiles a .cu
or .cpp
file and exposes a Python function as an extension.
Checkout CudaSoftmax for self-explanatory example.
I will not provide the code for this method in this post, so this fact speaks for itself. This approach is quite complicated, requires good justification, and usually the last thing you should try doing.
It's very easy to write inefficient, buggy, unsafe code.
Pros:
- Maximum control. “If you want something done right, do it yourself.”
- Potential for the fastest possible kernel if well-optimized.
Cons:
- Requires deep CUDA understanding.
- Memory management, block sizes, shared memory—those are hard!
- Maintenance overhead can be extremely high.
Conclusion
When it comes to speeding up PyTorch operations, you can choose from progressively more intricate methods:
torch.compile
: Minimal code changes needed.- Triton Kernel: More control over kernel behaviour, still quite easy coding.
- Pure CUDA: Maximum optimisation potential, but a lot higher complexity.
If you’re looking for the simplest improvement, start with torch.compile
. If that’s insufficient, explore Triton. For advanced users, writing a custom CUDA kernel can yield further gains, though it demands deep GPU programming skills.
References
- Compiling the optimizer with torch.compile (PyTorch Docs)
- How should I use torch.compile properly? (PyTorch discussion)
- Using User-Defined Triton Kernels with torch.compile (PyTorch Docs)
- Torch.compile with custom Triton kernel (PyTorch discussion)
- GitHub: fattorib/CudaSoftmax
Choose the path that fits your project’s needs and your comfort level. Good luck optimizing!