# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromtypingimportAny,Dict,List,Optionalimporttorchfromtorch.nnimportDataParallel,Moduleimportpytorch_lightningasplfrompytorch_lightning.overrides.data_parallelimportLightningParallelModulefrompytorch_lightning.plugins.io.checkpoint_pluginimportCheckpointIOfrompytorch_lightning.plugins.precisionimportPrecisionPluginfrompytorch_lightning.strategies.parallelimportParallelStrategyfrompytorch_lightning.utilities.apply_funcimportapply_to_collectionfrompytorch_lightning.utilities.model_helpersimportis_overriddenfrompytorch_lightning.utilities.typesimport_METRIC_COLLECTION,STEP_OUTPUT
[docs]classDataParallelStrategy(ParallelStrategy):"""Implements data-parallel training in a single process, i.e., the model gets replicated to each device and each gets a split of the data."""strategy_name="dp"def__init__(self,accelerator:Optional["pl.accelerators.accelerator.Accelerator"]=None,parallel_devices:Optional[List[torch.device]]=None,checkpoint_io:Optional[CheckpointIO]=None,precision_plugin:Optional[PrecisionPlugin]=None,):super().__init__(accelerator=accelerator,parallel_devices=parallel_devices,cluster_environment=None,checkpoint_io=checkpoint_io,precision_plugin=precision_plugin,)@propertydefglobal_rank(self)->int:return0@propertydeflocal_rank(self)->int:return0@propertydefnode_rank(self)->int:return0@propertydefworld_size(self)->int:return1
[docs]defsetup(self,trainer:"pl.Trainer")->None:# model needs to be moved to the device before it is wrappedself.model_to_device()self.model=self._setup_model(LightningParallelModule(self.model))super().setup(trainer)
[docs]defbatch_to_device(self,batch:Any,device:Optional[torch.device]=None,dataloader_idx:int=0)->Any:"""Moves the batch to the correct device. The input and the output is the same type. Args: batch: The batch of samples to move to the correct device device: The target device dataloader_idx: The index of the dataloader to which the batch belongs. """# DataParallel handles the transfer of batch to the devicereturnbatch
def_setup_model(self,model:Module)->DataParallel:"""Wraps the given model into a :class:`~torch.nn.parallel.DataParallel` module."""returnDataParallel(module=model,device_ids=self.parallel_devices)
[docs]defreduce(self,collection:_METRIC_COLLECTION,*args,**kwargs)->_METRIC_COLLECTION:"""Reduces a collection of tensors from all processes. It can be applied to just a single tensor. Args: collection: The collection of tensors to sync and reduce. *args: ignored for DP **kwargs: ignored for DP Return: Reduced tensor values or the same value if it was not or did not contain a tensor. """defmean(t:torch.Tensor)->torch.Tensor:original_dtype=t.dtypereturnt.float().mean().to(original_dtype)returnapply_to_collection(collection,torch.Tensor,mean)
[docs]defteardown(self)->None:super().teardown()ifself.root_device.type=="cuda":# GPU teardownself.lightning_module.cpu()# clean up memorytorch.cuda.empty_cache()
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.