I’ve recently been digging into the PyTorch FSDP implementation. It’s powerful and highly optimized, which naturally means the codebase is extensive and isn't always straightforward to navigate. In the process, I decided to write a minimal implementation based on my findings, mainly to emphasize and show the different states and pre/post forward/backward of FSDP, all in a single place!
Hope this helps others!