thunder.distributed.column_parallel

thunder.distributed.column_parallel(thunder_module, target_modules, process_group=None, *, device=None)[source]

Convert specified modules into column-wise parallel ones.

This method has two effects:
  1. Chunks target modules’ parameters in 0-th dimension.

  2. Insert preprocess and postprocess around modified module ops.

Parameters:
Return type:

ThunderModule

Example

import os

import torch
import torch.nn
import torch.nn.functional as F
from torch.distributed import distributed_c10d

import thunder
from thunder.distributed import column_parallel


class Model(nn.Module):
    def __init__(
        self,
        num_embeddings: int,
        embedding_dim: int,
        n_hidden: int,
        n_out: int,
    ) -> None:
        super().__init__()
        self.embed = nn.Embedding(num_embeddings, embedding_dim)
        self.l1 = nn.Linear(embedding_dim, n_hidden)
        self.l2 = nn.Linear(n_hidden, n_out)

    def forward(self, tokens: torch.Tensor) -> torch.Tensor:
        feature = self.embed(tokens)
        h = F.gelu(self.l1(feature), approximate='tanh')
        return self.l2(h)

world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ["LOCAL_RANK"])
device = torch.device(f"cuda:{local_rank}")
distributed_c10d.init_process_group()
model = Model().to(device)
jitted_model = thunder.jit(model)
# `l2`'s output size (= n_out) needs to be divisible by `world_size`
tp_model = column_parallel(
    jitted_model,
    target_modules=("embed", "l2",),
)

x = torch.randn(4, n_in, device=device)
out = tp_model(x)  # shape: [4, n_out]