thunder.distributed.fsdp¶
- thunder.distributed.fsdp(model, *, device=None, broadcast_from=None, sharding_strategy=FSDPType.ZERO2, bucketing_strategy=FSDPBucketingStrategy.NONE, move_state_dict_to_cpu=None)[source]¶
Convert
modelinto Fully Sharded Data Parallel.This splits
model’s parameters in their first dimension intoworld_sizechunks then has rank-ihosti-th chunks of them. This means the implementation is different fromtorch.distributed.fsdp.FullyShardedDataParallelwhich creates what’s calledtorch.distributed.fsdp._flat_param.FlatParameteras of https://github.com/pytorch/pytorch/tree/647f14e7. PyTorch however seems to be interested in per-parameter sharding as per https://github.com/pytorch/pytorch/issues/114299.To apply bucketing of collective communications, specify either
LAYERorBucketingStrategy.BLOCKasbucketing_strategy. The latter uses one collective communication, be it AllGather to unshard parameters or ReduceScatter to shard gradients, for one Transformer block. The former users one per layer such astorch.nn.Linearandtorch.nn.LayerNorm.See FSDP Tutorial to see how parameters are sharded across devices and how communications calls are inserted.
- Parameters:
model¶ (
Module|ThunderModule) – The model to convert.model (torch.nn.modules.module.Module | thunder.core.module.ThunderModule) –
sharding_strategy (FSDPType) –
bucketing_strategy (FSDPBucketingStrategy) –
- Keyword Arguments:
device¶ – The corresponding model shard will be moved to this device. We recommend setting this to
torch.cuda.current_device().broadcast_from¶ – The rank of the device hosting the parameters to broadcast. If None is passed, broadcasting will be skipped (default). Enabling can be useful for models whose weights have been loaded from a checkpoint in a single rank.
sharding_strategy¶ –
bucketing_strategy¶ –
move_state_dict_to_cpu¶ –
- Move all-gather’ed parameters of
original_state_dict()to CPU as each all-gather is finished.
- Returns:
- Move all-gather’ed parameters of
- Return type: