Training sharded HuggingFace models on multiple GPUs (DeepSpeed)

Hi there! I have three questions regarding my project that have stumped me for a bit, and possibly they might be related to one another.

  1. I am trying to train a BART model from the HuggingFace repository and I am encountering some issues. After starting the training session, I can observe that every odd-numbered GPU (1, 3, etc.) is constantly at 0% usage. Am I not initializing the model properly (see code below)?

  2. I cannot seem to properly shard the model even with DeepSpeed stage 3. The model has around 45 million parameters, and if I try to increase the batch size of the model to 256 instead of 128 it goes OOM. I have tried to implement another suggestion I’ve seen earlier in the forums here, but still not luck.

  3. How does model compilation (introduced in PyTorch 2.0) work for sharded models? Would torch.compile do anything for my code below? Before, I was writing it right after initializing the model in init, but I’ve seen a recommendation do to it after creating the trainer:

trainer = Trainer(...)
torch.compile(model)
trainer.fit(model, datamodule=dm)

I didn’t really observe much difference between the two though.

Thanks in advance for any help!

Log for 4 GPU run:

[2023-11-04 14:20:45,614] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2023-11-04 14:20:45,615] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
initializing deepspeed distributed: GLOBAL_RANK: 1, MEMBER: 2/2
initializing deepspeed distributed: GLOBAL_RANK: 0, MEMBER: 1/2
Enabling DeepSpeed FP16. Model parameters and inputs will be cast to `float16`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
Using /gpfs/home1/avoinea/.cache/torch_extensions/py310_cu121 as PyTorch extensions root...
Using /gpfs/home1/avoinea/.cache/torch_extensions/py310_cu121 as PyTorch extensions root...
Emitting ninja build file /gpfs/home1/avoinea/.cache/torch_extensions/py310_cu121/cpu_adam/build.ninja...
Building extension module cpu_adam...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
e[93m [WARNING] e[0m cpu_adam cuda is missing or is incompatible with installed torch, only cpu ops can be compiled!
ninja: no work to do.
Loading extension module cpu_adam...
Loading extension module cpu_adam...

  | Name  | Type                         | Params | Params per Device
---------------------------------------------------------------------------
0 | model | BartForConditionalGeneration | 45.1 M | 22.6 M           
---------------------------------------------------------------------------
45.1 M    Trainable params
0         Non-trainable params
45.1 M    Total params
180.445   Total estimated model params size (MB)
Time to load cpu_adam op: 2.611788034439087 seconds
Parameter Offload: Total persistent parameters: 100352 in 160 params
e[93m [WARNING] e[0m cpu_adam cuda is missing or is incompatible with installed torch, only cpu ops can be compiled!
Time to load cpu_adam op: 2.555509090423584 seconds
/gpfs/home1/avoinea/linguamol/.venv/lib/python3.10/site-packages/deepspeed/runtime/zero/stage3.py:1286: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:83.)
  total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])
/gpfs/home1/avoinea/linguamol/.venv/lib/python3.10/site-packages/deepspeed/runtime/zero/stage3.py:1286: UserWarning: The torch.cuda.*DtypeTensor constructors are no longer recommended. It's best to use methods such as torch.tensor(data, dtype=*, device='cuda') to create tensors. (Triggered internally at ../torch/csrc/tensor/python_tensor.cpp:83.)
  total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)])

Report of nvidia-smi for only 2 GPU run (similar to 4 GPU, where every even one has 99% usage):

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.12             Driver Version: 535.104.12   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A100-SXM4-40GB          On  | 00000000:CA:00.0 Off |                  Off |
| N/A   48C    P0             256W / 400W |  30320MiB / 40960MiB |     99%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-SXM4-40GB          On  | 00000000:E3:00.0 Off |                  Off |
| N/A   30C    P0              50W / 400W |      8MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
                                                                                         
+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A   1704411      C   .../user/project/.venv/bin/python    30306MiB |
+---------------------------------------------------------------------------------------+

Code for the Lightning module:

from transformers import BartForConditionalGeneration, BartConfig
from transformers import get_cosine_schedule_with_warmup
from deepspeed.ops.adam import DeepSpeedCPUAdam
import lightning as L
import torch

class BartForMoleculeGeneration(L.LightningModule):
    def __init__(self, config: "DictConfig"):
        super().__init__()
        self.save_hyperparameters(config)
        self.config = config
        self.model = None

    def configure_sharded_model(self) -> None:
        model_config = BartConfig.from_dict(self.config.model)
        self.model = BartForConditionalGeneration(model_config)

    def training_step(self, batch, batch_idx):
        input_ids, attention_mask, labels = batch

        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )

        loss = outputs.loss
        self.log("train/loss", loss, prog_bar=True, logger=True, sync_dist=True)

        sch = self.lr_schedulers()
        if (batch_idx + 1) % 5 == 0:
            sch.step()

        return loss
  
   ...
   def configure_optimizers(self):
        optimizer = DeepSpeedCPUAdam(self.model.parameters(), lr=3e-4)
        lr_scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=5000)

        return [optimizer], [lr_scheduler]

This is my trainer configuration:

...

train_strategy = DeepSpeedStrategy(
        logging_batch_size_per_gpu=128,
        stage=2,
        offload_optimizer=True,
        offload_parameters=False,
        allgather_bucket_size=5e8,
        reduce_bucket_size=5e8,
    )

    trainer = L.Trainer(
        accelerator="gpu",
        strategy=train_strategy,
        precision="16-mixed",
        devices=2,
        num_nodes=1,
        max_epochs=5,
        log_every_n_steps=5,
        logger=wandb_logger,
        enable_progress_bar=False,
        callbacks=[checkpoint_callback, lr_monitor, early_stopping],
    )

Disregard this. As it was pointed to me on GitHub by a helpful dev, I was using the wrong sbatch configuration.
Instead of using:

#SBATCH --nodes=1
#SBATCH --gpus=4

Make sure you use:

#SBATCH --nodes=1
#SBATCH --gpus=4
#SBATCH --ntasks-per-node=4