thunder.plugins.FSDP¶
- class thunder.plugins.FSDP(device=None, broadcast_from=None, sharding_strategy=FSDPType.ZERO2, bucketing_strategy=FSDPBucketingStrategy.NONE, move_state_dict_to_cpu=False, ddp_bucket_size_in_mb=25.0, process_group=None)[source]¶
Bases:
Plugin
Plugin for enabling Fully Sharded Data Parallel (FSDP) training in Thunder.
This plugin shards model parameters across workers using PyTorch’s FSDP API, optionally combined with DDP for grouped communication in multi-dimensional meshes. It handles initialization broadcasts, parameter materialization, and state dict management according to specified sharding and bucketing strategies.
See https://github.com/pytorch/pytorch/blob/v2.7.0/torch/distributed/fsdp/fully_sharded_data_parallel.py#L117 for more details.
- Parameters:
device¶ (
Optional
[device
]) – torch.device | None, default None Device on which to place sharded modules. If None, modules remain on their existing devices.broadcast_from¶ (
Optional
[int
]) – int | None, default None Global rank ID to broadcast parameters from before sharding. If None, no broadcast is performed.sharding_strategy¶ (
FSDPType
) – FSDPType, default FSDPType.ZERO2 Strategy for parameter sharding (e.g., ZERO2 for sharding both parameters and optimizer state).bucketing_strategy¶ (
FSDPBucketingStrategy
) – FSDPBucketingStrategy, default FSDPBucketingStrategy.NONE Bucketing strategy to use when saving or loading FSDP checkpoints.move_state_dict_to_cpu¶ (
bool
) – bool, default False Whether to move the state dict parameters to CPU after serialization to reduce GPU memory usage.ddp_bucket_size_in_mb¶ (
float
) – float, default 25.0 Bucket size in megabytes for the DDP transform when used in a combined mesh with FSDP.process_group¶ – Optional[ProcessGroup or DeviceMesh], default is the current default process group The process group or device mesh to use for distributed communication. If None, uses the default process group.
device (torch.device | None) –
broadcast_from (int | None) –
sharding_strategy (FSDPType) –
bucketing_strategy (FSDPBucketingStrategy) –
move_state_dict_to_cpu (bool) –
ddp_bucket_size_in_mb (float) –
- __init__(device=None, broadcast_from=None, sharding_strategy=FSDPType.ZERO2, bucketing_strategy=FSDPBucketingStrategy.NONE, move_state_dict_to_cpu=False, ddp_bucket_size_in_mb=25.0, process_group=None)[source]¶
Methods
__init__
([device, broadcast_from, ...])setup_executors
()setup_lookasides
()Constructs the list of graph-level transforms.
Attributes
policy