accelerators

Accelerator

The Accelerator base class for Lightning PyTorch.

CPUAccelerator

Accelerator for CPU devices.

CUDAAccelerator

Accelerator for NVIDIA CUDA devices.

XLAAccelerator

Accelerator for XLA devices, normally TPUs.

callbacks

BackboneFinetuning

Finetune a backbone model based on a learning rate user-defined scheduling.

BaseFinetuning

This class implements the base logic for writing your own Finetuning Callback.

BasePredictionWriter

Base class to implement how the predictions should be stored.

BatchSizeFinder

Finds the largest batch size supported by a given model before encountering an out of memory (OOM) error.

Callback

Abstract base class used to build new callbacks.

DeviceStatsMonitor

Automatically monitors and logs device stats during training, validation and testing stage.

EarlyStopping

Monitor a metric and stop training when it stops improving.

GradientAccumulationScheduler

Change gradient accumulation factor according to scheduling.

LambdaCallback

Create a simple callback on the fly using lambda functions.

LearningRateFinder

The LearningRateFinder callback enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate.

LearningRateMonitor

Automatically monitor and logs learning rate for learning rate schedulers during training.

ModelCheckpoint

Save the model periodically by monitoring a quantity.

ModelPruning

Model pruning Callback, using PyTorch's prune utilities.

ModelSummary

Generates a summary of all layers in a LightningModule.

OnExceptionCheckpoint

Used to save a checkpoint on exception.

ProgressBar

The base class for progress bars in Lightning.

RichModelSummary

Generates a summary of all layers in a LightningModule with rich text formatting.

RichProgressBar

Create a progress bar with rich text formatting.

StochasticWeightAveraging

Implements the Stochastic Weight Averaging (SWA) Callback to average a model.

SpikeDetection

ThroughputMonitor

Computes and logs throughput with the Throughput

Timer

The Timer callback tracks the time spent in the training, validation, and test loops and interrupts the Trainer if the given time limit for the training loop is reached.

TQDMProgressBar

This is the default progress bar used by Lightning.

cli

LightningCLI

Implementation of a configurable command line tool for pytorch-lightning.

LightningArgumentParser

Extension of jsonargparse's ArgumentParser for pytorch-lightning.

SaveConfigCallback

Saves a LightningCLI config to the log_dir when training starts.

core

CheckpointHooks

Hooks to be used with Checkpointing.

DataHooks

Hooks to be used for data related stuff.

ModelHooks

Hooks to be used in LightningModule.

LightningDataModule

A DataModule standardizes the training, val, test splits, data preparation and transforms.

LightningModule

HyperparametersMixin

LightningOptimizer

This class is used to wrap the user optimizers and handle properly the backward and optimizer_step logic across accelerators, AMP, accumulate_grad_batches.

loggers

logger

Abstract base class used to build new loggers.

comet

Comet Logger

csv_logs

CSV logger

mlflow

MLflow Logger

neptune

Neptune Logger

tensorboard

TensorBoard Logger

wandb

Weights and Biases Logger

plugins

precision

DeepSpeedPrecision

Precision plugin for DeepSpeed integration.

DoublePrecision

Plugin for training with double (torch.float64) precision.

HalfPrecision

Plugin for training with half precision.

FSDPPrecision

Precision plugin for training with Fully Sharded Data Parallel (FSDP).

MixedPrecision

Plugin for Automatic Mixed Precision (AMP) training with torch.autocast.

Precision

Base class for all plugins handling the precision-specific parts of the training.

XLAPrecision

Plugin for training with XLA.

TransformerEnginePrecision

Plugin for training with fp8 precision via nvidia's Transformer Engine.

BitsandbytesPrecision

Plugin for quantizing weights with bitsandbytes.

environments

ClusterEnvironment

Specification of a cluster environment.

KubeflowEnvironment

Environment for distributed training using the PyTorchJob operator from Kubeflow.

LightningEnvironment

