Fabric¶
- class lightning.fabric.fabric.Fabric(*, accelerator='auto', strategy='auto', devices='auto', num_nodes=1, precision=None, plugins=None, callbacks=None, loggers=None)[source]¶
Bases:
object
Fabric accelerates your PyTorch training or inference code with minimal changes required.
Automatic placement of models and data onto the device.
Automatic support for mixed and double precision (smaller memory footprint).
Seamless switching between hardware (CPU, GPU, TPU) and distributed training strategies (data-parallel training, sharded training, etc.).
Automated spawning of processes, no launch utilities required.
Multi-node support.
- Parameters:
accelerator¶ (
Union
[str
,Accelerator
]) – The hardware to run on. Possible choices are:"cpu"
,"cuda"
,"mps"
,"gpu"
,"tpu"
,"auto"
.strategy¶ (
Union
[str
,Strategy
]) – Strategy for how to run across multiple devices. Possible choices are:"dp"
,"ddp"
,"ddp_spawn"
,"deepspeed"
,"fsdp"
.devices¶ (
Union
[list
[int
],str
,int
]) – Number of devices to train on (int
), which GPUs to train on (list
orstr
), or"auto"
. The value applies per node.num_nodes¶ (
int
) – Number of GPU nodes for distributed training.precision¶ (
Union
[Literal
[64
,32
,16
],Literal
['transformer-engine'
,'transformer-engine-float16'
,'16-true'
,'16-mixed'
,'bf16-true'
,'bf16-mixed'
,'32-true'
,'64-true'
],Literal
['64'
,'32'
,'16'
,'bf16'
],None
]) – Double precision ("64"
), full precision ("32"
), half precision AMP ("16-mixed"
), or bfloat16 precision AMP ("bf16-mixed"
).plugins¶ (
Union
[Precision
,ClusterEnvironment
,CheckpointIO
,list
[Union
[Precision
,ClusterEnvironment
,CheckpointIO
]],None
]) – One or several custom pluginscallbacks¶ (
Union
[list
[Any
],Any
,None
]) – A single callback or a list of callbacks. A callback can contain any arbitrary methods that can be invoked throughcall()
by the user.loggers¶ (
Union
[Logger
,list
[Logger
],None
]) – A single logger or a list of loggers. Seelog()
for more information.
- _setup_dataloader(dataloader, use_distributed_sampler=True, move_to_device=True)[source]¶
Set up a single dataloader for accelerated training.
- Parameters:
dataloader¶ (
DataLoader
) – The dataloader to accelerate.use_distributed_sampler¶ (
bool
) – If setTrue
(default), automatically wraps or replaces the sampler on the dataloader for distributed training. If you have a custom sampler defined, set this argument toFalse
.move_to_device¶ (
bool
) – If setTrue
(default), moves the data returned by the dataloader automatically to the correct device. Set this toFalse
and alternatively useto_device()
manually on the returned data.
- Return type:
- Returns:
The wrapped dataloader.
- all_gather(data, group=None, sync_grads=False)[source]¶
Gather tensors or collections of tensors from multiple processes.
This method needs to be called on all processes and the tensors need to have the same shape across all processes, otherwise your program will stall forever.
- Parameters:
data¶ (
Union
[Tensor
,dict
,list
,tuple
]) – int, float, tensor of shape (batch, …), or a (possibly nested) collection thereof.group¶ (
Optional
[Any
]) – the process group to gather results from. Defaults to all processes (world).sync_grads¶ (
bool
) – flag that allows users to synchronize gradients for theall_gather
operation
- Return type:
- Returns:
A tensor of shape (world_size, batch, …), or if the input was a collection the output will also be a collection with tensors of this shape. For the special case where world_size is 1, no additional dimension is added to the tensor(s).
- all_reduce(data, group=None, reduce_op='mean')[source]¶
Reduce tensors or collections of tensors from multiple processes.
The reduction on tensors is applied in-place, meaning the result will be placed back into the input tensor. This method needs to be called on all processes and the tensors need to have the same shape across all processes, otherwise your program will stall forever.
- Parameters:
data¶ (
Union
[Tensor
,dict
,list
,tuple
]) – int, float, tensor of shape (batch, …), or a (possibly nested) collection thereof. Tensor will be modified in-place.group¶ (
Optional
[Any
]) – the process group to reduce results across. Defaults to all processes (world).reduce_op¶ (
Union
[ReduceOp
,str
,None
]) – the reduction operation. Defaults to ‘mean’. Can also be a string ‘sum’ or ReduceOp. Some strategies may limit the choices here.
- Return type:
- Returns:
A tensor of the same shape as the input with values reduced pointwise across processes. The same is applied to tensors in a collection if a collection is given as input.
- autocast()[source]¶
A context manager to automatically convert operations for the chosen precision.
Use this only if the forward method of your model does not cover all operations you wish to run with the chosen precision setting.
- Return type:
- backward(tensor, *args, model=None, **kwargs)[source]¶
Replaces
loss.backward()
in your training loop. Handles precision and automatically for you.- Parameters:
tensor¶ (
Tensor
) – The tensor (loss) to back-propagate gradients from.*args¶ (
Any
) – Optional positional arguments passed to the underlying backward function.model¶ (
Optional
[_FabricModule
]) – Optional model instance for plugins that require the model for backward().**kwargs¶ (
Any
) – Optional named keyword arguments passed to the underlying backward function.
- Return type:
Note
When using
strategy="deepspeed"
and multiple models were set up, it is required to pass in the model as argument here.
- barrier(name=None)[source]¶
Wait for all processes to enter this call.
Use this to synchronize all parallel processes, but only if necessary, otherwise the overhead of synchronization will cause your program to slow down. This method needs to be called on all processes. Failing to do so will cause your program to stall forever.
- Return type:
- broadcast(obj, src=0)[source]¶
Send a tensor from one process to all others.
This method needs to be called on all processes. Failing to do so will cause your program to stall forever.
- Parameters:
- Return type:
TypeVar
(TBroadcast
)- Returns:
The transferred data, the same value on every rank.
- call(hook_name, *args, **kwargs)[source]¶
Trigger the callback methods with the given name and arguments.
Not all objects registered via
Fabric(callbacks=...)
must implement a method with the given name. The ones that have a matching method name will get called.- Parameters:
- Return type:
Example:
class MyCallback: def on_train_epoch_end(self, results): ... fabric = Fabric(callbacks=[MyCallback()]) fabric.call("on_train_epoch_end", results={...})
- clip_gradients(module, optimizer, clip_val=None, max_norm=None, norm_type=2.0, error_if_nonfinite=True)[source]¶
Clip the gradients of the model to a given max value or max norm.
- Parameters:
module¶ (
Union
[Module
,_FabricModule
]) – The module whose parameters should be clipped.optimizer¶ (
Union
[Optimizer
,_FabricOptimizer
]) – The optimizer referencing the parameters to be clipped.clip_val¶ (
Union
[float
,int
,None
]) – If passed, gradients will be clipped to this value.max_norm¶ (
Union
[float
,int
,None
]) – If passed, clips the gradients in such a way that the p-norm of the resulting parameters is no larger than the given value.norm_type¶ (
Union
[float
,int
]) – The type of norm if max_norm was passed. Can be'inf'
for infinity norm. Default is the 2-norm.error_if_nonfinite¶ (
bool
) – An error is raised if the total norm of the gradients is NaN or infinite.
- Return type:
- Returns:
The total norm of the gradients (before clipping was applied) as a scalar tensor if
max_norm
was passed, otherwiseNone
.
- init_module(empty_init=None)[source]¶
Instantiate the model and its parameters under this context manager to reduce peak memory usage.
The parameters get created on the device and with the right data type right away without wasting memory being allocated unnecessarily.
- init_tensor()[source]¶
Tensors that you instantiate under this context manager will be created on the device right away and have the right data type depending on the precision setting in Fabric.
- Return type:
- launch(function=<function _do_nothing>, *args, **kwargs)[source]¶
Launch and initialize all the processes needed for distributed execution.
- Parameters:
function¶ (
Callable
[[Fabric
],Any
]) – Optional function to launch when using a spawn/fork-based strategy, for example, when using the XLA strategy (accelerator="tpu"
). The function must accept at least one argument, to which the Fabric object itself will be passed.*args¶ (
Any
) – Optional positional arguments to be passed to the function.**kwargs¶ (
Any
) – Optional keyword arguments to be passed to the function.
- Return type:
- Returns:
Returns the output of the function that ran in worker process with rank 0.
The
launch()
method should only be used if you intend to specify accelerator, devices, and so on in the code (programmatically). If you are launching with the Lightning CLI,fabric run ...
, removelaunch()
from your code.The
launch()
is a no-op when called multiple times and no function is passed in.
- load(path, state=None, strict=True)[source]¶
Load a checkpoint from a file and restore the state of objects (modules, optimizers, etc.)
How and which processes load gets determined by the strategy. This method must be called on all processes!
- Parameters:
path¶ (
Union
[str
,Path
]) – A path to where the file is locatedstate¶ (
Optional
[dict
[str
,Union
[Module
,Optimizer
,Any
]]]) – A dictionary of objects whose state will be restored in-place from the checkpoint path. If no state is given, then the checkpoint will be returned in full.strict¶ (
bool
) – Whether to enforce that the keys in state match the keys in the checkpoint.
- Return type:
- Returns:
The remaining items that were not restored into the given state dictionary. If no state dictionary is given, the full checkpoint will be returned.
- load_raw(path, obj, strict=True)[source]¶
Load the state of a module or optimizer from a single state-dict file.
Use this for loading a raw PyTorch model checkpoint created without Fabric. This is conceptually equivalent to
obj.load_state_dict(torch.load(path))
, but is agnostic to the strategy being used.
- log(name, value, step=None)[source]¶
Log a scalar to all loggers that were added to Fabric.
- Parameters:
- Return type:
- log_dict(metrics, step=None)[source]¶
Log multiple scalars at once to all loggers that were added to Fabric.
- Parameters:
metrics¶ (
Mapping
[str
,Any
]) – A dictionary where the key is the name of the metric and the value the scalar to be logged. Anytorch.Tensor
in the dictionary get detached from the graph automatically.step¶ (
Optional
[int
]) – Optional step number. Most Logger implementations auto-increment this value by one with every log call. You can specify your own value here.
- Return type:
- no_backward_sync(module, enabled=True)[source]¶
Skip gradient synchronization during backward to avoid redundant communication overhead.
Use this context manager when performing gradient accumulation to speed up training with multiple devices.
Example:
# Accumulate gradient 8 batches at a time with fabric.no_backward_sync(model, enabled=(batch_idx % 8 != 0)): output = model(input) loss = ... fabric.backward(loss) ...
For those strategies that don’t support it, a warning is emitted. For single-device strategies, it is a no-op. Both the model’s
.forward()
and thefabric.backward()
call need to run under this context.- Parameters:
- Return type:
- print(*args, **kwargs)[source]¶
Print something only on the first process. If running on multiple machines, it will print from the first process in each machine.
Arguments passed to this method are forwarded to the Python built-in
print()
function.- Return type:
- rank_zero_first(local=False)[source]¶
The code block under this context manager gets executed first on the main process (rank 0) and only when completed, the other processes get to run the code in parallel.
- Parameters:
local¶ (
bool
) – Set this toTrue
if the local rank should be the one going first. Useful if you are downloading data and the filesystem isn’t shared between the nodes.- Return type:
Example:
with fabric.rank_zero_first(): dataset = MNIST("datasets/", download=True)
- run(*args, **kwargs)[source]¶
All the code inside this run method gets accelerated by Fabric.
You can pass arbitrary arguments to this function when overriding it.
- Return type:
- save(path, state, filter=None)[source]¶
Save checkpoint contents to a file.
How and which processes save gets determined by the strategy. For example, the ddp strategy saves checkpoints only on process 0, while the fsdp strategy saves files from every rank. This method must be called on all processes!
- Parameters:
path¶ (
Union
[str
,Path
]) – A path to where the file(s) should be savedstate¶ (
dict
[str
,Union
[Module
,Optimizer
,Any
]]) – A dictionary with contents to be saved. If the dict contains modules or optimizers, their state-dict will be retrieved and converted automatically.filter¶ (
Optional
[dict
[str
,Callable
[[str
,Any
],bool
]]]) – An optional dictionary containing filter callables that return a boolean indicating whether the given item should be saved (True
) or filtered out (False
). Each filter key should match a state key, where its filter will be applied to thestate_dict
generated.
- Return type:
- static seed_everything(seed=None, workers=None, verbose=True)[source]¶
Helper function to seed everything without explicitly importing Lightning.
See
seed_everything()
for more details.- Return type:
- setup(module, *optimizers, move_to_device=True, _reapply_compile=True)[source]¶
Set up a model and its optimizers for accelerated training.
- Parameters:
module¶ (
Module
) – Atorch.nn.Module
to set up*optimizers¶ (
Optimizer
) – The optimizer(s) to set up (no optimizers is also possible)move_to_device¶ (
bool
) – If setTrue
(default), moves the model to the correct device. Set this toFalse
and alternatively useto_device()
manually._reapply_compile¶ (
bool
) – IfTrue
(default), and the model wastorch.compile``d before, the corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP, FSDP etc.). Set it to ``False
if compiling DDP/FSDP is causing issues.
- Return type:
- Returns:
The tuple containing wrapped module and the optimizers, in the same order they were passed in.
- setup_dataloaders(*dataloaders, use_distributed_sampler=True, move_to_device=True)[source]¶
Set up one or multiple dataloaders for accelerated training. If you need different settings for each dataloader, call this method individually for each one.
- Parameters:
*dataloaders¶ (
DataLoader
) – A single dataloader or a sequence of dataloaders.use_distributed_sampler¶ (
bool
) – If setTrue
(default), automatically wraps or replaces the sampler on the dataloader(s) for distributed training. If you have a custom sampler defined, set this argument toFalse
.move_to_device¶ (
bool
) – If setTrue
(default), moves the data returned by the dataloader(s) automatically to the correct device. Set this toFalse
and alternatively useto_device()
manually on the returned data.
- Return type:
- Returns:
The wrapped dataloaders, in the same order they were passed in.
- setup_module(module, move_to_device=True, _reapply_compile=True)[source]¶
Set up a model for accelerated training or inference.
This is the same as calling
.setup(model)
with no optimizers. It is useful for inference or for certain strategies like FSDP that require setting up the module before the optimizer can be created and set up. See alsosetup_optimizers()
.- Parameters:
module¶ (
Module
) – Atorch.nn.Module
to set upmove_to_device¶ (
bool
) – If setTrue
(default), moves the model to the correct device. Set this toFalse
and alternatively useto_device()
manually._reapply_compile¶ (
bool
) – IfTrue
(default), and the model wastorch.compile``d before, the corresponding :class:`~torch._dynamo.OptimizedModule` wrapper will be removed and reapplied with the same settings after the model was set up by the strategy (e.g., after the model was wrapped by DDP, FSDP etc.). Set it to ``False
if compiling DDP/FSDP is causing issues.
- Return type:
_FabricModule
- Returns:
The wrapped model.
- setup_optimizers(*optimizers)[source]¶
Set up one or more optimizers for accelerated training.
Some strategies do not allow setting up model and optimizer independently. For them, you should call
.setup(model, optimizer, ...)
instead to jointly set them up.
- sharded_model()[source]¶
Instantiate a model under this context manager to prepare it for model-parallel sharding. :rtype:
AbstractContextManager
Deprecated since version This: context manager is deprecated in favor of
init_module()
, use it instead.
- to_device(obj)[source]¶
Move a
torch.nn.Module
or a collection of tensors to the current device, if it is not already on that device.
- property device: device¶
The current device this process runs on.
Use this to create tensors directly on the device if needed.
- property local_rank: int¶
The index of the current process among the processes running on the local node.
- property logger: Logger¶
Returns the first logger in the list passed to Fabric, which is considered the main logger.
- property loggers: list[lightning.fabric.loggers.logger.Logger]¶
Returns all loggers passed to Fabric.