thunder.distributed.ddp¶
- thunder.distributed.ddp(model, *, broadcast_from=0, bucket_size_in_mb=25.0)[source]¶
Thunder’s Distributed Data Parallel.
This function does two things. One is to broadcast the parameters hosted on the rank specified by
broadcast_from
to all the other ranks belonging to default process_group. The other is to update the behavior of backward trace generation and optimization of it so that each gradient gets pre-averaged, i.e., divided by world size, and asynchronously all-reduced.- Parameters:
- Keyword Arguments:
- Return type:
ddp_example.py¶1# $ torchrun --nproc-per-node=<N_GPU> ddp_example.py 2import os 3import math 4 5import torch 6import torch.distributed as tdist 7import torch.nn as nn 8import torch.nn.functional as F 9 10import thunder 11import thunder.distributed as dist 12 13 14LOCAL_RANK = int(os.environ["LOCAL_RANK"]) 15BATCH_SIZE = 8 16IN_FEATURES = 32 17OUT_FEATURES = 64 18N_CLASSES = 4 19 20 21def get_batch() -> tuple[torch.Tensor, torch.Tensor]: 22 x = torch.randn(BATCH_SIZE, IN_FEATURES, device=f"cuda:{LOCAL_RANK}", requires_grad=True) 23 y = torch.randn(BATCH_SIZE, N_CLASSES, device=f"cuda:{LOCAL_RANK}").softmax(dim=1).requires_grad_() 24 return x, y 25 26 27def new_gelu(a: torch.Tensor): 28 return 0.5 * a * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (a + 0.044715 * torch.pow(a, 3.0)))) 29 30 31class MyModel(nn.Module): 32 def __init__(self) -> None: 33 super().__init__() 34 self.l1 = nn .Linear(IN_FEATURES, OUT_FEATURES) 35 self.l2 = nn.Linear(OUT_FEATURES, N_CLASSES) 36 37 def forward(self, x: torch.Tensor) -> torch.Tensor: 38 h = new_gelu(self.l1(x)) 39 return self.l2(h) 40 41 42def main(): 43 tdist.init_process_group(backend="nccl") 44 45 model = MyModel().to(LOCAL_RANK) 46 compiled = dist.ddp(thunder.jit(model)) 47 optimizer = torch.optim.AdamW(compiled.parameters()) 48 losses = [] 49 loss_all_reduce_workers = [] 50 51 for _ in range(10): 52 optimizer.zero_grad() 53 x, y = get_batch() 54 out = compiled(x) 55 loss = F.cross_entropy(y, out) 56 loss.backward() 57 optimizer.step() 58 with torch.no_grad(): 59 losses.append(loss.detach()) 60 loss_all_reduce_workers.append(tdist.all_reduce(losses[-1], op=tdist.ReduceOp.AVG, async_op=True)) 61 62 if LOCAL_RANK == 0: 63 for i, (loss, worker) in enumerate(zip(losses, loss_all_reduce_workers)): 64 assert worker.wait() 65 print(f"# {i}-th loss: {loss.item()}") 66 67 68if __name__ == "__main__": 69 main()