accelerators¶
The Accelerator base class for Lightning PyTorch. |
|
Accelerator for CPU devices. |
|
Accelerator for NVIDIA CUDA devices. |
|
Accelerator for XLA devices, normally TPUs. |
callbacks¶
Finetune a backbone model based on a learning rate user-defined scheduling. |
|
This class implements the base logic for writing your own Finetuning Callback. |
|
Base class to implement how the predictions should be stored. |
|
Finds the largest batch size supported by a given model before encountering an out of memory (OOM) error. |
|
Abstract base class used to build new callbacks. |
|
Automatically monitors and logs device stats during training, validation and testing stage. |
|
Monitor a metric and stop training when it stops improving. |
|
Change gradient accumulation factor according to scheduling. |
|
Create a simple callback on the fly using lambda functions. |
|
The |
|
Automatically monitor and logs learning rate for learning rate schedulers during training. |
|
Save the model periodically by monitoring a quantity. |
|
Model pruning Callback, using PyTorch's prune utilities. |
|
Generates a summary of all layers in a |
|
Used to save a checkpoint on exception. |
|
The base class for progress bars in Lightning. |
|
Generates a summary of all layers in a |
|
Create a progress bar with rich text formatting. |
|
Implements the Stochastic Weight Averaging (SWA) Callback to average a model. |
|
Computes and logs throughput with the |
|
The Timer callback tracks the time spent in the training, validation, and test loops and interrupts the Trainer if the given time limit for the training loop is reached. |
|
This is the default progress bar used by Lightning. |
cli¶
Implementation of a configurable command line tool for pytorch-lightning. |
|
Extension of jsonargparse's ArgumentParser for pytorch-lightning. |
|
Saves a LightningCLI config to the log_dir when training starts. |
core¶
Hooks to be used with Checkpointing. |
|
Hooks to be used for data related stuff. |
|
Hooks to be used in LightningModule. |
|
A DataModule standardizes the training, val, test splits, data preparation and transforms. |
|
This class is used to wrap the user optimizers and handle properly the backward and optimizer_step logic across accelerators, AMP, accumulate_grad_batches. |
loggers¶
Abstract base class used to build new loggers. |
|
Comet Logger |
|
CSV logger |
|
MLflow Logger |
|
Neptune Logger |
|
TensorBoard Logger |
|
Weights and Biases Logger |
plugins¶
precision¶
Precision plugin for DeepSpeed integration. |
|
Plugin for training with double ( |
|
Plugin for training with half precision. |
|
Precision plugin for training with Fully Sharded Data Parallel (FSDP). |
|
Plugin for Automatic Mixed Precision (AMP) training with |
|
Base class for all plugins handling the precision-specific parts of the training. |
|
Plugin for training with XLA. |
|
Plugin for training with fp8 precision via nvidia's Transformer Engine. |
|
Plugin for quantizing weights with bitsandbytes. |
environments¶
Specification of a cluster environment. |
|
Environment for distributed training using the PyTorchJob operator from Kubeflow. |
|
The default environment used by Lightning for a single node or free cluster (not managed). |
|
An environment for running on clusters managed by the LSF resource manager. |
|
An environment for running on clusters with processes created through MPI. |
|
Cluster environment for training on a cluster managed by SLURM. |
|
Environment for fault-tolerant and elastic training with torchelastic |
|
Cluster environment for training on a TPU Pod with the PyTorch/XLA library. |
io¶
|
|
Interface to save/load checkpoints as they are saved through the |
|
CheckpointIO that utilizes |
|
CheckpointIO that utilizes |
others¶
Abstract base class for creating plugins that wrap layers of a model with synchronization logic for multiprocessing. |
|
A plugin that wraps all batch normalization layers of a model with synchronization logic for multiprocessing. |
profiler¶
This profiler uses Python's cProfiler to record more detailed information about time spent in each function call recorded during a given action. |
|
This class should be used when you don't want the (small) overhead of profiling. |
|
If you wish to write a custom profiler, you should inherit from this class. |
|
This profiler uses PyTorch's Autograd Profiler and lets you inspect the cost of different operators inside your model - both on the CPU and GPU. |
|
This profiler simply records the duration of actions (in seconds) and reports the mean duration of each action and the total time spent over the entire training run. |
|
XLA Profiler will help you debug and optimize training workload performance for your models using Cloud TPU performance tools. |
trainer¶
Customize every aspect of training via flags. |
strategies¶
Strategy for multi-process single-device training on one or multiple nodes. |
|
Provides capabilities to run training using the DeepSpeed library, with training optimizations for large billion parameter models. |
|
Strategy for Fully Sharded Data Parallel provided by torch.distributed. |
|
Enables user-defined parallelism applied to a model. |
|
Strategy for training with multiple processes in parallel. |
|
Strategy that handles communication on a single device. |
|
Strategy for training on a single XLA device. |
|
Base class for all strategies that change the behaviour of the training, validation and test- loop. |
|
Strategy for training multiple TPU devices using the |
tuner¶
Tuner class to tune your model. |
utilities¶
Utilities that can be used with Deepspeed. |
|
Utilities related to memory. |
|
Utilities used for parameter parsing. |
|
Utilities that can be used for calling functions on a particular rank. |
|
Utilities to help with reproducibility of models. |
|
Warning-related utilities. |
- lightning.pytorch.utilities.measure_flops(model, forward_fn, loss_fn=None)[source]¶
Utility to compute the total number of FLOPs used by a module during training or during inference.
It’s recommended to create a meta-device model for this:
Example:
with torch.device("meta"): model = MyModel() x = torch.randn(2, 32) model_fwd = lambda: model(x) fwd_flops = measure_flops(model, model_fwd) model_loss = lambda y: y.sum() fwd_and_bwd_flops = measure_flops(model, model_fwd, model_loss)
- Parameters:
- Return type: