Lightning AI Studios: Never set up a local environment again →

Lightning 2.0: Fast, Flexible, Stable

Lightning AI is excited to announce the release of Lightning 2.0 ⚡

Over the last couple of years PyTorch Lightning has become the preferred deep learning framework for researchers and ML developers around the world, with close to 50 million downloads and 18k OSS projects, from top universities to leading labs.

With the help of over 800 contributors, we have added many features and functionalities to make it the most complete research toolkit possible, but some of these changes also introduced issues:

  • API changes to the trainer
  • Trainer code became harder to follow
  • Many integrations made Lightning appear bloated
  • The trainer became harder to customize / takes away what I instead need to tweak / have control over.

To make the research experience better, we are introducing 2.0:

  • No API changes – We commit to backward compatibility in the 2.0 series
  • Simplified abstraction layers, removed legacy functionality, integrations out of the main repo. This improves the project’s readability and debugging experience.
  • Introducing Fabric. Scale any PyTorch model with just a few lines of code. Read-on!

Highlights

PyTorch 2.0 and torch.compile

Lightning 2.0 is best friends with PyTorch 2.0. You can torch.compile your LightningModules now!

import torch
import lightning as L

model = LitModel()
# This will compile forward and {training,validation,test,predict}_step 
compiled_model = torch.compile(model)

trainer = L.Trainer()
trainer.fit(compiled_model)

PyTorch reports that on average, “models runs 43% faster in training on an NVIDIA A100 GPU. At Float32 precision, it runs 21% faster on average and at AMP Precision it runs 51% faster on average” (source). If you want to learn more about torch.compile and how such speedups can be achieved, read the official PyTorch 2.0 blog post.

