# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromtypingimportDictimporttorchfrompytorch_lightning.strategies.ddpimportDDPStrategyfrompytorch_lightning.utilities.apply_funcimportapply_to_collectionfrompytorch_lightning.utilities.typesimport_METRIC_COLLECTION
[docs]classDDP2Strategy(DDPStrategy):"""DDP2 behaves like DP in one node, but synchronization across nodes behaves like in DDP."""strategy_name="ddp2"@propertydefglobal_rank(self)->int:returnself.node_rank@propertydefworld_size(self)->int:returnself.num_nodes
[docs]defreduce(self,collection:_METRIC_COLLECTION,*args,**kwargs)->_METRIC_COLLECTION:"""Reduces a collection of tensors from all processes. It can be applied to just a single tensor. In DDP2, the reduction here is only across local devices within the node. Args: collection: The collection of tensors to sync and reduce. *args: ignored for DDP2 **kwargs: ignored for DDP2 Return: Reduced tensor values or the same value if it was not or did not contain a tensor. """defmean(t:torch.Tensor)->torch.Tensor:original_dtype=t.dtypereturnt.float().mean().to(original_dtype)returnapply_to_collection(collection,torch.Tensor,mean)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.