Fabric

class lightning.fabric.fabric.Fabric(*, accelerator='auto', strategy='auto', devices='auto', num_nodes=1, precision=None, plugins=None, callbacks=None, loggers=None)[source]

Bases: object

Fabric accelerates your PyTorch training or inference code with minimal changes required.

  • Automatic placement of models and data onto the device.

  • Automatic support for mixed and double precision (smaller memory footprint).

  • Seamless switching between hardware (CPU, GPU, TPU) and distributed training strategies (data-parallel training, sharded training, etc.).

  • Automated spawning of processes, no launch utilities required.

  • Multi-node support.

Parameters:
  • accelerator (Union[str, Accelerator]) – The hardware to run on. Possible choices are: "cpu", "cuda", "mps", "gpu", "tpu", "auto".

  • strategy (Union[str, Strategy]) – Strategy for how to run across multiple devices. Possible choices are: "dp", "ddp", "ddp_spawn", "deepspeed", "fsdp".

  • devices (Union[List[int], str, int]) – Number of devices to train on (int), which GPUs to train on (list or str), or "auto". The value applies per node.

  • num_nodes (int) – Number of GPU nodes for distributed training.

  • precision (Union[Literal[64, 32, 16], Literal['transformer-engine', 'transformer-engine-float16', '16-true', '16-mixed', 'bf16-true', 'bf16-mixed', '32-true', '64-true'], Literal['64', '32', '16', 'bf16'], None]) – Double precision ("64"), full precision ("32"), half precision AMP ("16-mixed"), or bfloat16 precision AMP ("bf16-mixed").

  • plugins (Union[Precision, ClusterEnvironment, CheckpointIO, List[Union[Precision, ClusterEnvironment, CheckpointIO]], None]) – One or several custom plugins

  • callbacks (Union[List[Any], Any, None]) – A single callback or a list of callbacks. A callback can contain any arbitrary methods that can be invoked through call() by the user.

  • loggers (Union[Logger, List[Logger], None]) – A single logger or a list of loggers. See log() for more information.

_setup_dataloader(dataloader, use_distributed_sampler=True, move_to_device=True)[source]

Set up a single dataloader for accelerated training.

Parameters:
  • dataloader (DataLoader) – The dataloader to accelerate.

  • use_distributed_sampler (bool) – If set True (default), automatically wraps or replaces the sampler on the dataloader for distributed training. If you have a custom sampler defined, set this argument to False.

  • move_to_device (bool) – If set True (default), moves the data returned by the dataloader automatically to the correct device. Set this to False and alternatively use to_device() manually on the returned data.

Return type:

DataLoader

Returns:

The wrapped dataloader.

all_gather(data, group=None, sync_grads=False)[source]

Gather tensors or collections of tensors from multiple processes.

This method needs to be called on all processes and the tensors need to have the same shape across all processes, otherwise your program will stall forever.

Parameters:
  • data (Union[Tensor, Dict, List, Tuple]) – int, float, tensor of shape (batch, …), or a (possibly nested) collection thereof.

  • group (Optional[Any]) – the process group to gather results from. Defaults to all processes (world).

  • sync_grads (bool) – flag that allows users to synchronize gradients for the all_gather operation

Return type:

Union[Tensor, Dict, List, Tuple]

Returns:

A tensor of shape (world_size, batch, …), or if the input was a collection the output will also be a collection with tensors of this shape.

all_reduce(data, group=None, reduce_op='mean')[source]

Reduce tensors or collections of tensors from multiple processes.

The reduction on tensors is applied in-place, meaning the result will be placed back into the input tensor. This method needs to be called on all processes and the tensors need to have the same shape across all processes, otherwise your program will stall forever.

Parameters:
  • data (Union[Tensor, Dict, List, Tuple]) – int, float, tensor of shape (batch, …), or a (possibly nested) collection thereof. Tensor will be modified in-place.

  • group (Optional[Any]) – the process group to reduce results across. Defaults to all processes (world).

  • reduce_op (Union[ReduceOp, str, None]) – the reduction operation. Defaults to ‘mean’. Can also be a string ‘sum’ or ReduceOp. Some strategies may limit the choices here.

Return type:

Union[Tensor, Dict, List, Tuple]

Returns:

A tensor of the same shape as the input with values reduced pointwise across processes. The same is applied to tensors in a collection if a collection is given as input.

