lightning.fabric.utilities¶
- lightning.fabric.utilities.seed.seed_everything(seed=None, workers=False, verbose=True)[source]¶
Function that sets the seed for pseudo-random number generators in: torch, numpy, and Python’s random module. In addition, sets the following environment variables:
PL_GLOBAL_SEED
: will be passed to spawned subprocesses (e.g. ddp_spawn backend).PL_SEED_WORKERS
: (optional) is set to 1 ifworkers=True
.
- Parameters:
seed¶ (
Optional
[int
]) – the integer value seed for global random state in Lightning. IfNone
, it will read the seed fromPL_GLOBAL_SEED
env variable. IfNone
and thePL_GLOBAL_SEED
env variable is not set, then the seed defaults to 0.workers¶ (
bool
) – if set toTrue
, will properly configure all dataloaders passed to the Trainer with aworker_init_fn
. If the user already provides such a function for their dataloaders, setting this argument will have no influence. See also:pl_worker_init_function()
.verbose¶ (
bool
) – Whether to print a message on each rank with the seed being set.
- Return type:
- lightning.fabric.utilities.seed.pl_worker_init_function(worker_id, rank=None)[source]¶
The worker_init_fn that Lightning automatically adds to your dataloader if you previously set the seed with
seed_everything(seed, workers=True)
.See also the PyTorch documentation on randomness in DataLoaders.
- Return type:
- lightning.fabric.utilities.data.suggested_max_num_workers(local_world_size)[source]¶
Suggests an upper bound of
num_workers
to use in a PyTorchDataLoader
based on the number of CPU cores available on the system and the number of distributed processes in the current machine.
Checks whether the filesystem under the given path is shared across all processes.
This function should only be used in a context where distributed is initialized.
- Parameters:
strategy¶ (
Strategy
) – The strategy being used, either from Fabric (fabric.strategy
) or from Trainer (trainer.strategy
).path¶ (
Union
[str
,Path
,None
]) – The path to check. Defaults to the current working directory. The user must have permissions to write to this path or the parent folder, and the filesystem must be writable.timeout¶ (
int
) – If any of the processes can’t list the file created by rank 0 within this many seconds, the filesystem is determined to be not shared.
- Return type:
- lightning.fabric.utilities.warnings.disable_possible_user_warnings(module='')[source]¶
Ignore warnings of the category
PossibleUserWarning
from Lightning.For more granular control over which warnings to ignore, use
warnings.filterwarnings()
directly.
- lightning.fabric.utilities.throughput.measure_flops(model, forward_fn, loss_fn=None)[source]¶
Utility to compute the total number of FLOPs used by a module during training or during inference.
It’s recommended to create a meta-device model for this:
Example:
with torch.device("meta"): model = MyModel() x = torch.randn(2, 32) model_fwd = lambda: model(x) fwd_flops = measure_flops(model, model_fwd) model_loss = lambda y: y.sum() fwd_and_bwd_flops = measure_flops(model, model_fwd, model_loss)
- Parameters:
- Return type:
- class lightning.fabric.utilities.data.AttributeDict[source]¶
Bases:
Dict
A container to store state variables of your program.
This is a drop-in replacement for a Python dictionary, with the additional functionality to access and modify keys through attribute lookup for convenience.
Use this to define the state of your program, then pass it to
save()
andload()
.Example
>>> import torch >>> model = torch.nn.Linear(2, 2) >>> state = AttributeDict(model=model, iter_num=0) >>> state.model Linear(in_features=2, out_features=2, bias=True) >>> state.iter_num += 1 >>> state.iter_num 1 >>> state "iter_num": 1 "model": Linear(in_features=2, out_features=2, bias=True)
- class lightning.fabric.utilities.throughput.ThroughputMonitor(fabric, **kwargs)[source]¶
Bases:
Throughput
Computes throughput.
This class will automatically keep a count of the number of log calls (
step
). But that can be modified as desired. For manual logging, usingThroughput
directly might be desired.Example:
logger = ... fabric = Fabric(logger=logger) throughput = ThroughputMonitor(fabric) t0 = time() for i in range(1, 100): do_work() if torch.cuda.is_available(): torch.cuda.synchronize() # required or else time() won't be correct throughput.update(time=time() - t0, batches=i, samples=i) if i % 10 == 0: throughput.compute_and_log(step=i)
- Parameters:
**kwargs¶ (
Any
) – See available parameters inThroughput
- class lightning.fabric.utilities.throughput.Throughput(available_flops=None, world_size=1, window_size=100, separator='/')[source]¶
Bases:
object
Computes throughput.
Key
Value
- batches_per_sec | Rolling average (over
window_size
most recent updates) of the number of batches - processed per second
- samples_per_sec | Rolling average (over
window_size
most recent updates) of the number of samples - processed per second
- items_per_sec | Rolling average (over
window_size
most recent updates) of the number of items - processed per second
- flpps_per_sec | Rolling average (over
window_size
most recent updates) of the number of flops - processed per second
device/batches_per_sec | batches_per_sec divided by world size
device/samples_per_sec | samples_per_sec divided by world size
device/items_per_sec | items_per_sec divided by world size. This may include padding depending on the data
device/flops_per_sec | flops_per_sec divided by world size.
device/mfu | device/flops_per_sec divided by world size.
time | Total elapsed time
batches | Total batches seen
samples | Total samples seen
lengths | Total items seen
Example:
throughput = Throughput() t0 = time() for i in range(1000): do_work() if torch.cuda.is_available(): torch.cuda.synchronize() # required or else time() won't be correct throughput.update(time=time() - t0, samples=i) if i % 10 == 0: print(throughput.compute())
Notes
The implementation assumes that devices FLOPs are all the same as it normalizes by the world size and only takes a single
available_flops
value.items_per_sec, flops_per_sec and MFU do not account for padding if present. We suggest using samples_per_sec or batches_per_sec to measure throughput under this circumstance.
- Parameters:
available_flops¶ (
Optional
[float
]) – Number of theoretical flops available for a single device.world_size¶ (
int
) – Number of devices available across hosts. Global metrics are not included if the world size is 1.window_size¶ (
int
) – Number of batches to use for a rolling average.separator¶ (
str
) – Key separator to use when creating per-device and global metrics.
- update(*, time, batches, samples, lengths=None, flops=None)[source]¶
Update throughput metrics.
- Parameters:
time¶ (
float
) – Total elapsed time in seconds. It should monotonically increase by the iteration time with each call.batches¶ (
int
) – Total batches seen per device. It should monotonically increase with each call.samples¶ (
int
) – Total samples seen per device. It should monotonically increase by the batch size with each call.lengths¶ (
Optional
[int
]) – Total length of the samples seen. It should monotonically increase by the lengths of a batch with each call.flops¶ (
Optional
[int
]) – Flops elapased per device since lastupdate()
call. You can easily compute this by usingmeasure_flops()
and multiplying it by the number of batches that have been processed. The value might be different in each device if the batch size is not the same.
- Return type:
- batches_per_sec | Rolling average (over