Source code for lightning.pytorch.plugins.layer_sync

# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod

import torch
from torch import Tensor
from torch.nn import Module
from typing_extensions import override


[docs]class LayerSync(ABC): """Abstract base class for creating plugins that wrap layers of a model with synchronization logic for multiprocessing."""
[docs] @abstractmethod def apply(self, model: Module) -> Module: """Override this method to apply synchronization to the layers of this model."""
[docs] @abstractmethod def revert(self, model: Module) -> Module: """Override this method to undo all modifications made in :meth:`apply`."""
[docs]class TorchSyncBatchNorm(LayerSync): """A plugin that wraps all batch normalization layers of a model with synchronization logic for multiprocessing. This plugin has no effect in single-device operation. """
[docs] @override def apply(self, model: Module) -> Module: """Add global batchnorm for a model spread across multiple GPUs and nodes. Override this method to synchronize batchnorm layers between specific process groups instead of the whole world. Args: model: Reference to the current LightningModule Return: LightningModule with batchnorm layers synchronized within the process groups. """ return torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
[docs] @override def revert(self, model: Module) -> Module: """Convert the wrapped batchnorm layers back to regular batchnorm layers. Args: model: Reference to the current LightningModule Return: LightningModule with regular batchnorm layers that will no longer sync across processes. """ # Code adapted from https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547 # Original author: Kapil Yedidi (@kapily) converted_module = model if isinstance(model, torch.nn.modules.batchnorm.SyncBatchNorm): # Unfortunately, LayerSync does not store the original class - if it did # we could return the one that was originally created. converted_module = _BatchNormXd( model.num_features, model.eps, model.momentum, model.affine, model.track_running_stats ) if model.affine: with torch.no_grad(): converted_module.weight = model.weight converted_module.bias = model.bias converted_module.running_mean = model.running_mean converted_module.running_var = model.running_var converted_module.num_batches_tracked = model.num_batches_tracked if hasattr(model, "qconfig"): converted_module.qconfig = model.qconfig for name, child in model.named_children(): converted_module.add_module(name, self.revert(child)) del model return converted_module
class _BatchNormXd(torch.nn.modules.batchnorm._BatchNorm): @override def _check_input_dim(self, input: Tensor) -> None: # The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc # is this method that is overwritten by the subclass. # Here, we are bypassing some tensor sanity checks and trusting that the user # provides the right input dimensions at inference. return