autocast()[source]

A context manager to automatically convert operations for the chosen precision.

Use this only if the forward method of your model does not cover all operations you wish to run with the chosen precision setting.

Return type:

ContextManager

backward(tensor, *args, model=None, **kwargs)[source]

Replaces loss.backward() in your training loop. Handles precision and automatically for you.

Parameters:
  • tensor (Tensor) – The tensor (loss) to back-propagate gradients from.

  • *args (Any) – Optional positional arguments passed to the underlying backward function.

  • model (Optional[_FabricModule]) – Optional model instance for plugins that require the model for backward().

  • **kwargs (Any) – Optional named keyword arguments passed to the underlying backward function.

Return type:

None

Note

When using strategy="deepspeed" and multiple models were set up, it is required to pass in the model as argument here.

barrier(name=None)[source]

Wait for all processes to enter this call.

Use this to synchronize all parallel processes, but only if necessary, otherwise the overhead of synchronization will cause your program to slow down. This method needs to be called on all processes. Failing to do so will cause your program to stall forever.

Return type:

None

broadcast(obj, src=0)[source]

Send a tensor from one process to all others.

This method needs to be called on all processes. Failing to do so will cause your program to stall forever.

Parameters:
  • obj (TypeVar(TBroadcast)) – The object to broadcast to all other members. Any serializable object is supported, but it is most efficient with the object being a Tensor.

  • src (int) – The (global) rank of the process that should send the data to all others.

Return type:

TypeVar(TBroadcast)

Returns:

The transferred data, the same value on every rank.

call(hook_name, *args, **kwargs)[source]

Trigger the callback methods with the given name and arguments.

Not all objects registered via Fabric(callbacks=...) must implement a method with the given name. The ones that have a matching method name will get called.

Parameters:
  • hook_name (str) – The name of the callback method.

  • *args (Any) – Optional positional arguments that get passed down to the callback method.

  • **kwargs (Any) – Optional keyword arguments that get passed down to the callback method.

Return type:

None

Example:

class MyCallback:
    def on_train_epoch_end(self, results):
        ...

fabric = Fabric(callbacks=[MyCallback()])
fabric.call("on_train_epoch_end", results={...})
clip_gradients(module, optimizer, clip_val=None, max_norm=None, norm_type=2.0, error_if_nonfinite=True)[source]

Clip the gradients of the model to a given max value or max norm.

Parameters:
  • module (Union[Module, _FabricModule]) – The module whose parameters should be clipped.

  • optimizer (Union[Optimizer, _FabricOptimizer]) – The optimizer referencing the parameters to be clipped.

  • clip_val (Union[float, int, None]) – If passed, gradients will be clipped to this value.

  • max_norm (Union[float, int, None]) – If passed, clips the gradients in such a way that the p-norm of the resulting parameters is no larger than the given value.

  • norm_type (Union[float, int]) – The type of norm if max_norm was passed. Can be 'inf' for infinity norm. Default is the 2-norm.

  • error_if_nonfinite (bool) – An error is raised if the total norm of the gradients is NaN or infinite.

Return type:

Optional[Tensor]

init_module(empty_init=None)[source]

Instantiate the model and its parameters under this context manager to reduce peak memory usage.

The parameters get created on the device and with the right data type right away without wasting memory being allocated unnecessarily. The automatic device placement under this context manager is only supported with PyTorch 2.0 and newer.

Parameters:

empty_init (Optional[bool]) – Whether to initialize the model with empty weights (uninitialized memory). If None, the strategy will decide. Some strategies may not support all options. Set this to True if you are loading a checkpoint into a large model.

Return type:

ContextManager

init_tensor()[source]

Tensors that you instantiate under this context manager will be created on the device right away and have the right data type depending on the precision setting in Fabric.

The automatic device placement under this context manager is only supported with PyTorch 2.0 and newer.

Return type:

ContextManager

launch(function=<function _do_nothing>, *args, **kwargs)[source]

Launch and initialize all the processes needed for distributed execution.

Parameters:
  • function (Callable[[Fabric], Any]) – Optional function to launch when using a spawn/fork-based strategy, for example, when using the XLA strategy (accelerator="tpu"). The function must accept at least one argument, to which the Fabric object itself will be passed.

  • *args (Any) – Optional positional arguments to be passed to the function.

  • **kwargs (Any) – Optional keyword arguments to be passed to the function.

