I Contributed to PyTorch. Here's What I Learned

When you see something that does not work in an omnipresent framework, you believe it can't be completely broken, right?

Alex Dremov
I Contributed to PyTorch. Here's What I Learned

Let me do it for you

The Issue Must Not Be That Bad

That's what I thought when I encountered a PyTorch problem during one of my college assignments. Jupyter kernel was dying because of some bug in the LSTM implementation for MPS.

MPS (Metal Performance Shaders) is an acceleration backend for MacOS that utilizes GPU for computations

After a quick investigation, I discovered that this happens because of the batch_first flag. MPS's backend did not work correctly with it and crushed the entire kernel.

"Easy fix"
P.S. After that phrase, Alex spend the next two days fixing what looked like an "easy fix"

PR was merged pretty quickly. Thanks, PyTorch team, for that! And the story could've ended here, but I discovered a funny detail in MPS tests.

@unittest.skipIf(True, "Backward of lstm returns wrong result")
def test_lstm_2(self, device="mps", dtype=torch.float32)

And LSTM was really bad. It got a whole lot worse score than when trained on CUDA or CPU.

It Was Bad. Really Bad

It turned out that LSTM on MPS was completely broken. The forward pass had a bug with the batch_first flag and hidden cell initialization.

Backward pass used first layers weights for the last layers, mixing up all gradients. It did not calculate gradients for hidden states. And my favourite: the backward function returned initialized with garbage tensors, screwing up all subsequent training. It was a mess that I kept investigating for several days.

Eventually, I fixed LSTM and its tests in a massive PR, ensuring that it is consistent with the CPU.

What I Learned

  • Big projects also have garbage code. Broken implementation lived in stable releases for almost a year, generating several related GitHub issues.
  • Contributing to a big project is fun and challenging. And it eventually helps a lot of developers, which keeps me warm during cold winter nights. Specifically, contributing to PyTorch is extremely simple. Thanks, PyTorch team, for arranging that!
  • Deploying untested code that looks right is extremely dangerous. I listed pretty severe mistakes that I found scrutinizing LSTM sources for several days. There's no way they could have been discovered without extensive testing. Even though the issues were severe, they were also subtle. The code looked right.


I was able to complete the college PyTorch assignment even though it required rewriting PyTorch's LSTM MPS implementation. Consider also solving open issues of your favourite framework or project. At the end of the day, it is a lot more fun than Leetcode problems.

Subscribe and don't miss posts!

See My Work

[MPS] Fix LSTM backward and forward pass by AlexRoar 路 Pull Request #95137 路 pytorch/pytorch
Fixes #91694Fixes #92615Several transpositions were missing for backward graph in case of batch_first=True. The #91694 is not reproduced with batch_first=False.After fixing transpose issue, I fi...
[MPS] Fix bidirectional LSTM & small one-direction LSTM fix by AlexRoar 路 Pull Request #95563 路 pytorch/pytorch
Fixes #94754With this PR I hope to finish my breathtaking journey of fixing MPS LSTM.Here, I enable bidirectional on MPS. Also, I鈥檝e noticed that cache key did not account for all parameters, so ...
[MPS] LSTM grad_y missing fix by AlexRoar 路 Pull Request #96601 路 pytorch/pytorch
Fixes #96416Added tests that do not use LSTM output simalarly to the issueSeems like this fix once again introduces backward incompatibility.
[MPS] LogSoftmax numerical stability by AlexRoar 路 Pull Request #95091 路 pytorch/pytorch
Fixes #94043Calculations are now consistent with numericaly stable formula and CPU:$LogSoftmax(X, \dim) = X - \max(X, \dim) - \log(sum(X - \max(X, \dim), \dim))$@malfet


Subscribe to Alex Dremov

Get the email newsletter and receive valuable tips to bump up your professional skills