XLAStrategy

class lightning.fabric.strategies.XLAStrategy(accelerator=None, parallel_devices=None, checkpoint_io=None, precision=None, sync_module_states=True)[source]

Bases: ParallelStrategy

Strategy for training multiple TPU devices using the torch_xla.distributed.xla_multiprocessing.spawn() method.

_configure_launcher()[source]

Attach the launcher based on Strategy.

Return type:

None

all_gather(tensor, group=None, sync_grads=False)[source]

Function to gather a tensor from several distributed processes.

Parameters:
  • tensor (Tensor) – tensor to all-gather.

  • group (Optional[Any]) – unused.

  • sync_grads (bool) – flag that allows users to synchronize gradients for the all-gather operation.

Return type:

Tensor

Returns:

A tensor of shape (world_size, …)

all_reduce(output, group=None, reduce_op=None)[source]

Reduces the given tensor (e.g. across GPUs/processes).

Parameters:
  • tensor – the tensor to sync and reduce

  • group (Optional[Any]) – the process group to reduce

  • reduce_op (Union[ReduceOp, str, None]) – the reduction operation. Defaults to ‘mean’. Can also be a string ‘sum’ or ReduceOp.

Return type:

Tensor

barrier(name=None, *args, **kwargs)[source]

Synchronizes all processes which blocks processes until the whole group enters this function.

Parameters:

name (Optional[str]) – an optional name to pass into barrier.

Return type:

None

broadcast(obj, src=0)[source]

Broadcasts an object to all processes.

Parameters:
  • obj (TypeVar(TBroadcast)) – the object to broadcast

  • src (int) – source rank

Return type:

TypeVar(TBroadcast)

module_to_device(module)[source]

Moves the model to the correct device.

Return type:

None

process_dataloader(dataloader)[source]

Wraps the dataloader if necessary.

Parameters:

dataloader (DataLoader) – iterable. Ideally of type: torch.utils.data.DataLoader

Return type:

MpDeviceLoader

save_checkpoint(path, state, storage_options=None, filter=None)[source]

Save model, optimizer, and other state as a checkpoint file.

Parameters:
  • path (Union[str, Path]) – A path to where the file(s) should be saved

  • state (Dict[str, Union[Module, Optimizer, Any]]) – A dictionary with contents to be saved. If the dict contains modules or optimizers, their state-dict will be retrieved and converted automatically.

  • storage_options (Optional[Any]) – Additional options for the CheckpointIO plugin

  • filter (Optional[Dict[str, Callable[[str, Any], bool]]]) – An optional dictionary of the same format as state mapping keys to callables that return a boolean indicating whether the given parameter should be saved (True) or filtered out (False).

Return type:

None

setup_environment()[source]

Setup any processes or distributed connections.

This must be called by the framework at the beginning of every process, before any distributed communication takes place.

Return type:

None

setup_module(module)[source]

Performs setup for the model, e.g., by wrapping it by another class.

Return type:

Module

property root_device: device

Returns the root device.