# Copyright The Lightning AI team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromtypingimportAny,Dict,List,Optional,Unionimporttorchfromlightning_utilities.core.apply_funcimportapply_to_collectionfromtorchimportTensorfromtorch.nnimportDataParallel,Moduleimportpytorch_lightningasplfromlightning_fabric.pluginsimportCheckpointIOfromlightning_fabric.utilities.distributedimportReduceOpfrompytorch_lightning.overrides.baseimport_LightningPrecisionModuleWrapperBasefrompytorch_lightning.overrides.data_parallelimportLightningParallelModulefrompytorch_lightning.plugins.precisionimportPrecisionPluginfrompytorch_lightning.strategies.parallelimportParallelStrategyfrompytorch_lightning.strategies.strategyimportTBroadcast,TReducefrompytorch_lightning.utilities.model_helpersimportis_overriddenfrompytorch_lightning.utilities.typesimportSTEP_OUTPUT
[docs]classDataParallelStrategy(ParallelStrategy):"""Implements data-parallel training in a single process, i.e., the model gets replicated to each device and each gets a split of the data."""strategy_name="dp"def__init__(self,accelerator:Optional["pl.accelerators.Accelerator"]=None,parallel_devices:Optional[List[torch.device]]=None,checkpoint_io:Optional[CheckpointIO]=None,precision_plugin:Optional[PrecisionPlugin]=None,):super().__init__(accelerator=accelerator,parallel_devices=parallel_devices,cluster_environment=None,checkpoint_io=checkpoint_io,precision_plugin=precision_plugin,)@propertydefglobal_rank(self)->int:return0@propertydeflocal_rank(self)->int:return0@propertydefnode_rank(self)->int:return0@propertydefworld_size(self)->int:return1
[docs]defsetup(self,trainer:"pl.Trainer")->None:# model needs to be moved to the device before it is wrappedself.model_to_device()assertisinstance(self.model,(pl.LightningModule,_LightningPrecisionModuleWrapperBase))self.model=self._setup_model(LightningParallelModule(self.model))super().setup(trainer)
[docs]defbatch_to_device(self,batch:Any,device:Optional[torch.device]=None,dataloader_idx:int=0)->Any:"""Moves the batch to the correct device. The input and the output is the same type. Args: batch: The batch of samples to move to the correct device device: The target device dataloader_idx: The index of the dataloader to which the batch belongs. """# DataParallel handles the transfer of batch to the devicereturnbatch
def_setup_model(self,model:Module)->DataParallel:"""Wraps the given model into a :class:`~torch.nn.parallel.DataParallel` module."""returnDataParallel(module=model,device_ids=self.parallel_devices)
[docs]defreduce(self,collection:TReduce,group:Optional[Any]=None,reduce_op:Optional[Union[ReduceOp,str]]="mean")->TReduce:"""Reduces a collection of tensors from all processes. It can be applied to just a single tensor. Args: collection: The collection of tensors to sync and reduce. group: ignored for DP reduce_op: ignored for DP Return: Reduced tensor values or the same value if it was not or did not contain a tensor. """defmean(t:Tensor)->Tensor:original_dtype=t.dtypereturnt.float().mean().to(original_dtype)returnapply_to_collection(collection,Tensor,mean)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.