Source code for pytorch_lightning.plugins.layer_sync
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
import torch
from torch import Tensor
from torch.nn import Module
[docs]class LayerSync(ABC):
    """Abstract base class for creating plugins that wrap layers of a model with synchronization logic for
    multiprocessing."""
[docs]    @abstractmethod
    def apply(self, model: Module) -> Module:
        """Override this method to apply synchronization to the layers of this model."""
[docs]    @abstractmethod
    def revert(self, model: Module) -> Module:
        """Override this method to undo all modifications made in :meth:`apply`."""
[docs]class NativeSyncBatchNorm(LayerSync):
    """A plugin that wraps all batch normalization layers of a model with synchronization logic for
    multiprocessing.
    This plugin has no effect in single-device operation.
    """
[docs]    def apply(self, model: Module) -> Module:
        """Add global batchnorm for a model spread across multiple GPUs and nodes.
        Override this method to synchronize batchnorm layers between specific process groups instead
        of the whole world.
        Args:
            model: Reference to the current LightningModule
        Return:
            LightningModule with batchnorm layers synchronized within the process groups.
        """
        return torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
[docs]    def revert(self, model: Module) -> Module:
        """Convert the wrapped batchnorm layers back to regular batchnorm layers.
        Args:
            model: Reference to the current LightningModule
        Return:
            LightningModule with regular batchnorm layers that will no longer sync across processes.
        """
        # Code adapted from https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547
        # Original author: Kapil Yedidi (@kapily)
        converted_module = model
        if isinstance(model, torch.nn.modules.batchnorm.SyncBatchNorm):
            # Unfortunately, LayerSync does not store the original class - if it did
            # we could return the one that was originally created.
            converted_module = _BatchNormXd(
                model.num_features, model.eps, model.momentum, model.affine, model.track_running_stats
            )
            if model.affine:
                with torch.no_grad():
                    converted_module.weight = model.weight
                    converted_module.bias = model.bias
            converted_module.running_mean = model.running_mean
            converted_module.running_var = model.running_var
            converted_module.num_batches_tracked = model.num_batches_tracked
            if hasattr(model, "qconfig"):
                converted_module.qconfig = model.qconfig
        for name, child in model.named_children():
            converted_module.add_module(name, self.revert(child))
        del model
        return converted_module
class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm):
    def _check_input_dim(self, input: Tensor) -> None:
        # The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc
        # is this method that is overwritten by the subclass.
        # Here, we are bypassing some tensor sanity checks and trusting that the user
        # provides the right input dimensions at inference.
        return