thunder.distributed.fsdp¶
- thunder.distributed.fsdp(model, *, device=None, broadcast_from=None, sharding_strategy=FSDPType.ZERO2, bucketing_strategy=FSDPBucketingStrategy.NONE, move_state_dict_to_cpu=None)[source]¶
Convert
model
into Fully Sharded Data Parallel.This splits
model
’s parameters in their first dimension intoworld_size
chunks then has rank-i
hosti
-th chunks of them. This means the implementation is different fromtorch.distributed.fsdp.FullyShardedDataParallel
which creates what’s calledtorch.distributed.fsdp._flat_param.FlatParameter
as of https://github.com/pytorch/pytorch/tree/647f14e7. PyTorch however seems to be interested in per-parameter sharding as per https://github.com/pytorch/pytorch/issues/114299.To apply bucketing of collective communications, specify either
LAYER
orBucketingStrategy.BLOCK
asbucketing_strategy
. The latter uses one collective communication, be it AllGather to unshard parameters or ReduceScatter to shard gradients, for one Transformer block. The former users one per layer such astorch.nn.Linear
andtorch.nn.LayerNorm
.See FSDP Tutorial to see how parameters are sharded across devices and how communications calls are inserted.
- Parameters:
model¶ (
Module
|ThunderModule
) – The model to convert.model (torch.nn.modules.module.Module | thunder.core.module.ThunderModule) –
sharding_strategy (FSDPType) –
bucketing_strategy (FSDPBucketingStrategy) –
- Keyword Arguments:
device¶ – The corresponding model shard will be moved to this device. We recommend setting this to
torch.cuda.current_device()
.broadcast_from¶ – The rank of the device hosting the parameters to broadcast. If None is passed, broadcasting will be skipped (default). Enabling can be useful for models whose weights have been loaded from a checkpoint in a single rank.
sharding_strategy¶ –
bucketing_strategy¶ –
move_state_dict_to_cpu¶ –
- Move all-gather’ed parameters of
original_state_dict()
to CPU as each all-gather is finished.
- Returns:
- Move all-gather’ed parameters of
- Return type: