thunder.distributed.row_parallel¶
- thunder.distributed.row_parallel(thunder_module, target_modules, process_group=None, *, device=None)[source]¶
Convert specified modules into row-wise parallel ones.
- This method has two effects:
Chunks target modules’ parameters in 1st dimension.
Insert preprocess and postprocess around modified module ops.
- Parameters:
thunder_module¶ (
ThunderModule
) –target_modules¶ (
Sequence
[str
]) – Names of modules to convert into row-wise.thunder_module (ThunderModule) –
target_modules (Sequence[str]) –
process_group (ProcessGroup | None) –
device (torch.device | None) –
- Return type:
Example
import os import torch import torch.nn import torch.nn.functional as F from torch.distributed import distributed_c10d import thunder from thunder.distributed import row_parallel class Model(nn.Module): def __init__( self, num_embeddings: int, embedding_dim: int, n_hidden: int, n_out: int, ) -> None: super().__init__() self.embed = nn.Embedding(num_embeddings, embedding_dim) self.l1 = nn.Linear(embedding_dim, n_hidden) self.l2 = nn.Linear(n_hidden, n_out) def forward(self, tokens: torch.Tensor) -> torch.Tensor: feature = self.embed(tokens) h = F.gelu(self.l1(feature), approximate='tanh') return self.l2(h) world_size = int(os.environ["WORLD_SIZE"]) local_rank = int(os.environ["LOCAL_RANK"]) device = torch.device(f"cuda:{local_rank}") distributed_c10d.init_process_group() model = Model().to(device) jitted_model = thunder.jit(model) # ``embedding_dim`` and `l2`'s input size (= n_hidden) need to be divisible by `world_size` tp_model = column_parallel( jitted_model, target_modules=("embed", "l2",), ) x = torch.randn(4, n_in, device=device) out = tp_model(x) # shape: [4, n_out]