Shortcuts

DDPFullyShardedStrategy

class pytorch_lightning.strategies.DDPFullyShardedStrategy(accelerator=None, cpu_offload=False, flatten_parameters=True, reshard_after_forward=True, move_grads_to_cpu=None, fp32_reduce_scatter=None, compute_dtype=None, bucket_cap_mb=25, min_num_params=100000000, state_dict_to_cpu=True, parallel_devices=None, cluster_environment=None, checkpoint_io=None, precision_plugin=None, process_group_backend=None)[source]

Bases: pytorch_lightning.strategies.ddp.DDPStrategy

Plugin for Fully Sharded Data Parallel provided by FairScale.

Warning

DDPFullyShardedStrategy is in beta and subject to change.

Full Sharded Training shards the entire model across all available GPUs, allowing you to scale model size, whilst using efficient communication to reduce overhead. In practice, this means we can remain at parity with PyTorch DDP, whilst scaling our model sizes dramatically. The technique is similar to ZeRO-Stage 3 but has been built for upstreaming to PyTorch.

For more information check out FairScale’s docs.

Defaults have been set and options have been exposed, but may require configuration based on your level of memory/speed efficiency. We suggest having a look at this PR for more information.

Many of the helpful doc strings below came from the original FairScale documentation.

Parameters:
  • cpu_offload (bool) – Offload FP32 params to CPU. Only usable in precision=16 mode. (Default: False).

  • move_grads_to_cpu (Optional[bool]) – Moves gradient shards to CPU after reduction. Only disable if using CPU based optimizers (Default to cpu_offload).

  • flatten_parameters (bool) – Flattens parameter into single contiguous tensor for speed efficiency (Default: True).

  • reshard_after_forward (bool) – Reshard parameters after the forward pass, which saves memory but slows down training. This is only relevant when resharding individual layers. (Default: True).

  • fp32_reduce_scatter (Optional[bool]) – Reduce-Scatter gradients in FP32. Only relevant in mixed precision (Default: None).

  • compute_dtype (Optional[dtype]) – dtype for full parameters for computation. Default to torch.float32, unless using mixed precision, in which case defaults to torch.float16. (Default: None).

  • bucket_cap_mb (int) – bucket parameters so that gradient reduction can potentially overlap with backward computation. bucket_cap_mb controls the bucket size in MegaBytes (MB). Buckets are sub-divided based on world_size, so the max shard size is roughly bucket_cap_mb / world_size. Values <= 0 disable bucketing. (Default: 25).

  • min_num_params (int) – Number of parameters to wrap when using FairScale auto_wrap. (Default: 1e8)

  • state_dict_to_cpu (bool) – Whether to return parameters (returned by state_dict()) on CPU device. If False, this will default to compute_device. (Default: True).

connect(model)[source]

Called by the accelerator to connect the accelerator and the model with this plugin.

Return type:

None

lightning_module_state_dict()[source]

Returns model state.

Return type:

Dict[str, Any]

model_sharded_context()[source]

Provide hook to create modules in a distributed aware context. This is useful for when we’d like to shard the model instantly, which is useful for extremely large models which can save memory and initialization time.

Returns: Model parallel context.

Return type:

Generator

model_to_device()[source]

Moves the model to the correct device.

Return type:

None

predict_step(*args, **kwargs)[source]

The actual predict step.

See predict_step() for more details

Return type:

Union[Tensor, Dict[str, Any]]

setup(trainer)[source]

Setup plugins for the trainer fit and creates optimizers.

Parameters:

trainer (Trainer) – the trainer instance

Return type:

None

test_step(*args, **kwargs)[source]

The actual test step.

See test_step() for more details

Return type:

Union[Tensor, Dict[str, Any], None]

training_step(*args, **kwargs)[source]

The actual training step.

See training_step() for more details

Return type:

Union[Tensor, Dict[str, Any]]

validation_step(*args, **kwargs)[source]

The actual validation step.

See validation_step() for more details

Return type:

Union[Tensor, Dict[str, Any], None]