Source code for pytorch_lightning.utilities.deepspeed
#!/usr/bin/env python# Copyright 2020 The PyTorch Lightning team and Microsoft Corporation. All rights reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Utilities that can be used with Deepspeed."""from__future__importannotationsimportosimporttorchfrompytorch_lightning.utilitiesimport_DEEPSPEED_AVAILABLEfrompytorch_lightning.utilities.typesimport_PATHif_DEEPSPEED_AVAILABLE:fromdeepspeed.utils.zero_to_fp32import(get_fp32_state_dict_from_zero_checkpoint,get_model_state_file,get_optim_files,)CPU_DEVICE=torch.device("cpu")defds_checkpoint_dir(checkpoint_dir:_PATH,tag:str|None=None)->str:iftagisNone:latest_path=os.path.join(checkpoint_dir,"latest")ifos.path.isfile(latest_path):withopen(latest_path)asfd:tag=fd.read().strip()else:raiseValueError(f"Unable to find 'latest' file at {latest_path}")directory=os.path.join(checkpoint_dir,tag)ifnotos.path.isdir(directory):raiseFileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")returndirectory# Modified script from https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/utils/zero_to_fp32.py
[docs]defconvert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir:_PATH,output_file:_PATH,tag:str|None=None)->None:"""Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. It gets copied into the top level checkpoint dir, so the user can easily do the conversion at any point in the future. Once extracted, the weights don't require DeepSpeed and can be used in any application. Additionally the script has been modified to ensure we keep the lightning state inside the state dict for being able to run ``LightningModule.load_from_checkpoint('...')```. Args: checkpoint_dir: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) output_file: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) tag: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` Examples: >>> from pytorch_lightning.utilities.deepspeed import ( ... convert_zero_checkpoint_to_fp32_state_dict ... ) >>> # Lightning deepspeed has saved a directory instead of a file >>> save_path = "lightning_logs/version_0/checkpoints/epoch=0-step=0.ckpt/" # doctest: +SKIP >>> output_path = "lightning_model.pt" # doctest: +SKIP >>> convert_zero_checkpoint_to_fp32_state_dict(save_path, output_path) # doctest: +SKIP Saving fp32 state dict to lightning_model.pt """state_dict=get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir,tag)# additional logic to ensure we keep the lightning state dict as well from rank 0.deepspeed_states=["module","optimizer","lr_scheduler","csr_tensor_module_names","skipped_steps","global_steps","dp_world_size","mp_world_size",]checkpoint_dir=ds_checkpoint_dir(checkpoint_dir)optim_files=get_optim_files(checkpoint_dir)optim_state=torch.load(optim_files[0],map_location=CPU_DEVICE)zero_stage=optim_state["optimizer_state_dict"]["zero_stage"]model_file=get_model_state_file(checkpoint_dir,zero_stage)client_state=torch.load(model_file,map_location=CPU_DEVICE)client_state={key:valueforkey,valueinclient_state.items()ifkeynotindeepspeed_states}# State dict keys will include reference to wrapper LightningDeepSpeedModule# Delete `module` prefix before saving.state_dict={k.partition("module.")[2]:state_dict[k]forkinstate_dict.keys()}client_state["state_dict"]=state_dictprint(f"Saving fp32 state dict to {output_file}")torch.save(client_state,output_file)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.