Return type:

Any

Returns:

Returns the output of the function that ran in worker process with rank 0.

The launch() method should only be used if you intend to specify accelerator, devices, and so on in the code (programmatically). If you are launching with the Lightning CLI, lightning run model ..., remove launch() from your code.

The launch() is a no-op when called multiple times and no function is passed in.

load(path, state=None, strict=True)[source]

Load a checkpoint from a file and restore the state of objects (modules, optimizers, etc.)

How and which processes load gets determined by the strategy. This method must be called on all processes!

Parameters:
  • path (Union[str, Path]) – A path to where the file is located

  • state (Optional[Dict[str, Union[Module, Optimizer, Any]]]) – A dictionary of objects whose state will be restored in-place from the checkpoint path. If no state is given, then the checkpoint will be returned in full.

  • strict (bool) – Whether to enforce that the keys in state match the keys in the checkpoint.

Return type:

Dict[str, Any]

Returns:

The remaining items that were not restored into the given state dictionary. If no state dictionary is given, the full checkpoint will be returned.

load_raw(path, obj, strict=True)[source]

Load the state of a module or optimizer from a single state-dict file.

Use this for loading a raw PyTorch model checkpoint created without Fabric. This is conceptually equivalent to obj.load_state_dict(torch.load(path)), but is agnostic to the strategy being used.

Parameters:
Return type:

None

log(name, value, step=None)[source]

Log a scalar to all loggers that were added to Fabric.

Parameters:
  • name (str) – The name of the metric to log.

  • value (Any) – The metric value to collect. If the value is a torch.Tensor, it gets detached from the graph automatically.

  • step (Optional[int]) – Optional step number. Most Logger implementations auto-increment the step value by one with every log call. You can specify your own value here.

Return type:

None

log_dict(metrics, step=None)[source]

Log multiple scalars at once to all loggers that were added to Fabric.

Parameters:
  • metrics (Mapping[str, Any]) – A dictionary where the key is the name of the metric and the value the scalar to be logged. Any torch.Tensor in the dictionary get detached from the graph automatically.

  • step (Optional[int]) – Optional step number. Most Logger implementations auto-increment this value by one with every log call. You can specify your own value here.

Return type:

None

no_backward_sync(module, enabled=True)[source]

Skip gradient synchronization during backward to avoid redundant communication overhead.

Use this context manager when performing gradient accumulation to speed up training with multiple devices.

Example:

# Accumulate gradient 8 batches at a time
with fabric.no_backward_sync(model, enabled=(batch_idx % 8 != 0)):
    output = model(input)
    loss = ...
    fabric.backward(loss)
    ...

For those strategies that don’t support it, a warning is emitted. For single-device strategies, it is a no-op. Both the model’s .forward() and the fabric.backward() call need to run under this context.

Parameters:
  • module (_FabricModule) – The module for which to control the gradient synchronization.

  • enabled (bool) – Whether the context manager is enabled or not. True means skip the sync, False means do not skip.

Return type:

ContextManager

print(*args, **kwargs)[source]

Print something only on the first process. If running on multiple machines, it will print from the first process in each machine.

Arguments passed to this method are forwarded to the Python built-in print() function.

Return type:

None

rank_zero_first(local=False)[source]

The code block under this context manager gets executed first on the main process (rank 0) and only when completed, the other processes get to run the code in parallel.

Parameters:

local (bool) – Set this to True if the local rank should be the one going first. Useful if you are downloading data and the filesystem isn’t shared between the nodes.

Return type:

Generator

Example:

with fabric.rank_zero_first():
    dataset = MNIST("datasets/", download=True)
run(*args, **kwargs)[source]

All the code inside this run method gets accelerated by Fabric.

You can pass arbitrary arguments to this function when overriding it.

Return type:

Any

save(path, state, filter=None)[source]

Save checkpoint contents to a file.

How and which processes save gets determined by the strategy. For example, the ddp strategy saves checkpoints only on process 0, while the fsdp strategy saves files from every rank. This method must be called on all processes!

