thunder.distributed.fsdp

thunder.distributed.fsdp(model, *, device=None, broadcast_from=None, sharding_strategy=FSDPType.ZERO2, bucketing_strategy=FSDPBucketingStrategy.NONE, move_state_dict_to_cpu=None)[source]

Convert model into Fully Sharded Data Parallel.

This splits model’s parameters in their first dimension into world_size chunks then has rank-i host i-th chunks of them. This means the implementation is different from torch.distributed.fsdp.FullyShardedDataParallel which creates what’s called torch.distributed.fsdp._flat_param.FlatParameter as of https://github.com/pytorch/pytorch/tree/647f14e7. PyTorch however seems to be interested in per-parameter sharding as per https://github.com/pytorch/pytorch/issues/114299.

To apply bucketing of collective communications, specify either LAYER or BucketingStrategy.BLOCK as bucketing_strategy. The latter uses one collective communication, be it AllGather to unshard parameters or ReduceScatter to shard gradients, for one Transformer block. The former users one per layer such as torch.nn.Linear and torch.nn.LayerNorm.

See FSDP Tutorial to see how parameters are sharded across devices and how communications calls are inserted.

Parameters:
Keyword Arguments:
  • device – The corresponding model shard will be moved to this device. We recommend setting this to torch.cuda.current_device().

  • broadcast_from – The rank of the device hosting the parameters to broadcast. If None is passed, broadcasting will be skipped (default). Enabling can be useful for models whose weights have been loaded from a checkpoint in a single rank.

  • sharding_strategy

  • bucketing_strategy

  • move_state_dict_to_cpu

    Move all-gather’ed parameters of original_state_dict() to CPU

    as each all-gather is finished.

    Returns:

    torch.nn.Module

Return type:

Module | ThunderModule