Shortcuts

Plugins

Plugins allow custom integrations to the internals of the Trainer such as a custom precision or distributed implementation.

Under the hood, the Lightning Trainer is using plugins in the training routine, added automatically depending on the provided Trainer arguments. For example:

# accelerator: GPUAccelerator
# training type: DDPPlugin
# precision: NativeMixedPrecisionPlugin
trainer = Trainer(gpus=4, precision=16)

We expose Accelerators and Plugins mainly for expert users that want to extend Lightning for:

  • New hardware (like TPU plugin)

  • Distributed backends (e.g. a backend not yet supported by PyTorch itself)

  • Clusters (e.g. customized access to the cluster’s environment interface)

There are two types of Plugins in Lightning with different responsibilities:

TrainingTypePlugin

  • Launching and teardown of training processes (if applicable)

  • Setup communication between processes (NCCL, GLOO, MPI, …)

  • Provide a unified communication interface for reduction, broadcast, etc.

  • Provide access to the wrapped LightningModule

PrecisionPlugin

  • Perform pre- and post backward/optimizer step operations such as scaling gradients

  • Provide context managers for forward, training_step, etc.

  • Gradient clipping

Futhermore, for multi-node training Lightning provides cluster environment plugins that allow the advanced user to configure Lighting to integrate with a 4. Custom cluster.

../_images/overview.svg

Create a custom plugin

Expert users may choose to extend an existing plugin by overriding its methods …

from pytorch_lightning.plugins import DDPPlugin


class CustomDDPPlugin(DDPPlugin):
    def configure_ddp(self):
        self._model = MyCustomDistributedDataParallel(
            self.model,
            device_ids=...,
        )

or by subclassing the base classes TrainingTypePlugin or PrecisionPlugin to create new ones. These custom plugins can then be passed into the Trainer directly or via a (custom) accelerator:

# custom plugins
trainer = Trainer(strategy=CustomDDPPlugin(), plugins=[CustomPrecisionPlugin()])

# fully custom accelerator and plugins
accelerator = MyAccelerator(
    precision_plugin=CustomPrecisionPlugin(),
    training_type_plugin=CustomDDPPlugin(),
)
trainer = Trainer(accelerator=accelerator)

The full list of built-in plugins is listed below.

Warning

The Plugin API is in beta and subject to change. For help setting up custom plugins/accelerators, please reach out to us at support@pytorchlightning.ai


Training Type Plugins

TrainingTypePlugin

Base class for all training type plugins that change the behaviour of the training, validation and test- loop.

SingleDevicePlugin

Plugin that handles communication on a single device.

ParallelPlugin

Plugin for training with multiple processes in parallel.

DataParallelPlugin

Implements data-parallel training in a single process, i.e., the model gets replicated to each device and each gets a split of the data.

DDPPlugin

Plugin for multi-process single-device training on one or multiple nodes.

DDP2Plugin

DDP2 behaves like DP in one node, but synchronization across nodes behaves like in DDP.

DDPShardedPlugin

Optimizer and gradient sharded training provided by FairScale.

DDPSpawnShardedPlugin

Optimizer sharded training provided by FairScale.

DDPSpawnPlugin

Spawns processes using the torch.multiprocessing.spawn() method and joins processes after training finishes.

DeepSpeedPlugin

Provides capabilities to run training using the DeepSpeed library, with training optimizations for large billion parameter models.

HorovodPlugin

Plugin for Horovod distributed training integration.

SingleTPUPlugin

Plugin for training on a single TPU device.

TPUSpawnPlugin

Plugin for training multiple TPU devices using the torch.multiprocessing.spawn() method.

Precision Plugins

PrecisionPlugin

Base class for all plugins handling the precision-specific parts of the training.

MixedPrecisionPlugin

Base Class for mixed precision.

NativeMixedPrecisionPlugin

Plugin for Native Mixed Precision (AMP) training with torch.autocast.

ShardedNativeMixedPrecisionPlugin

Native AMP for Sharded Training.

ApexMixedPrecisionPlugin

Mixed Precision Plugin based on Nvidia/Apex (https://github.com/NVIDIA/apex)

DeepSpeedPrecisionPlugin

Precision plugin for DeepSpeed integration.

TPUPrecisionPlugin

TPUBf16PrecisionPlugin

Plugin that enables bfloats on TPUs.

DoublePrecisionPlugin

Plugin for training with double (torch.float64) precision.

FullyShardedNativeMixedPrecisionPlugin

Native AMP for Fully Sharded Training.

IPUPrecisionPlugin

Cluster Environments

ClusterEnvironment

Specification of a cluster environment.

LightningEnvironment

The default environment used by Lightning for a single node or free cluster (not managed).

LSFEnvironment

An environment for running on clusters managed by the LSF resource manager.

TorchElasticEnvironment

Environment for fault-tolerant and elastic training with torchelastic

KubeflowEnvironment

Environment for distributed training using the PyTorchJob operator from Kubeflow

SLURMEnvironment

Cluster environment for training on a cluster managed by SLURM.