lightning.fabric.utilities.seed.seed_everything(seed=None, workers=False)[source]

Function that sets the seed for pseudo-random number generators in: torch, numpy, and Python’s random module. In addition, sets the following environment variables:

  • PL_GLOBAL_SEED: will be passed to spawned subprocesses (e.g. ddp_spawn backend).

  • PL_SEED_WORKERS: (optional) is set to 1 if workers=True.

  • seed (Optional[int]) – the integer value seed for global random state in Lightning. If None, it will read the seed from PL_GLOBAL_SEED env variable. If None and the PL_GLOBAL_SEED env variable is not set, then the seed defaults to 0.

  • workers (bool) – if set to True, will properly configure all dataloaders passed to the Trainer with a worker_init_fn. If the user already provides such a function for their dataloaders, setting this argument will have no influence. See also: pl_worker_init_function().

Return type:


lightning.fabric.utilities.seed.pl_worker_init_function(worker_id, rank=None)[source]

The worker_init_fn that Lightning automatically adds to your dataloader if you previously set the seed with seed_everything(seed, workers=True).

See also the PyTorch documentation on randomness in DataLoaders.

Return type:


Suggests an upper bound of num_workers to use in a PyTorch DataLoader based on the number of CPU cores available on the system and the number of distributed processes in the current machine.


local_world_size (int) – The number of distributed processes running on the current machine. Set this to the number of devices configured in Fabric/Trainer.

Return type:


lightning.fabric.utilities.distributed.is_shared_filesystem(strategy, path=None, timeout=3)[source]

Checks whether the filesystem under the given path is shared across all processes.

This function should only be used in a context where distributed is initialized.

  • strategy (Strategy) – The strategy being used, either from Fabric (fabric.strategy) or from Trainer (trainer.strategy).

  • path (Union[str, Path, None]) – The path to check. Defaults to the current working directory. The user must have permissions to write to this path or the parent folder, and the filesystem must be writable.

  • timeout (int) – If any of the processes can’t list the file created by rank 0 within this many seconds, the filesystem is determined to be not shared.

Return type:



Ignore warnings of the category PossibleUserWarning from Lightning.

For more granular control over which warnings to ignore, use warnings.filterwarnings() directly.


module (str) – Name of the module for which the warnings should be ignored (e.g., 'lightning.pytorch.strategies'). Default: Disables warnings from all modules.

Return type:


lightning.fabric.utilities.throughput.measure_flops(model, forward_fn, loss_fn=None)[source]

Utility to compute the total number of FLOPs used by a module during training or during inference.

It’s recommended to create a meta-device model for this:


with torch.device("meta"):
    model = MyModel()
    x = torch.randn(2, 32)

model_fwd = lambda: model(x)
fwd_flops = measure_flops(model, model_fwd)

model_loss = lambda y: y.sum()
fwd_and_bwd_flops = measure_flops(model, model_fwd, model_loss)
  • model (Module) – The model whose FLOPs should be measured.

  • forward_fn (Callable[[], Tensor]) – A function that runs forward on the model and returns the result.

  • loss_fn (Optional[Callable[[Tensor], Tensor]]) – A function that computes the loss given the forward_fn output. If provided, the loss and backward FLOPs will be included in the result.

Return type:



Bases: Dict

A container to store state variables of your program.

This is a drop-in replacement for a Python dictionary, with the additional functionality to access and modify keys through attribute lookup for convenience.

Use this to define the state of your program, then pass it to save() and load().


>>> import torch
>>> model = torch.nn.Linear(2, 2)
>>> state = AttributeDict(model=model, iter_num=0)
>>> state.model
Linear(in_features=2, out_features=2, bias=True)
>>> state.iter_num += 1
>>> state.iter_num
>>> state
"iter_num": 1
"model":    Linear(in_features=2, out_features=2, bias=True)
class lightning.fabric.utilities.throughput.ThroughputMonitor(fabric, **kwargs)[source]

Bases: Throughput

Computes throughput.

This class will automatically keep a count of the number of log calls (step). But that can be modified as desired. For manual logging, using Throughput directly might be desired.


logger = ...
fabric = Fabric(logger=logger)
throughput = ThroughputMonitor()
t0 = time()
for i in range(1, 100):
    if torch.cuda.is_available(): torch.cuda.synchronize()  # required or else time() won't be correct
    throughput.update(time=time() - t0, batches=i, samples=i)
    if i % 10 == 0:
compute_and_log(step=None, **kwargs)[source]

See Throughput.compute()

Return type:

Dict[str, Union[int, float]]

class lightning.fabric.utilities.throughput.Throughput(available_flops=None, world_size=1, window_size=100, separator='/')[source]

Bases: object

Computes throughput.



batches_per_sec | Rolling average (over window_size most recent updates) of the number of batches
processed per second
samples_per_sec | Rolling average (over window_size most recent updates) of the number of samples
processed per second
items_per_sec | Rolling average (over window_size most recent updates) of the number of items
processed per second
flpps_per_sec | Rolling average (over window_size most recent updates) of the number of flops
processed per second

device/batches_per_sec | batches_per_sec divided by world size

device/samples_per_sec | samples_per_sec divided by world size

device/items_per_sec | items_per_sec divided by world size. This may include padding depending on the data

device/flops_per_sec | flops_per_sec divided by world size.

device/mfu | device/flops_per_sec divided by world size.

time | Total elapsed time

batches | Total batches seen

samples | Total samples seen

lengths | Total items seen


throughput = Throughput()
t0 = time()
for i in range(1000):
    if torch.cuda.is_available(): torch.cuda.synchronize()  # required or else time() won't be correct
    throughput.update(time=time() - t0, samples=i)
    if i % 10 == 0:


  • The implementation assumes that devices FLOPs are all the same as it normalizes by the world size and only takes a single available_flops value.

  • items_per_sec, flops_per_sec and MFU do not account for padding if present. We suggest using samples_per_sec or batches_per_sec to measure throughput under this circumstance.

  • available_flops (Optional[float]) – Number of theoretical flops available for a single device.

  • world_size (int) – Number of devices available across hosts. Global metrics are not included if the world size is 1.

  • window_size (int) – Number of batches to use for a rolling average.

  • separator (str) – Key separator to use when creating per-device and global metrics.


Compute throughput metrics.

Return type:

Dict[str, Union[int, float]]

update(*, time, batches, samples, lengths=None, flops=None)[source]

Update throughput metrics.

  • time (float) – Total elapsed time in seconds. It should monotonically increase by the iteration time with each call.

  • batches (int) – Total batches seen per device. It should monotonically increase with each call.

  • samples (int) – Total samples seen per device. It should monotonically increase by the batch size with each call.

  • lengths (Optional[int]) – Total length of the samples seen. It should monotonically increase by the lengths of a batch with each call.

  • flops (Optional[int]) – Flops elapased per device since last update() call. You can easily compute this by using measure_flops() and multiplying it by the number of batches that have been processed. The value might be different in each device if the batch size is not the same.

Return type: