Deepspeed partitioned activation checkpointing issues

Hi!
I believe I identified a bug in partitioned activation checkpointing and I applied a “fix” here:
This is my github repo:
github.com/andrasiani/deepspeed_lightning_gpt_partition_activations_checkpointing/tree/master

@awaelchli I know you helped me before on this topic, could any one take a look?

Here is the original issue with description of the possible bug:
lightning.ai/forums/t/deepspeed-stage-3-partition-activations-brings-no-benefit/2915

My fix:
From pytorch lightning this is called like:
deepspeed.checkpointing.configure(
mpu_=None,
partition_activations=True,
contiguous_checkpointing=False,
checkpoint_in_cpu=False,
profile=checkpoint_config.get(“profile”),
)

mpu object is set to None in pytorch lightning - this is a pytorch lightning bug I think - to address this inside my custom script I initialize an mpu object, so there should be no problem.
Relevant line:
checkpointing_.py#L521

Still, when I run activation_checkpointing without partitioning I get the same memory usage as in the case when partitioning is enabled.
deepspeed/runtime/activation_checkpointing/checkpointing.py

  • I checked and partition_activations() method shards the tensor correctly,
  • the backward all_gather method also seems to work correctly checkpointing_.py#L663,
  • the get_partitioned_activations_for_backward() seems to only save partitions.

Still, Somewhere I suspect that the whole checkpointed layer is stored in memory after forward until backward pass, instead of just a partition of it.

Any help is much appreciated.