Deepspeed partitioned activation checkpointing issues

I believe I identified a bug in partitioned activation checkpointing and I applied a “fix” here:
This is my github repo:

@awaelchli I know you helped me before on this topic, could any one take a look?

Here is the original issue with description of the possible bug:

My fix:
From pytorch lightning this is called like:

mpu object is set to None in pytorch lightning - this is a pytorch lightning bug I think - to address this inside my custom script I initialize an mpu object, so there should be no problem.
Relevant line:

Still, when I run activation_checkpointing without partitioning I get the same memory usage as in the case when partitioning is enabled.

  • I checked and partition_activations() method shards the tensor correctly,
  • the backward all_gather method also seems to work correctly,
  • the get_partitioned_activations_for_backward() seems to only save partitions.

Still, Somewhere I suspect that the whole checkpointed layer is stored in memory after forward until backward pass, instead of just a partition of it.

Any help is much appreciated.