
Fabric Arguments


Choose one of "cpu", "gpu", "tpu", "auto".

# CPU accelerator
fabric = Fabric(accelerator="cpu")

# Running with GPU Accelerator using 2 GPUs
fabric = Fabric(devices=2, accelerator="gpu")

# Running with TPU Accelerator using 8 tpu cores
fabric = Fabric(devices=8, accelerator="tpu")

# Running with GPU Accelerator using the DistributedDataParallel strategy
fabric = Fabric(devices=4, accelerator="gpu", strategy="ddp")

The "auto" option recognizes the machine you are on and selects the available accelerator.

# If your machine has GPUs, it will use the GPU Accelerator
fabric = Fabric(devices=2, accelerator="auto")


Choose a training strategy: "dp", "ddp", "ddp_spawn", "xla", "deepspeed", "fsdp"``.

# Running with the DistributedDataParallel strategy on 4 GPUs
fabric = Fabric(strategy="ddp", accelerator="gpu", devices=4)

# Running with the DDP Spawn strategy using 4 cpu processes
fabric = Fabric(strategy="ddp_spawn", accelerator="cpu", devices=4)

Additionally, you can pass in your custom strategy by configuring additional parameters.

from lightning.fabric.strategies import DeepSpeedStrategy

fabric = Fabric(strategy=DeepSpeedStrategy(stage=2), accelerator="gpu", devices=2)


Configure the devices to run on. Can be of type:

  • int: the number of devices (e.g., GPUs) to train on

  • list of int: which device index (e.g., GPU ID) to train on (0-indexed)

  • str: a string representation of one of the above

# default used by Fabric, i.e., use the CPU
fabric = Fabric(devices=None)

# equivalent
fabric = Fabric(devices=0)

# int: run on two GPUs
fabric = Fabric(devices=2, accelerator="gpu")

# list: run on GPUs 1, 4 (by bus ordering)
fabric = Fabric(devices=[1, 4], accelerator="gpu")
fabric = Fabric(devices="1, 4", accelerator="gpu")  # equivalent

# -1: run on all GPUs
fabric = Fabric(devices=-1, accelerator="gpu")
fabric = Fabric(devices="-1", accelerator="gpu")  # equivalent


Number of cluster nodes for distributed operation.

# Default used by Fabric
fabric = Fabric(num_nodes=1)

# Run on 8 nodes
fabric = Fabric(num_nodes=8)

Learn more about distributed multi-node training on clusters here.


Fabric supports double precision (64), full precision (32), or half precision (16) operation (including bfloat16). Half precision, or mixed precision, is the combined use of 32 and 16-bit floating points to reduce the memory footprint during model training. This can result in improved performance, achieving significant speedups on modern GPUs.

# Default used by the Fabric
fabric = Fabric(precision=32, devices=1)

# 16-bit (mixed) precision
fabric = Fabric(precision=16, devices=1)

# 16-bit bfloat precision
fabric = Fabric(precision="bf16", devices=1)

# 64-bit (double) precision
fabric = Fabric(precision=64, devices=1)


Plugins allow you to connect arbitrary backends, precision libraries, clusters etc. For example: To define your own behavior, subclass the relevant class and pass it in. Here’s an example linking up your own ClusterEnvironment.

from lightning.fabric.plugins.environments import ClusterEnvironment

class MyCluster(ClusterEnvironment):
    def main_address(self):
        return your_main_address

    def main_port(self):
        return your_main_port

    def world_size(self):
        return the_world_size

fabric = Fabric(plugins=[MyCluster()], ...)


A callback class is a collection of methods that the training loop can call at a specific point in time, for example, at the end of an epoch. Add callbacks to Fabric to inject logic into your training loop from an external callback class.

class MyCallback:
    def on_train_epoch_end(self, results):

You can then register this callback, or multiple ones directly in Fabric:

fabric = Fabric(callbacks=[MyCallback()])

Then, in your training loop, you can call a hook by its name. Any callback objects that have this hook will execute it:

# Call any hook by name"on_train_epoch_end", results={...})


Attach one or several loggers/experiment trackers to Fabric for convenient logging of metrics.

# Default used by Fabric, no loggers are active
fabric = Fabric(loggers=[])

# Log to a single logger
fabric = Fabric(loggers=TensorBoardLogger(...))

# Or multiple instances
fabric = Fabric(loggers=[logger1, logger2, ...])

Anywhere in your training loop, you can log metrics to all loggers at once:

fabric.log("loss", loss)
fabric.log_dict({"loss": loss, "accuracy": acc})

