thunder.distributed.ddp

thunder.distributed.ddp(model, *, broadcast_from=0, bucket_size_in_mb=25.0)[source]

Thunder’s Distributed Data Parallel.

This function does two things. One is to broadcast the parameters hosted on the rank specified by broadcast_from to all the other ranks belonging to default process_group. The other is to update the behavior of backward trace generation and optimization of it so that each gradient gets pre-averaged, i.e., divided by world size, and asynchronously all-reduced.

Parameters:
  • model (Module) – A model before thunder.jit applied

  • model (Module) –

  • broadcast_from (int | None) –

  • bucket_size_in_mb (float) –

Keyword Arguments:
  • broadcast_from – The rank of the device hosting the parameters to broadcast. If None is passed, broadcasting will be skipped. Skipping can be useful for models whose weights have been loaded from a checkpoint. Defaults to 0.

  • bucket_size_in_mb – Size of a gradient bucket.

Return type:

Module

ddp_example.py
 1# $ torchrun --nproc-per-node=<N_GPU> ddp_example.py
 2import os
 3import math
 4
 5import torch
 6import torch.distributed as tdist
 7import torch.nn as nn
 8import torch.nn.functional as F
 9
10import thunder
11import thunder.distributed as dist
12
13
14LOCAL_RANK = int(os.environ["LOCAL_RANK"])
15BATCH_SIZE = 8
16IN_FEATURES = 32
17OUT_FEATURES = 64
18N_CLASSES = 4
19
20
21def get_batch() -> tuple[torch.Tensor, torch.Tensor]:
22    x = torch.randn(BATCH_SIZE, IN_FEATURES, device=f"cuda:{LOCAL_RANK}", requires_grad=True)
23    y = torch.randn(BATCH_SIZE, N_CLASSES, device=f"cuda:{LOCAL_RANK}").softmax(dim=1).requires_grad_()
24    return x, y
25
26
27def new_gelu(a: torch.Tensor):
28    return 0.5 * a * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (a + 0.044715 * torch.pow(a, 3.0))))
29
30
31class MyModel(nn.Module):
32    def __init__(self) -> None:
33        super().__init__()
34        self.l1 = nn .Linear(IN_FEATURES, OUT_FEATURES)
35        self.l2 = nn.Linear(OUT_FEATURES, N_CLASSES)
36
37    def forward(self, x: torch.Tensor) -> torch.Tensor:
38        h = new_gelu(self.l1(x))
39        return self.l2(h)
40
41
42def main():
43    tdist.init_process_group(backend="nccl")
44
45    model = MyModel().to(LOCAL_RANK)
46    compiled = dist.ddp(thunder.jit(model))
47    optimizer = torch.optim.AdamW(compiled.parameters())
48    losses = []
49    loss_all_reduce_workers = []
50
51    for _ in range(10):
52        optimizer.zero_grad()
53        x, y = get_batch()
54        out = compiled(x)
55        loss = F.cross_entropy(y, out)
56        loss.backward()
57        optimizer.step()
58        with torch.no_grad():
59            losses.append(loss.detach())
60            loss_all_reduce_workers.append(tdist.all_reduce(losses[-1], op=tdist.ReduceOp.AVG, async_op=True))
61
62    if LOCAL_RANK == 0:
63        for i, (loss, worker)  in enumerate(zip(losses, loss_all_reduce_workers)):
64            assert worker.wait()
65            print(f"# {i}-th loss: {loss.item()}")
66
67
68if __name__ == "__main__":
69    main()