deepspeed

Functions

convert_zero_checkpoint_to_fp32_state_dict

Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict file that can be loaded with torch.load(file) + load_state_dict() and used for training without DeepSpeed.

ds_checkpoint_dir

rtype:

str

Utilities that can be used with Deepspeed.

lightning.pytorch.utilities.deepspeed.convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None)[source]

Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict file that can be loaded with torch.load(file) + load_state_dict() and used for training without DeepSpeed. It gets copied into the top level checkpoint dir, so the user can easily do the conversion at any point in the future. Once extracted, the weights don’t require DeepSpeed and can be used in any application. Additionally the script has been modified to ensure we keep the lightning state inside the state dict for being able to run LightningModule.load_from_checkpoint('...')`.

Parameters:
  • checkpoint_dir (_PATH) – path to the desired checkpoint folder. (one that contains the tag-folder, like global_step14)

  • output_file (_PATH) – path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)

  • tag (str | None) – checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named latest in the checkpoint folder, e.g., global_step14

Return type:

dict[str, Any]

Examples:

# Lightning deepspeed has saved a directory instead of a file
convert_zero_checkpoint_to_fp32_state_dict(
    "lightning_logs/version_0/checkpoints/epoch=0-step=0.ckpt/",
    "lightning_model.pt"
)