The default environment used by Lightning for a single node or free cluster (not managed).

LSFEnvironment

An environment for running on clusters managed by the LSF resource manager.

MPIEnvironment

An environment for running on clusters with processes created through MPI.

SLURMEnvironment

Cluster environment for training on a cluster managed by SLURM.

TorchElasticEnvironment

Environment for fault-tolerant and elastic training with torchelastic

XLAEnvironment

Cluster environment for training on a TPU Pod with the PyTorch/XLA library.

io

AsyncCheckpointIO

AsyncCheckpointIO enables saving the checkpoints asynchronously in a thread.

CheckpointIO

Interface to save/load checkpoints as they are saved through the Strategy.

TorchCheckpointIO

CheckpointIO that utilizes torch.save() and torch.load() to save and load checkpoints respectively, common for most use cases.

XLACheckpointIO

CheckpointIO that utilizes xm.save to save checkpoints for TPU training strategies.

others

LayerSync

Abstract base class for creating plugins that wrap layers of a model with synchronization logic for multiprocessing.

TorchSyncBatchNorm

A plugin that wraps all batch normalization layers of a model with synchronization logic for multiprocessing.

profiler

AdvancedProfiler

This profiler uses Python's cProfiler to record more detailed information about time spent in each function call recorded during a given action.

PassThroughProfiler

This class should be used when you don't want the (small) overhead of profiling.

Profiler

If you wish to write a custom profiler, you should inherit from this class.

PyTorchProfiler

This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of different operators inside your model - both on the CPU and GPU.

SimpleProfiler

This profiler simply records the duration of actions (in seconds) and reports the mean duration of each action and the total time spent over the entire training run.

XLAProfiler

XLA Profiler will help you debug and optimize training workload performance for your models using Cloud TPU performance tools.

trainer

Trainer

Customize every aspect of training via flags.

strategies

DDPStrategy

Strategy for multi-process single-device training on one or multiple nodes.

DeepSpeedStrategy

Provides capabilities to run training using the DeepSpeed library, with training optimizations for large billion parameter models.

FSDPStrategy

Strategy for Fully Sharded Data Parallel provided by torch.distributed.

ModelParallelStrategy

Enables user-defined parallelism applied to a model.

ParallelStrategy

Strategy for training with multiple processes in parallel.

SingleDeviceStrategy

Strategy that handles communication on a single device.

SingleDeviceXLAStrategy

Strategy for training on a single XLA device.

Strategy

Base class for all strategies that change the behaviour of the training, validation and test- loop.

XLAStrategy

Strategy for training multiple TPU devices using the torch_xla.distributed.xla_multiprocessing.spawn() method.

tuner

Tuner

Tuner class to tune your model.

utilities

combined_loader

data

deepspeed

Utilities that can be used with Deepspeed.

memory

Utilities related to memory.

model_summary

parsing

Utilities used for parameter parsing.

rank_zero

Utilities that can be used for calling functions on a particular rank.

seed

Utilities to help with reproducibility of models.

warnings

Warning-related utilities.

lightning.pytorch.utilities.measure_flops(model, forward_fn, loss_fn=None)[source]

Utility to compute the total number of FLOPs used by a module during training or during inference.

It’s recommended to create a meta-device model for this:

Example:

with torch.device("meta"):
    model = MyModel()
    x = torch.randn(2, 32)

model_fwd = lambda: model(x)
fwd_flops = measure_flops(model, model_fwd)

model_loss = lambda y: y.sum()
fwd_and_bwd_flops = measure_flops(model, model_fwd, model_loss)
Parameters:
  • model (Module) – The model whose FLOPs should be measured.

  • forward_fn (Callable[[], Tensor]) – A function that runs forward on the model and returns the result.

  • loss_fn (Optional[Callable[[Tensor], Tensor]]) – A function that computes the loss given the forward_fn output. If provided, the loss and backward FLOPs will be included in the result.

Return type:

int