Automatic accelerator selection (#16847)

The Trainer now chooses accelerator="auto", strategy="auto", devices="auto" as defaults. This automatically detects the best hardware on your system (TPUs, GPUs, Apple Silicon, etc.) and chooses as many devices as are available.

import lightning as L

# Selects accelerator, devices and strategy automatically!
trainer = L.Trainer()

# Same as:
trainer = L.Trainer(accelerator="auto", strategy="auto", devices="auto")

For example, on a 8-GPU server, this will implicitly select Trainer(accelerator="cuda", strategy="ddp", devices=8).

Support for arbitrary iterables (#16726)

Previously, the Trainer supported DataLoader-like iterables. However, with this release, users can now work with any iterable that implements the Python iterable definition. This includes custom data structures, such as user-defined classes and generators, as well as built-in Python objects.

To use this new feature, return any iterable (or collection of iterables) from the dataloader hooks.

    return DataLoader(...)
    return list(range(1000))
    
    # pass loaders as a dict. This will create batches like this:
    # {'a': batch_from_loader_a, 'b': batch_from_loader_b}
    return {"a": DataLoader(...), "b": DataLoader(...)}
    
    # pass loaders as list. This will create batches like this:
    # [batch_from_dl_1, batch_from_dl_2]
    return [DataLoader(...), DataLoader(...)]
    
    # arbitrary nesting
    # {'a': [batch_from_dl_1, batch_from_dl_2], 'b': [batch_from_dl_3, batch_from_dl_4]}
    return {"a": [dl1, dl2], "b": [dl3, dl4]}

Read our data section for more information.

Redesigned multi-dataloader support (#16743#16784#16939)

Lightning automatically collates the batches from multiple iterables based on a “mode”. This is done with our newly revamped CombinedLoader class.

from lightning.pytorch.utilities import CombinedLoader

iterables = {"a": DataLoader(), "b": DataLoader()}
# Lightning uses this under the hood, but this way you can change the "mode"
combined_loader = CombinedLoader(iterables, mode="min_size")

model = ...
trainer = Trainer()
trainer.fit(model, combined_loader)

The following modes are supported:

  • min_size: stops after the shortest iterable (the one with the lowest number of items) is done.
  • max_size_cycle: stops after the longest iterable (the one with most items) is done, while cycling through the rest of the iterables.
  • max_size: stops after the longest iterable (the one with most items) is done, while returning None for the exhausted iterables.
  • sequential: completely consumes ecah iterable sequentially, and returns a triplet (data, idx, iterable_idx)

If you have a need for a different “mode”, feel free to open a feature request! Adding new modes is now very simplified. These improvements also allowed us to simplify the trainer’s loops by abstracting this logic inside the CombinedLoader.

Barebones Trainer mode (#16854)

A new Trainer argument Trainer(barebones=...) was added (default is False) to disable all features that may impact the raw speed of the training loop. This allows users to quickly and fairily compare the runtime of a Lightning script with a raw PyTorch script.

This is how you enable it:

import lightning as L

# Default: False
trainer = L.Trainer(barebones=True)

A message informs about the changed settings:

You are running in `Trainer(barebones=True)` mode. All features that may impact raw speed have been disabled to facilitate analyzing the Trainer overhead. Specifically, the following features are deactivated:
 - Checkpointing: `Trainer(enable_checkpointing=True)`
 - Progress bar: `Trainer(enable_progress_bar=True)`
 - Model summary: `Trainer(enable_model_summary=True)`
 - Logging: `Trainer(logger=True)`, `Trainer(log_every_n_steps>0)`, `LightningModule.log(...)`, `LightningModule.log_dict(...)`
 - Sanity checking: `Trainer(num_sanity_val_steps>0)`
 - Development run: `Trainer(fast_dev_run=True)`
 - Anomaly detection: `Trainer(detect_anomaly=True)`
 - Profiling: `Trainer(profiler=...)`

Tip: This feature is also very useful for unit testing!

Better progress bar (#16695)

Based on feedback from users, we decided to separate the training progress bar from the validation bar. This greatly improves the time estimates (since validation is usually faster) and resolves confusion around the total batches being processed in an epoch.

This is how the bar looked in versions before 2.0:

Epoch 3:  21%|██        | 28/128 [00:36<01:32, 23.12it/s, loss=0.163]
Validation DataLoader 0:  38%|███      | 12/32 [00:12<00:20,  1.01s/it]

Note how the total batches (128) is the sum of the training batches (32) and the three validation runs (3 x 32). And this is how the progress bar looks like now:

Epoch 3:  50%|█████     | 16/32 [00:36<01:32, 23.12it/s]
Validation DataLoader 0:  38%|███      | 12/32 [00:12<00:20,  1.01s/it]

Note how the batch counts are now separate. The training progress bar pauses until validation is completed.

Lightning Fabric

Lightning 2.0 is the official release for Lightning Fabric ?

Fabric is the fast and lightweight way to scale PyTorch models without boilerplate code.

  • Easily switch from running on CPU to GPU (Apple Silicon, CUDA, …), TPU, multi-GPU or even multi-node training
  • State-of-the-art distributed training strategies (DDP, FSDP, DeepSpeed) and mixed precision out of the box
  • Handles all the boilerplate device logic for you
  • Brings useful tools to help you build a trainer (callbacks, logging, checkpoints, …)
  • Designed with multi-billion parameter models in mind

? Go to Fabric documentation ?

  import torch
  import torch.nn as nn
  from torch.utils.data import DataLoader, Dataset

+ from lightning.fabric import Fabric

  class PyTorchModel(nn.Module):
      ...

  class PyTorchDataset(Dataset):
      ...

+ fabric = Fabric(accelerator="cuda", devices=8, strategy="ddp")
+ fabric.launch()

- device = "cuda" if torch.cuda.is_available() else "cpu"
  model = PyTorchModel(...)
  optimizer = torch.optim.SGD(model.parameters())
+ model, optimizer = fabric.setup(model, optimizer)
  dataloader = DataLoader(PyTorchDataset(...), ...)
+ dataloader = fabric.setup_dataloaders(dataloader)
  model.train()

  for epoch in range(num_epochs):
      for batch in dataloader:
          input, target = batch
-         input, target = input.to(device), target.to(device)
          optimizer.zero_grad()
          output = model(input)
          loss = loss_fn(output, target)
-         loss.backward()
+         fabric.backward(loss)
          optimizer.step()
          lr_scheduler.step()

Backward Incompatible Changes

This section outlines notable changes that are not backward compatible with previous versions. The full list of changes and removals can be found in the Full Changelog below.

Since 2.0 is a major release, we took the opportunity to take our APIs to the next level and make considerable changes to reduce the backwards incompatible changes in the future. To alleviate this, we will commit to continue supporting the 1.9.x line of releases by doing bug-fix releases with any important fixes that are necessary.

PyTorch

The *_epoch_end hooks were removed (#16520)

Since the very beginning of Lightning, the LightningModule offered the convenient *_epoch_end end hooks in which users could reduce metrics collected across the entire epoch to log them. For this to work, Lightning had to store all the outputs returned from the *_step methods internally to be able to send them to the *_epoch_end hook.

This “silent” accumulation of memory lead to many users scratching their head when they found that after training for 20 hours the epoch crashed randomly with an out-of-memory error. This also meant that avoiding this behaviour required code changes, as simply overriding this hook would force this behaviour, regardless of whether you used the outputs. This was exacerbated by users not knowing the difference between the on_training_epoch_end (does not store outputs) and training_epoch_end (does store outputs) hooks.

Based on this feedback, we decided to remove this mechanism completely. Lightning 2.0 favors simplicity and speed over convenience features.

Before:

import lightning as L

class LitModel(L.LightningModule):
    
    def training_step(self, batch, batch_idx):
        ...
        return {"loss": loss, "banana": banana}
    
    # `outputs` is a list of all bananas returned in the epoch
    def training_epoch_end(self, outputs):
        avg_banana = torch.cat(out["banana"] for out in outputs).mean()

Now:

import lightning as L

class LitModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        # 1. Create a list to hold the outputs of `*_step`
        self.bananas = []
    
    def training_step(self, batch, batch_idx):
        ...
        # 2. Add the outputs to the list
        # You should be aware of the implications on memory usage
        self.bananas.append(banana)
        return loss
    
    # 3. Rename the hook to `on_*_epoch_end`
    def on_training_epoch_end(self):
        # 4. Do something with all outputs
        avg_banana = torch.cat(self.bananas).mean()
        # Don't forget to clear the memory for the next epoch!
        self.bananas.clear()

The new way makes the user to manage their own list of outputs, and with it also the responsibility of managing the memory correctly. If an OOM happens, the user can identify the issue in their own code rather than having to guess what the Trainer does.

The prediction epoch end hook is a special case of this change where the outputs argument has been removed from on_predict_epoch_end(trainer, pl_module) but can still be accessed via the attribute trainer.predict_loop.outputs

You can find a migration guide for this change in this PR’s description.

Working with multiple optimizers (#16539)

Lightning 2.0 removed the special optimizer loop that would manage multiple optimizers in the automatic mode in favor of an overall easier to understand and debuggable training loop code.
Training with multiple optimizers is now restricted to the “manual optimization mode”:

Before:

import lightning as L

class LitModel(L.LightningModule):
    
    def configure_optimizers(self):
        ...
        return optimizer1, optimizer2

    def training_step(self, batch, batch_idx, optimizer_idx)
        if optimizer_idx == 0:
            ...
            return loss0
        if optimizer_idx == 1:
            ...
            return loss1

Now:

import lightning as L

class LitModel(L. LightningModule):
    
    def __init__(self):
        super().__init__()
        # 1. Switch to manual optimization mode
        self.automatic_optimization = False
    
    def configure_optimizers(self):
        ...
        # 2. Return multiple optimizers, same as before
        return optimizer1, optimizer2

    # 3. Remove the `optimizer_idx` argument from `training_step`
    def training_step(self, batch, batch_idx)
        
        # 4. Grab all optimizers you want to work with
        optimizer1, optimizer2 = self.optimizers()
        ...
        
        # 5. Perform backward manually, step optimzers, etc.
        self.manual_backward(loss0)
        optimizer1.step()
        optimizer1.zero_grad()
        
        # 6. In manual optimization, you don't need to return anything
        return None

You may also find the utility methods self.toggle_optimizer() and self.untoggle_optimizer() useful if you need to restrict parameters that require gradients to a specific optimizer. For a complete example, see our simple GAN implementation.

Truncated backpropagation through time (TBPTT) (#16172)

Similar to the multi-optimizer loop mentioned above, truncated backpropagation through time (TBPTT) was a loop that added a lot of complexity to the Trainer. Over time, TBPTT has fallen out of fashion and today the demand from users is so low that we decided to drop special support in the framework in favor of simplifying the Trainer. TBPTT can still be done in manual optimization.

Before:

import lightning as L

class LitModel(L.LightningModule):

    def __init__(self):
        super().__init__()
        self.truncated_bptt_steps = 10
        self.my_rnn = ...
        
    def training_step(self, batch, batch_idx, hiddens):
        ...
        loss, hiddens = self.my_rnn(..., hiddens)
        ...
        return loss, hiddens

Now:

import lightning as L

class LitModel(L.LightningModule):

    def __init__(self):
        super().__init__()
        
         # 1. Switch to manual optimization
        self.automatic_optimization = False
        
        self.truncated_bptt_steps = 10
        self.my_rnn = ...
        
    # 2. Remove the `hiddens` argument
    def training_step(self, batch, batch_idx):
        
        # 3. Split the batch in chunks along the time dimension
        split_batches = split_batch(batch, self.truncated_bptt_steps)
        
        hiddens = ...  # 3. Choose the initial hidden state
        for split_batch in range(split_batches):
            # 4. Perform the optimization in a loop
            loss, hiddens = self.my_rnn(split_batch, hiddens)
            self.backward(loss)
            optimizer.step()
            optimizer.zero_grad()
            
            # 5. "Truncate"
            hiddens = hiddens.detach()
        
        # 6. Remove the return of `hiddens`
        # Returning loss in manual optimization is not needed
        return None

Working with multiple dataloaders (#16800#16753)

To simplify the Trainer interface and with the goal of simpler iterable support inside the Trainer, we removed theTrainer(multiple_trainloader_mode=...) argument. The mode is now agnostic to the trainer stage (“train” previously) and it’s easier to debug and understand for the user as the logic is all encapsulated in the CombinedLoader

Before:

import lightning as L

class LitModel(L.LightningModule):
    def train_dataloader(self):
        return [DataLoader(), DataLoader()]

model = LitModel()
trainer = Trainer(multiple_trainloader_mode="min_size")
trainer.fit(model)

Now:

import lightning as L
from lightning.pytorch.utilities import CombinedLoader

class LitModel(L.LightningModule):
    def train_dataloader(self):
        iterables = [DataLoader(), DataLoader()]
        return CombinedLoader(iterables, mode="min_size")

model = LitModel()
trainer = Trainer()
trainer.fit(model)

Related to this, we cleaned up which hooks need the dataloader_idx as an input argument. Now it’s only required if you use multiple dataloaders. Don’t worry, the Trainer will automatically check if it’s required for you and tell you about it.

Accessing dataloaders (#16726#16800)

In the case that you require access to the DataLoader or Dataset objects, iterables for each step can be accessed
via the trainer properties Trainer.train_dataloaderTrainer.val_dataloadersTrainer.test_dataloaders, and Trainer.predict_dataloaders.

These properties will match exactly what was returned in your *_dataloader hooks or passed to the Trainer, meaning that if you returned a dictionary of dataloaders, these will return a dictionary of dataloaders. This wasn’t the case in before 2.0:

Before:

# Passing 1 dataloader per stage
train_dataset = trainer.train_dataloader.loaders.dataset
val_dataset = trainer.val_dataloaders[0].dataset

# Passing 2 dataloaders per stage
train_dataset = trainer.train_dataloader.loaders[0].dataset
val_dataset = trainer.val_dataloaders[0].dataset

Now:

# Passing 1 dataloader per stage
train_dataset = trainer.train_dataloader.dataset
val_dataset = trainer.val_dataloaders.dataset

# Passing 2 dataloaders per stage
train_dataset = trainer.train_dataloader[0].dataset
val_dataset = trainer.val_dataloaders[0].dataset

The Tuner and Trainer broke up (#16462)

The Tuner and Trainer are no longer together. The two Trainer arguments Trainer(auto_lr_find=..., auto_scale_batch_size=...) and the Trainer.tune() method were removed to make the Trainer leaner and easier to work with.

Before:

import lightning as L

# Tune learning rate
trainer = L.Trainer(auto_lr_find=True)
trainer.tune(model)

# Tune batch size
trainer = L.Trainer(auto_scale_batch_size=True)
trainer.tune(model)

# Fit using tuned settings
trainer.fit(model)

Now:

import lightning as L

# 1. Create the Trainer
trainer = L.Trainer()

# 2. Create the Tuner
tuner = L.pytorch.tuner.Tuner(trainer)

# 3. Tune learning rate
tuner.lr_find(...)

# 4. Tune batch size
tuner.scale_batch_size(...)

# Fit using tuned settings
trainer.fit(model)

You can find more documentation about the tuner here.

Standardized device selection and automation

In Lightning 1.6.0, we simplified the Trainer’s signature by collapsing the four accelerator-specific device arguments into a single one called Trainer(devices=...). In 2.0, we are now dropping support for the old Trainer(gpus=..., tpu_cores=..., ipus=..., num_processes=...).

Before:

import lightning as L

# Multiple CPU processes
trainer = L.Trainer(num_processes=2)

# GPU selection
trainer = L.Trainer(gpus=4)

# TPU core selection
trainer = L.Trainer(tpu_cores=[1])

# Graphcore IPU devices
trainer = L.Trainer(ipus=1)

Now:

import lightning as L

# Select devices and accelerator separately
trainer = L.Trainer(accelerator="cpu", devices=2)
trainer = L.Trainer(accelerator="gpu", devices=4)
trainer = L.Trainer(accelerator="tpu", devices=[1])
trainer = L.Trainer(accelerator="ipu", devices=1)

# Or let Lightning detect the accelerator automatically
trainer = L.Trainer(accelerator="auto", devices=2)

# `accelerator="auto"` is the default
trainer = L.Trainer(devices=2)

In addition, the Trainer(auto_select_gpus=...) also got removed (#16184). This was a problematic feature that was not well documented, often misundertood, and lead to DDP stalls due to race conditions in the device selection. It is recommended to use the devices="auto" instead or, if the exact previous behavior is desired, use the utility function.

Before:

import lightning as L

trainer = L.Trainer(auto_select_gpus=True, devices=2)

Now:

import lightning as L

# Recommended
trainer = L.Trainer(devices="auto")

# Alternatively, use the utility function (with care!)
from lightning.pytorch.accelerators import find_usable_cuda_devices

trainer = L.Trainer(devices=find_usable_cuda_devices(2), strategy="ddp_spawn")

The slow and clunky data-parallel strategy (#16748)

PyTorch and Lightning have discouraged the use of the old-style DataParallel (DP) training for more than two years now. This method of multi-GPU training is slow and has many limitations that impact users, even more so in Lightning. Since DP has fallen out of fashion, and other strategies (DDP, DeepSpeed, etc.) have emerged without the same limitations, Lightning is now dropping DP completely.

Before:

import lightning as L

# Enables DP, but with many limitations
trainer = L.Trainer(strategy="dp", devices=8)

Now:

import lightning as L

# DDP is recommended for multi-GPU training
trainer = L.Trainer(strategy="ddp", devices=8)

# If devices > 1, it selects ddp for you
trainer = L.Trainer(devices=8)

# This is the same
trainer = L.Trainer(strategy="auto", devices=8)

Note that the DDP strategy now gets selected automatically when multiple devices are selected.
In Jupyter notebooks (Google Colab, Kaggle, etc.), Lightning will automatically select a fork-based DDP strategy (strategy="ddp_notebook").

By extension, the LightningModule no longer has the training_step_end()validation_step_end() and test_step_end() hooks because they were only used for reducing the outputs in DP. If you have other code that needs to run at the end of the step, you can migrate it to the corresponding *_batch_end hook for example.

Loop customization has a new face

Loop customization was our attempt to make the Trainer more customizable, hackable, and extensible. The modularity it brought had many pros, but it turned out that many users didn’t need customization and favored readability and simplicity instead. Hence, the concept of loop customization was completely removed from 2.0.

For users who like build custom training loops, there is now a new paradigm with Lightning Fabric and “Build your own Trainer” (BYOT). Check out the Fabric documentation and the super hackable, 500-lines trainer template.

Mixed precision overhaul (#16783)

Based on feedback, we decided to make the names for precision backends in Trainer(precision=...) clearer and less ambiguous. For example, the previous notation Trainer(precision=16) (which is still allowed to be used) suggested to some users that all of the weights and tensors would be stored in a float16 format, which is not true. To solve this misunderstanding, we now distinguish these modes with “true” and “mixed” suffixes in the names:

Recommended valueShort form
“64-true”“64”, 64
“32-true”“32”, 32
“16-mixed”“16”, 16
“bf16-mixed”“bf16”

All documentation and examples are now recommending the new, less ambiguous names.

Apex mixed precision gets replaced with AMP (#16149)

In addition to reworking the precision settings, we removed the NVIDIA/Apex integration which got deprecated in 1.9.0. Apex itself has deprecated the mixed precision module and recommends the native torch.amp module in PyTorch.
For Lightning, this means you should switch to Trainer(precision="16-mixed").

Before:

import lightning as L

# This required installing nvidia/apex
trainer = L.Trainer(amp_backend="apex", amp_level="O2")

Now:

import lightning as L

# Rely on PyTorch's native mixed precision
trainer = L.Trainer(precision="16-mixed")

Native FSDP replaces Fairscale FSDP (#16400)

With the recent annoucement that FSDP becomes production ready in PyTorch 2.0, we are dropping the support for the experimental Fairscale version of FSDP and go all in on the native implementation instead.

Before:

import lightning as L

# Short-hand names (fairscale)
trainer = L.Trainer(strategy="ddp_sharded" | "ddp_fully_sharded" | "fsdp")

# Shorthand names (native PyTorch)
trainer = L.Trainer(strategy="fsdp_native" | "fsdp_native_full_shard_offload")

# Or using the strategy instance
from lightning.pytorch.strategies import DDPShardedStrategy, DDPFullyShardedStrategy

trainer = L.Trainer(strategy=DDPShardedStrategy(...))
trainer = L.Trainer(strategy=DDPFullyShardedStrategy(...))

Now:

import lightning as L

# Shorthand names (native PyTorch)
trainer = L.Trainer(strategy="fsdp" | "fsdp_cpu_offload")

# Or using the strategy instance
from lightning.pytorch.strategies import FSDPStrategy

trainer = L.Trainer(strategy=FSDPStrategy(...))

Resuming from checkpoints (#16167)

Resuming from a checkpoint is no longer done by specifying the filename in the Trainer constructor. The Trainer(resume_from_checkpoint=...) argument was renamed and moved to the individual Trainer methods.

Before:

import lightning as L

trainer = L.Trainer(resume_from_checkpoint="/path/to/checkpoint.ckpt")
trainer.fit(model)
trainer.test(model)
...

Now:

import lightning as L

# 1. Remove `resume_from_checkpoint` from the Trainer
trainer = L.Trainer()

# 2. Add the path to any of the trainer methods you want to run
trainer.fit(model, ckpt_path="/path/to/checkpoint.ckpt")
trainer.test(model, ckpt_path="/path/to/checkpoint.ckpt")
...

We also added support for setting the checkpoint path statefully:

Now:

import lightning as L

trainer = L.Trainer()

# set the checkpoint path with a setter
trainer.ckpt_path = "/path/to/checkpoint.ckpt"
trainer.fit(model)
# remember to clear it before continuing
trainer.ckpt_path = None
trainer.test()
...

Logging the loss to the progress bar (#16192)

In previous versions of Lightning, the Trainer used to automatically compute a running mean of the training loss and log it to the progress bar. We are removing this feature to favor speed over automation. The loss will no longer appear in the progress bar unless the user explicitly adds it.

Before:

Epoch 8:  53%|█████    | 17/32 [5.13/s, v_num=2, loss=0.5643]

Now:

Epoch 8:  53%|█████    | 17/32 [5.13/s, v_num=2]
def training_step(self, batch, batch_idx):
    ...

    # Add this if you need the loss to be displayed in the progress bar
    self.log("loss", loss, prog_bar=True, on_step=True)
    return loss

The brittle argument parsing utilities (#16708)

In previous versions of Lightning, the Trainer.add_argparse_args and Trainer.from_argparse_args utility functions helped the user construct a commandline interface using Python’s argparse library. Over time, new functionality was added to the Trainer that allowed most of the frequently used arguments to accept many different types. This made the arg-parsing utilities very brittle and limited in its scope. Ultimately, we decided to remove it in 2.0 in favor of more robust solutions like LightningCLI or third-party commandline tools (ClickHydra, etc.).

Before:

import lightning as L
import argparse


parser = argparse.ArgumentParser()
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()
trainer = L.Trainer.from_argparse_args(args)
...

Now:

Example using the LightningCLI:

from lightning.pytorch.cli import LightningCLI

LightningCLI(MyLitModel, MyLitDataModule)

Run it with python train.py fit --trainer.max_epochs=2 for example.
Alternatively, you can add the argparse arguments you want manually:

import lightning as L
import argparse

parser = argparse.ArgumentParser()

# 1. Add the arguments you need
parser.add_argument("--accelerator", type=str, default="cuda")
args = parser.parse_args()

# 2. Pass them into the Trainer
trainer = L.Trainer(**vars(args))
...

Gradient norm tracking (#16745)

Until now, automatic tracking of the gradient norm was possible through the Trainer(track_grad_norm=...) argument. This functionality has now moved to a utility function and can be easily added in two lines of code to the LightningModule.

Before:

import lightning as L

trainer = L.Trainer(track_grad_norm=2)

# Optionally customize logging in LightningModule
def log_grad_norm(self, grad_norm_dict):
    self.log(...)

Now:

from lightning.pytorch.utilities import grad_norm

# 1. Override LightningModule hook
def on_before_optimizer_step(self, optimizer):
    # 2. Inspect the (unscaled) gradients here
    self.log_dict(grad_norm(self, norm_type=2))

This enables users to customize how the gradient norm is computed and logged, without needing to wrangle with the Trainer or override the log_grad_norm hook.

Speeding up DDP with find_unused_parameters (#16611)

When your model doesn’t use all parameters in the forward-backward pass, PyTorch’s DDP wrapper will freak out and inform you about it with an error. For this reason, and for the longest time, Lightning has set find_unused_parameters=True so that this error can be avoided. However, depending on the model, this can have significant performance impact (slowing down your training). With Lightning 2.0, we switch this flag back to find_unused_parameters=False (default) and favor speed over convenience.

Before:

import lightning as L

# Previously, you had to manually override this for speed
trainer = L.Trainer(strategy="ddp_find_unused_parameters_false")

Now:

import lightning as L

# Now, you get the best speed by default for all ddp variants
trainer = L.Trainer(strategy="ddp")
trainer = L.Trainer(strategy="ddp_spawn")

# We now have these if you need them
trainer = L.Trainer(strategy="ddp_find_unused_parameters_true")

However, it can still happen that users run into problems with unused parametes in their model. Lightning now overrides PyTorch’s default error message with a custom one to help users resolve the problem.

Sampler replacement in distributed strategies (#16829)

We renamed the Trainer(replace_sampler_ddp=...) argument to Trainer(use_distributed_sampler=...) to communicate that the sampler gets created not only for the DDP strategies, but all distributed strategies that need it. Its function is still the same as before, and most users don’t need to change its default value.

Before:

import lightning as L

# Hmm, I wonder, does this only apply to DDP?
trainer = L.Trainer(replace_sampler_ddp=True)

Now:

import lightning as L

# Aha! Lightning uses a distributed sampler by default, got it!
trainer = L.Trainer(use_distributed_sampler=True)

PyTorch 1.10 and Python 3.7 no longer supported (#16492#16579)

Whenever there is a new PyTorch version, or a new Python version, it is time to say good bye to the oldest one we support. With the introduction of PyTorch 2.0, we are dropping support for PyTorch 1.10 to continue our support window of the four latest versions: 1.11, 1.12, 1.13 and 2.0. Similarly, with Lightning 2.0 we support the latest three versions of Python: 3.8, 3.9, and 3.10 (3.11 is coming soon).

If you are interested which range of PyTorch or Python versions a particular Lightning version supports, see our compatibility matrix.

Removed experimental fault-tolerance support (#16516#16533)

To simplify reading and debugging the codebase, we removed the experimental support for fault-tolerance which was under the PL_FAULT_TOLERANT_TRAINING= environment flag. We are looking at ways to re-implement this. If you are interested in this feature, don’t hesitate to reach out to us or create a feature request.

Some of the features it included were now ported to stable APIs. One feature is the new callback to save a checkpoint on exception:

Now:

import lightning as L
from lightning.pytorch.callbacks import OnExceptionCheckpoint

on_exception = OnExceptionCheckpoint(".")
trainer = L.Trainer(callbacks=on_exception)

Another feature is automatic SIGTERM handling:

Now:

import lightning as L

trainer = L.Trainer()

if trainer.received_sigterm:
    ...

Removed support for self.log()ing a dictionary (#16389)

Our logging mechanism previously supported log("key", {"something": 123}) (not using log_dict). However, this added significant complexity to the implementation with little benefit, as these keys could not be monitored by our Callbacks and most logger implementations do not support this notation. If you were using this feature with a compatible logger, you can still publish data directly to the Logger using self.logger.log_metrics().

Removed trainer.reset_*_dataloader() methods (#16726)

These methods were not intended for public use as they could leave the Trainer in an undefined state. As a result, we have removed them. To achieve the same functionality, you can use the Trainer(reload_dataloaders_every_n_epochs=...) argument.

Removed the Trainer(move_metrics_to_cpu=True) argument (#16358)

This flag was designed to reduce device memory allocation at the end of an epoch. However, due to design issues, it could leave non-CPU runs in a non-functional state. Since the memory savings were minimal compared to other components and users can still manually control their metrics, we decided to remove this flag.

Separate the Gradient Accumulation Scheduler from Trainer (#16729)

We removed support for passing a scheduling dictionary to Trainer(accumulate_grad_batches=...). The same functionality can still be achieved by simply passing the callback in. This simplifies the Trainer and the overall validation logic.

Before:

trainer = Trainer(accumulate_grad_batches={"1": 5, "10": 3})

Now:

from lightning.pytorch.callbacks import GradientAccumulationScheduler
trainer = Trainer(callbacks=GradientAccumulationScheduler({"1": 5, "10": 3}))

Fabric

LightningLite is dead, long live Lightning Fabric!

Over time, LightningLite has evolved from a simple onboarding tool to a powerful toolbox enabling users to build performant and hackable trainers. In 1.9.0, we renamed it to Lightning Fabric and gave users early access to its new features. In 2.0, we are dropping LightningLite from the package completely.

Before:

# There were two different ways of importing Lite in <= 1.9.0
from pytorch_lighting.lite import LightningLite
from lightning.lite import LightningLite

# You had to subclass LightningLite and implement `run()`
class MyTrainer(LightningLite):
    
    def run(self):
        ...
        self.backward(loss)
        ...

lite = LightningLite(...)
lite.run()

Now:

# 1. Import Fabric directly from the lightning package
import lightning as L

# 2. Instantiate Fabric directly, without subclassing
fabric = L.Fabric(...)

# 3. Use it in your training loop
fabric.backward(loss)
...

Learn more about Fabric and what it can do in the new docs!

CHANGELOG

PyTorch

Added
  • Added migration logic to warn about checkpoints with apex AMP state (#16161)
  • Added the Trainer.ckpt_path = ... setter to statefully set the checkpoint path to load. This can act as a replacement for the removed Trainer(resume_from_checkpoint=...) flag (#16187)
  • Added an argument include_cuda in pytorch_lightning.utilities.seed.isolate_rng to disable managing torch.cuda‘s rng (#16423)
  • Added Tuner.lr_find(attr_name=...) to specify custom learning rate attribute names (#16462)
  • Added an OnExceptionCheckpoint callback to save a checkpoint on exception (#16512)
  • Added support for running the MLFlowLogger with the mlflow-skinny package (16513)
  • Added a Trainer.received_sigterm property to check whether a SIGTERM signal was received (#16501)
  • Added support for cascading a SIGTERM signal to launched processes after the launching process (rank 0) receives it (#16525)
  • Added a kill method to launchers to kill all launched processes (#16525)
  • Added suffix option to DDP strategy names to enable find_unused_parameters=True, for example strategy="ddp_find_unused_parameters_true" (#16611)
  • Added a new method Strategy.on_exception to the strategy base interface (#16646)
  • Added support for predict_step(dataloader_iter, batch_index) (#16726)
  • Added support for arbitrary iterables as dataloaders (#16726)
  • Added “sequential” mode support to CombinedLoader to consume multiple iterables in sequence (#16743#16784)
  • Added “max_size” mode support to CombinedLoader to consume multiple iterables entirely without cycling (#16939
  • Added a Trainer(barebones=True) argument where all features that may impact raw speed are disabled (#16854)
  • Added support for writing logs remote file systems on CSVLoggers. (#16880)
  • Added DDPStrategy(start_method=...) argument, defaulting to ‘popen’ (#16809)
  • Added checks for whether the iterables used by the loops are valid (#17007)
Changed
  • The Trainer’s signal handlers are now registered for trainer.{validate,test,predict} (#17017)
  • Renamed ProgressBarBase to ProgressBar (#17058)
  • The Trainer now chooses accelerator="auto", strategy="auto", devices="auto" as defaults (#16847)
  • “Native” suffix removal (#16490)
  • strategy="fsdp_native" is now strategy="fsdp"
  • strategy="fsdp_native_full_shard_offload" is now strategy="fsdp_cpu_offload"
  • pytorch_lightning.strategies.fully_sharded_native.DDPFullyShardedNativeStrategy is now pytorch_lightning.strategies.fsdp.FSDPStrategy
  • pytorch_lightning.plugins.precision.fsdp_native_native_amp.FullyShardedNativeNativeMixedPrecisionPlugin is now pytorch_lightning.plugins.precision.fsdp.FSDPMixedPrecisionPlugin
  • pytorch_lightning.plugins.precision.native_amp is now pytorch_lightning.plugins.precision.amp
  • NativeSyncBatchNorm is now TorchSyncBatchNorm
  • Changed the default of LearningRateFinder(update_attr=...) and Tuner.lr_find(update_attr=...) to True (#16462)
  • Renamed the pl.utilities.exceptions.GracefulExitException to SIGTERMException (#16501)
  • The Callback.on_train_epoch_end hook now runs after the LightningModule.on_train_epoch_end hook for instances of EarlyStopping and Checkpoint callbacks (#16567)
  • The LightningModule.{un}toggle_optimizer methods no longer accept a optimizer_idx argument to select the relevant optimizer. Instead, the optimizer object can be passed in directly (#16560)
  • Manual optimization is now required for working with multiple optimizers (#16539)
  • DDP’s find_unused_parameters now defaults to False (#16611)
  • The strategy selected by accelerator="hpu" now defaults to find_unused_parameters=False (#16611)
  • The main progress bar displayed during training no longer includes the combined progress for validation (#16695)
  • Renamed TQDMProgressBar.main_progress_bar to TQDMProgressBar.train_progress_bar (#16695)
  • Marked the progress tracking classes as protected (#17009)
  • Marked the lightning.pytorch.trainer.configuration_validator.verify_loop_configurations function as protected (#17009)
  • Marked the lightning.pytorch.utiltiies.distirbuted.register_ddp_comm_hook function as protected (#17009)
  • Marked lightning.pytorch.utilities.supporters.CombinedDataset as protected (#16714)
  • Marked the {Accelerator,Signal,Callback,Checkpoint,Data,Logger}Connector classes as protected (#17008)
  • Marked the lightning.pytorch.trainer.connectors.signal_connector.HandlersCompose class as protected (#17008)
  • Disabled strict loading in multiprocessing launcher (“ddp_spawn”, etc.) when loading weights back into the main process (#16365)
  • Renamed CombinedLoader.loaders to CombinedLoader.iterables (#16743)
  • Renamed Trainer(replace_sampler_ddp=...) to Trainer(use_distributed_sampler=...) (#16829)
  • Moved the CombinedLoader class from lightning.pytorch.trainer.supporters to lightning.pytorch.combined_loader (#16819)
  • The top-level loops now own the data sources and combined dataloaders (#16726)
  • The trainer.*_dataloader properties now return what the user returned in their LightningModule.*_dataloader() hook (#16726#16800)
  • The dataloader_idx argument is now optional for the on_{validation,test,predict}_batch_{start,end} hooks. Remove it or default it to 0 if you don’t use multiple dataloaders (#16753)
  • Renamed TPUSpawnStrategy to XLAStrategy (#16781)
  • Renamed strategy='tpu_spawn' to strategy='xla' and strategy='tpu_spawn_debug' to strategy='xla_debug' (#16781)
  • Changed arguments for precision settings (from [64|32|16|bf16] to [“64-true”|”32-true”|”16-mixed”|”bf16-mixed”]) (#16783)
  • When using multiple devices, the strategy now defaults to “ddp” instead of “ddp_spawn” when none is set (#16780)
  • The selection Trainer(strategy="ddp_spawn", ...) no longer falls back to “ddp” when a cluster environment gets detected (#16780)
  • Predict’s custom BatchSampler that tracks the batch indices no longer consumes the entire batch sampler at the beginning (#16826)
  • Gradient norm tracking with track_grad_norm no longer rounds the norms to 4 digits, but instead logs them at full resolution (#16877)
  • Merged the DDPSpawnStrategy into DDPStrategy (#16809)
  • The NeptuneLogger now requires neptune>=1.0.0 (#16888)
  • Changed minimum supported version of rich from 10.14.0 to 12.13.0 (#16798)
  • Removed the lightning.pytorch.overrides.torch_distributed.broadcast_object_list function (#17011)
  • The ServableModule is now an abstract interface (#17000)
  • The psutil package is now required for CPU monitoring (#17010)
  • The Trainer no longer accepts positional arguments to (#17022)
Removed
  • Removed support for PyTorch 1.10 (#16492)
  • Removed support for Python 3.7 (#16579)
  • Removed the pytorch_lightning.lite module in favor of lightning_fabric (#15953)
  • nvidia/apex removal (#16149)
    • Removed pytorch_lightning.plugins.NativeMixedPrecisionPlugin in favor of pytorch_lightning.plugins.MixedPrecisionPlugin
    • Removed the LightningModule.optimizer_step(using_native_amp=...) argument
    • Removed the Trainer(amp_backend=...) argument
    • Removed the Trainer.amp_backend property
    • Removed the Trainer(amp_level=...) argument
    • Removed the pytorch_lightning.plugins.ApexMixedPrecisionPlugin class
    • Removed the pytorch_lightning.utilities.enums.AMPType enum
    • Removed the DeepSpeedPrecisionPlugin(amp_type=..., amp_level=...) arguments
  • Removed Trainer(strategy='horovod') support (#16150)
  • FairScale removal (in favor of PyTorch’s FSDP implementation) (#16400)
    • Removed the pytorch_lightning.overrides.fairscale.LightningShardedDataParallel class
    • Removed the pytorch_lightning.plugins.precision.fully_sharded_native_amp.FullyShardedNativeMixedPrecisionPlugin class
    • Removed the pytorch_lightning.plugins.precision.sharded_native_amp.ShardedNativeMixedPrecisionPlugin class
    • Removed the pytorch_lightning.strategies.fully_sharded.DDPFullyShardedStrategy (fsdp) class
    • Removed the pytorch_lightning.strategies.sharded.DDPShardedStrategy (ddp_sharded) class
    • Removed the pytorch_lightning.strategies.sharded_spawn.DDPSpawnShardedStrategy (ddp_sharded_spawn) class
  • Removed legacy device arguments in Trainer (#16171)
    • Removed the Trainer(gpus=...) argument
    • Removed the Trainer(tpu_cores=...) argument
    • Removed the Trainer(ipus=...) argument
    • Removed the Trainer(num_processes=...) argument
  • Removed the deprecated pytorch_lightning.utilities.AllGatherGrad class (#16360)
  • Removed the deprecated resume_from_checkpoint Trainer argument (#16167)
  • Removed the deprecated pytorch_lightning.profiler module (#16359)
  • Removed deadlock detection / process reconciliation (PL_RECONCILE_PROCESS=1) (#16204)
  • Removed the {training,validation,test}_epoch_end hooks which would retain step outputs in memory. Alternative implementations are suggested by implementing their on_*_epoch_end hooks instead (#16520)
  • Removed the outputs argument from the on_predict_epoch_end hook. You can access them via trainer.predict_loop.predictions (#16655)
  • Removed support for the experimental PL_FAULT_TOLERANT_TRAINING environment flag (#16516#16533)
  • Removed the deprecated LightningCLI arguments (#16380)
    • save_config_filename
    • save_config_overwrite
    • save_config_multifile
    • description
    • env_prefix
    • env_parse
  • Removed the deprecated pl.strategies.utils.on_colab_kaggle function (#16437)
  • Removed the deprecated code in:
    • pl.core.mixins (#16424)
    • pl.utilities.distributed (#16390)
    • pl.utilities.apply_func (#16413)
    • pl.utilities.xla_device (#16404)
    • pl.utilities.data (#16440)
    • pl.utilities.device_parser (#16412)
    • pl.utilities.optimizer (#16439)
    • pl.utilities.seed (#16422)
    • pl.utilities.cloud_io (#16438)
  • Removed the deprecated Accelerator.setup_environment method (#16436)
  • Mark the forward_module argument as required (#16386)
    • Removed the deprecated pl_module argument from the distributed module wrappers
    • Removed the deprecated pytorch_lightning.overrides.base.unwrap_lightning_module function
    • Removed the pytorch_lightning.overrides.distributed.LightningDistributedModule class
    • Removed the deprecated pytorch_lightning.overrides.fairscale.unwrap_lightning_module_sharded function
    • Removed the pytorch_lightning.overrides.fairscale.LightningDistributedModule class
  • Removed the deprecated automatic GPU selection (#16184)
    • Removed the Trainer(auto_select_gpus=...) argument
    • Removed the pytorch_lightning.tuner.auto_gpu_select.{pick_single_gpu,pick_multiple_gpus} functions
  • Removed support for loop customization
    • Removed Loop.replace() (#16361)
    • Removed Loop.connect() (#16384)
    • Removed the trainer.{fit,validate,test,predict}_loop properties (#16384)
    • Removed the default Loop.run() implementation (#16384)
    • The loop classes are now marked as protected (#16445)
    • The fetching classes are now marked as protected (#16664)
  • The lightning.pytorch.overrides.distributed.IndexBatchSamplerWrapper class is now marked as protected (#16826)
  • Removed the DataLoaderLoopEvaluationEpochLoop, and PredictionEpochLoop classes (#16726)
  • Removed trainer.reset_*_dataloader() methods in favor of Loop.setup_data() for the top-level loops (#16726)
  • Removed special support for truncated backpropagation through time (TBPTT) (#16172)
    • Removed the LightningModule.truncated_bptt_steps attribute
    • Removed the LightningModule.tbptt_split_batch hook
    • The LightningModule.training_step no longer accepts a hiddens argument
    • Removed the pytorch_lightning.loops.batch.TrainingBatchLoop
    • Removed the FitLoop.split_idx property
    • Removed the LoggerConnector.on_train_split_start method
  • Removed the experimental PL_INTER_BATCH_PARALLELISM environment flag (#16355)
  • Removed the Trainer(move_metrics_to_cpu=True) argument (#16358)
  • Removed the LightningModule.precision attribute (#16203)
  • Removed the automatic addition of a moving average of the training_step loss in the progress bar. Use self.log("loss", ..., prog_bar=True) instead. (#16192)
  • Removed support for passing a dictionary value to self.log() (#16389)
  • Removed Trainer.model setter (#16462)
  • Removed the argument Trainer(multiple_trainloader_mode=...). You can use CombinedLoader(..., mode=...) directly now (#16800)
  • Removed the unused lightning.pytorch.utilities.finite_checks.print_nan_gradients function (#16682)
  • Removed the unused lightning.pytorch.utilities.finite_checks.detect_nan_parameters function (#16682)
  • Removed the unused lightning.pytorch.utilities.parsing.flatten_dict function (#16744)
  • Removed the unused lightning.pytorch.utilities.metrics.metrics_to_scalars function (#16681)
  • Removed the unused lightning.pytorch.utilities.supporters.{SharedCycleIteratorState,CombinedLoaderIterator} classes (#16714)
  • Tuner removal
    • Removed the deprecated trainer.tuning property (#16379)
    • Removed the deprecated TrainerFn.TUNING and RunningStage.TUNING enums (#16379)
    • Removed Trainer.tune() in favor of Tuner(trainer).{lr_find,scale_batch_size} (#16462)
    • Removed Trainer(auto_scale_batch_size=...) in favor of Tuner(trainer).scale_batch_size() (#16462)
    • Removed Trainer(auto_lr_find=...) in favor of Tuner(trainer).lr_find() (#16462)
  • Removed the on_tpu argument from LightningModule.optimizer_step hook (#16537)
  • Removed the using_lbfgs argument from LightningModule.optimizer_step hook (#16538)
  • Removed the Trainer.data_parallel property. Use isinstance(trainer.strategy, ParallelStrategy) instead (#16703)
  • Removed the Trainer.prediction_writer_callbacks property (#16759)
  • Removed support for multiple optimizers in automatic optimization mode (#16539)
    • Removed opt_idx argument from BaseFinetuning.finetune_function callback method
    • Removed opt_idx argument from Callback.on_before_optimizer_step callback method
    • Removed optimizer_idx as an optional argument in LightningModule.training_step
    • Removed optimizer_idx argument from LightningModule.on_before_optimizer_step
    • Removed optimizer_idx argument from LightningModule.configure_gradient_clipping
    • Removed optimizer_idx argument from LightningModule.optimizer_step
    • Removed optimizer_idx argument from LightningModule.optimizer_zero_grad
    • Removed optimizer_idx argument from LightningModule.lr_scheduler_step
    • Removed support for declaring optimizer frequencies in the dictionary returned from LightningModule.configure_optimizers
    • Removed arguments optimizer and optimizer_idx from LightningModule.backward
    • Removed optimizer_idx argument from PrecisionPlugin.optimizer_step and all of its overrides in subclasses
    • Removed optimizer_idx argument from PrecisionPlugin.{optimizer_step,backward} and all of its overrides in subclasses
    • Removed optimizer_idx argument from Strategy.{optimizer_step,backward} and all of its overrides in subclasses
    • Removed Trainer.optimizer_frequencies attribute
  • Removed Strategy.dispatch (#16618)
  • Removed PrecisionPlugin.dispatch (#16618)
  • Removed legacy argparse utilities (#16708)
    • Removed LightningDataModule methods: add_argparse_args()from_argparse_args()parse_argparser()get_init_arguments_and_types()
    • Removed class methods from Trainer: default_attributes()from_argparse_args()parse_argparser()match_env_arguments()add_argparse_args()
    • Removed functions from lightning.pytorch.utilities.argparsefrom_argparse_args()parse_argparser()parse_env_variables()get_init_arguments_and_types()add_argparse_args()
    • Removed functions from lightning.pytorch.utilities.parsingimport str_to_bool()str_to_bool_or_int()str_to_bool_or_str()
  • Removed support for passing a scheduling dictionary to Trainer(accumulate_grad_batches=...) (#16729)
  • Removed support for DataParallel (strategy='dp') and the LightningParallelModule-Wrapper, (#16748)
  • Removed the unused lightning.pytorch.utilities.supporters.{SharedCycleIteratorState,CombinedLoaderIterator} classes (#16714)
  • Removed ProgressBarBase.{train_batch_idx,val_batch_idx,test_batch_idx,predict_batch_idx} properties (#16760)
  • Removed the fit_loop.{min,max}_steps setters (#16803)
  • Removed the Trainer(track_grad_norm=...) argument (#16745)
  • Removed the LightningModule.log_grad_norm() hook method (#16745)
  • Removed the QuantizationAwareTraining callback (#16750)
  • Removed the ColossalAIStrategy and ColossalAIPrecisionPlugin in favor of the new lightning-colossalai package (#16757#16778)
  • Removed the training_step_endvalidation_step_end, and test_step_end hooks from the LightningModule in favor of the *_batch_end hooks (#16791)
  • Removed the lightning.pytorch.strategies.DDPSpawnStrategy in favor of DDPStrategy(start_method='spawn') (merged both classes) (#16809)
  • Removed registration of ShardedTensor state dict hooks in LightningModule.__init__ with torch>=2.1 (#16892)
  • Removed the lightning.pytorch.core.saving.ModelIO class interface (#16999)
  • Removed the unused lightning.pytorch.utilities.memory.get_model_size_mb function (#17001)
Fixed
  • Fixed an issue where DistributedSampler.set_epoch wasn’t getting called during trainer.predict (#16785#16826)
  • Fixed an issue with comparing torch versions when using a version of torch built from source (#17030)
  • Improved the error message for installing tensorboard or tensorboardx (#17053)

Fabric

Added
  • Added Fabric.all_reduce (#16459)
  • Added support for saving and loading DeepSpeed checkpoints through Fabric.save/load() (#16452)
  • Added support for automatically calling set_epoch on the dataloader.batch_sampler.sampler (#16841)
  • Added support for writing logs to remote file systems with the CSVLogger (#16880)
  • Added support for frozen dataclasses in the optimizer state (#16656)
  • Added lightning.fabric.is_wrapped to check whether a module, optimizer, or dataloader was already wrapped by Fabric (#16953)
Changed
  • Fabric now chooses accelerator="auto", strategy="auto", devices="auto" as defaults (#16842)
  • Checkpoint saving and loading redesign (#16434)
    • Changed the method signatrue of Fabric.save and Fabric.load
    • Changed the method signature of Strategy.save_checkpoint and Fabric.load_checkpoint
    • Fabric.save accepts a state that can contain model and optimizer references
    • Fabric.load can now load state in-place onto models and optimizers
    • Fabric.load returns a dictionary of objects that weren’t loaded into the state
    • Strategy.save_checkpoint and Fabric.load_checkpoint are now responsible for accessing the state of the model and optimizers
  • DataParallelStrategy.get_module_state_dict() and DDPStrategy.get_module_state_dict() now correctly extracts the state dict without keys prefixed with ‘module’ (#16487)
  • “Native” suffix removal (#16490)
    • strategy="fsdp_full_shard_offload" is now strategy="fsdp_cpu_offload"
    • lightning.fabric.plugins.precision.native_amp is now lightning.fabric.plugins.precision.amp
  • Enabled all shorthand strategy names that can be supported in the CLI (#16485)
  • Renamed strategy='tpu_spawn' to strategy='xla' and strategy='tpu_spawn_debug' to strategy='xla_debug' (#16781)
  • Changed arguments for precision settings (from [64|32|16|bf16] to [“64-true”|”32-true”|”16-mixed”|”bf16-mixed”]) (#16767)
  • The selection Fabric(strategy="ddp_spawn", ...) no longer falls back to “ddp” when a cluster environment gets detected (#16780)
  • Renamed setup_dataloaders(replace_sampler=...) to setup_dataloaders(use_distributed_sampler=...) (#16829)
Removed
  • Removed support for PyTorch 1.10 (#16492)
  • Removed support for Python 3.7 (#16579)
Fixed
  • Fixed issue where the wrapped dataloader iter() would be called twice (#16841)
  • Improved the error message for installing tensorboard or tensorboardx (#17053)

App

Added
  • Added --zip option to the lightning cp command to copy content from the Cloud Platform Filesystem as a zipfile
Changed
  • Changed minimum supported version of rich from 10.14.0 to 12.13.0 (#16798)
Removed
  • Removed support for Python 3.7 (#16579)

Full commit list1.9.0...2.0.0

Contributors

Veteran

@aniketmaurya @Atharva-Phatak @awaelchli @Borda @carmocca @dmitsf @edenlightning @hhsecond @janpawlowskiof @justusschock @krshrimali @leoleoasd @martenlienen @narJH27 @rusmux @SauravMaheshkar @shenoynikhil @tshu-w @wouterzwerink

New

@coreyjadams @dconathan @RuRo @sergeevii123 @BrianPulfer @akkefa @belerico @tupini07 @jihoonkim2100 @eamonn-zh @lightningforever @mauvilsa @muhammadanas0716 @AdityaKane2001 @dtuit @themattinthehatt @janpawlowskiof @rusmux @Erotemic @janEbert