I believe I identified a bug in partitioned activation checkpointing and I applied a “fix” here:
This is my github repo:
@awaelchli I know you helped me before on this topic, could any one take a look?
Here is the original issue with description of the possible bug:
From pytorch lightning this is called like:
mpu object is set to None in pytorch lightning - this is a pytorch lightning bug I think - to address this inside my custom script I initialize an mpu object, so there should be no problem.
Still, when I run activation_checkpointing without partitioning I get the same memory usage as in the case when partitioning is enabled.
- I checked and partition_activations() method shards the tensor correctly,
- the backward all_gather method also seems to work correctly checkpointing_.py#L663,
- the get_partitioned_activations_for_backward() seems to only save partitions.
Still, Somewhere I suspect that the whole checkpointed layer is stored in memory after forward until backward pass, instead of just a partition of it.
Any help is much appreciated.