Parameters:
  • path (Union[str, Path]) – A path to where the file(s) should be saved

  • state (Dict[str, Union[Module, Optimizer, Any]]) – A dictionary with contents to be saved. If the dict contains modules or optimizers, their state-dict will be retrieved and converted automatically.

  • filter (Optional[Dict[str, Callable[[str, Any], bool]]]) – An optional dictionary containing filter callables that return a boolean indicating whether the given item should be saved (True) or filtered out (False). Each filter key should match a state key, where its filter will be applied to the state_dict generated.

Return type:

None

static seed_everything(seed=None, workers=None)[source]

Helper function to seed everything without explicitly importing Lightning.

See seed_everything() for more details.

Return type:

int

setup(module, *optimizers, move_to_device=True, _reapply_compile=True)[source]

Set up a model and its optimizers for accelerated training.

Parameters:
  • module (Module) – A torch.nn.Module to set up

  • *optimizers (Optimizer) – The optimizer(s) to set up (no optimizers is also possible)

  • move_to_device (bool) – If set True (default), moves the model to the correct device. Set this to False and alternatively use to_device() manually.

  • _reapply_compile (bool) – If True (default), and the model was torch.compile``d before, the corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP, FSDP etc.). Only applies on PyTorch >= 2.1. Set it to ``False if compiling DDP/FSDP is causing issues.

Return type:

Any

Returns:

The tuple containing wrapped module and the optimizers, in the same order they were passed in.

setup_dataloaders(*dataloaders, use_distributed_sampler=True, move_to_device=True)[source]

Set up one or multiple dataloaders for accelerated training. If you need different settings for each dataloader, call this method individually for each one.

Parameters:
  • *dataloaders (DataLoader) – A single dataloader or a sequence of dataloaders.

  • use_distributed_sampler (bool) – If set True (default), automatically wraps or replaces the sampler on the dataloader(s) for distributed training. If you have a custom sampler defined, set this argument to False.

  • move_to_device (bool) – If set True (default), moves the data returned by the dataloader(s) automatically to the correct device. Set this to False and alternatively use to_device() manually on the returned data.

Return type:

Union[DataLoader, List[DataLoader]]

Returns:

The wrapped dataloaders, in the same order they were passed in.

setup_module(module, move_to_device=True, _reapply_compile=True)[source]

Set up a model for accelerated training or inference.

This is the same as calling .setup(model) with no optimizers. It is useful for inference or for certain strategies like FSDP that require setting up the module before the optimizer can be created and set up. See also setup_optimizers().

Parameters:
  • module (Module) – A torch.nn.Module to set up

  • move_to_device (bool) – If set True (default), moves the model to the correct device. Set this to False and alternatively use to_device() manually.

  • _reapply_compile (bool) – If True (default), and the model was torch.compile``d before, the corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP, FSDP etc.). Only applies on PyTorch >= 2.1. Set it to ``False if compiling DDP/FSDP is causing issues.

Return type:

_FabricModule

Returns:

The wrapped model.

setup_optimizers(*optimizers)[source]

Set up one or more optimizers for accelerated training.

Some strategies do not allow setting up model and optimizer independently. For them, you should call .setup(model, optimizer, ...) instead to jointly set them up.

Parameters:

*optimizers (Optimizer) – One or more optmizers to set up.

Return type:

Union[_FabricOptimizer, Tuple[_FabricOptimizer, ...]]

Returns:

The wrapped optimizer(s).

sharded_model()[source]

Instantiate a model under this context manager to prepare it for model-parallel sharding. :rtype: ContextManager

Deprecated since version This: context manager is deprecated in favor of init_module(), use it instead.

to_device(obj)[source]

Move a torch.nn.Module or a collection of tensors to the current device, if it is not already on that device.

Parameters:

obj (Union[Module, Tensor, Any]) – An object to move to the device. Can be an instance of torch.nn.Module, a tensor, or a (nested) collection of tensors (e.g., a dictionary).

Return type:

Union[Module, Tensor, Any]

Returns:

A reference to the object that was moved to the new device.

property device: device

The current device this process runs on.

Use this to create tensors directly on the device if needed.

property global_rank: int

The global index of the current process across all devices and nodes.

property is_global_zero: bool

Whether this rank is rank zero.

property local_rank: int

The index of the current process among the processes running on the local node.

property logger: Logger

Returns the first logger in the list passed to Fabric, which is considered the main logger.

property loggers: List[Logger]

Returns all loggers passed to Fabric.

property node_rank: int

The index of the current node.

property world_size: int

The total number of processes running across all devices and nodes.