Lightning Trainer works on one gpu but OOM on more

I’m trying to run a training for Llama-2-13b-chat quantized on a system that has 3 3090s. When I run it on just one gpu it works but it’s very slow, so I was hoping to use all of the gpus to parallelize and get a faster training, but when I increase the number of gpus it goes out of memory.

╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /home/michele/MT-finetuning/MT-finetuning/Trainer.py:130 in <module>                             │
│                                                                                                  │
│   127 train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_siz   │
│   128 test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size,   │
│   129                                                                                            │
│ ❱ 130 trainer.fit(model, train_data_loader, test_data_loader)                                    │
│   131                                                                                            │
│                                                                                                  │
│ /home/michele/miniconda3/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:540   │
│ in fit                                                                                           │
│                                                                                                  │
│    537 │   │   model = _maybe_unwrap_optimized(model)                                            │
│    538 │   │   self.strategy._lightning_module = model                                           │
│    539 │   │   _verify_strategy_supports_compile(model, self.strategy)                           │
│ ❱  540 │   │   call._call_and_handle_interrupt(                                                  │
│    541 │   │   │   self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule,  │
│    542 │   │   )                                                                                 │
│    543                                                                                           │
│                                                                                                  │
│ /home/michele/miniconda3/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:42 in    │
│ _call_and_handle_interrupt                                                                       │
│                                                                                                  │
│    39 │   """                                                                                    │
│    40 │   try:                                                                                   │
│    41 │   │   if trainer.strategy.launcher is not None:                                          │
│ ❱  42 │   │   │   return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer,    │
│    43 │   │   return trainer_fn(*args, **kwargs)                                                 │
│    44 │                                                                                          │
│    45 │   except _TunerExitException:                                                            │
│                                                                                                  │
│ /home/michele/miniconda3/lib/python3.10/site-packages/lightning/pytorch/strategies/launchers/sub │
│ process_script.py:91 in launch                                                                   │
│                                                                                                  │
│    88 │   │   """                                                                                │
│    89 │   │   if not self.cluster_environment.creates_processes_externally:                      │
│    90 │   │   │   self._call_children_scripts()                                                  │
│ ❱  91 │   │   return function(*args, **kwargs)                                                   │
│    92 │                                                                                          │
│    93 │   def kill(self, signum: _SIGNUM) -> None:                                               │
│    94 │   │   for proc in self.procs:                                                            │
│                                                                                                  │
│ /home/michele/miniconda3/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:579   │
│ in _fit_impl                                                                                     │
│                                                                                                  │
│    576 │   │   │   model_provided=True,                                                          │
│    577 │   │   │   model_connected=self.lightning_module is not None,                            │
│    578 │   │   )                                                                                 │
│ ❱  579 │   │   self._run(model, ckpt_path=ckpt_path)                                             │
│    580 │   │                                                                                     │
│    581 │   │   assert self.state.stopped                                                         │
│    582 │   │   self.training = False                                                             │
│                                                                                                  │
│ /home/michele/miniconda3/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:961   │
│ in _run                                                                                          │
│                                                                                                  │
│    958 │   │   self._logger_connector.reset_metrics()                                            │
│    959 │   │                                                                                     │
│    960 │   │   # strategy will configure model and move it to the device                         │
│ ❱  961 │   │   self.strategy.setup(self)                                                         │
│    962 │   │                                                                                     │
│    963 │   │   # hook                                                                            │
│    964 │   │   if self.state.fn == TrainerFn.FITTING:                                            │
│                                                                                                  │
│ /home/michele/miniconda3/lib/python3.10/site-packages/lightning/pytorch/strategies/ddp.py:155 in │
│ setup                                                                                            │
│                                                                                                  │
│   152 │   │   self.accelerator.setup(trainer)                                                    │
│   153 │   │                                                                                      │
│   154 │   │   # move the model to the correct device                                             │
│ ❱ 155 │   │   self.model_to_device()                                                             │
│   156 │   │                                                                                      │
│   157 │   │   # skip wrapping the model if we are not fitting as no gradients need to be excha   │
│   158 │   │   trainer_fn = trainer.state.fn                                                      │
│                                                                                                  │
│ /home/michele/miniconda3/lib/python3.10/site-packages/lightning/pytorch/strategies/ddp.py:312 in │
│ model_to_device                                                                                  │
│                                                                                                  │
│   309 │   def model_to_device(self) -> None:                                                     │
│   310 │   │   log.debug(f"{self.__class__.__name__}: moving model to device [{self.root_device   │
│   311 │   │   assert self.model is not None                                                      │
│ ❱ 312 │   │   self.model.to(self.root_device)                                                    │
│   313 │                                                                                          │
│   314 │   def reduce(                                                                            │
│   315 │   │   self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[Red   │
│                                                                                                  │
│ /home/michele/miniconda3/lib/python3.10/site-packages/lightning/fabric/utilities/device_dtype_mi │
│ xin.py:54 in to                                                                                  │
│                                                                                                  │
│    51 │   │   # this converts `str` device to `torch.device`                                     │
│    52 │   │   device, dtype = torch._C._nn._parse_to(*args, **kwargs)[:2]                        │
│    53 │   │   self.__update_properties(device=device, dtype=dtype)                               │
│ ❱  54 │   │   return super().to(*args, **kwargs)                                                 │
│    55 │                                                                                          │
│    56 │   def cuda(self, device: Optional[Union[torch.device, int]] = None) -> Self:             │
│    57 │   │   """Moves all model parameters and buffers to the GPU. This also makes associated   │
│                                                                                                  │
│ /home/michele/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1160 in to      │
│                                                                                                  │
│   1157 │   │   │   │   │   │   │   non_blocking, memory_format=convert_to_format)                │
│   1158 │   │   │   return t.to(device, dtype if t.is_floating_point() or t.is_complex() else No  │
│   1159 │   │                                                                                     │
│ ❱ 1160 │   │   return self._apply(convert)                                                       │
│   1161 │                                                                                         │
│   1162 │   def register_full_backward_pre_hook(                                                  │
│   1163 │   │   self,                                                                             │
│                                                                                                  │
│ /home/michele/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:810 in _apply   │
│                                                                                                  │
│    807 │   def _apply(self, fn, recurse=True):                                                   │
│    808 │   │   if recurse:                                                                       │
│    809 │   │   │   for module in self.children():                                                │
│ ❱  810 │   │   │   │   module._apply(fn)                                                         │
│    811 │   │                                                                                     │
│    812 │   │   def compute_should_use_set_data(tensor, tensor_applied):                          │
│    813 │   │   │   if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):           │
│                                                                                                  │
│ /home/michele/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:810 in _apply   │
│                                                                                                  │
│    807 │   def _apply(self, fn, recurse=True):                                                   │
│    808 │   │   if recurse:                                                                       │
│    809 │   │   │   for module in self.children():                                                │
│ ❱  810 │   │   │   │   module._apply(fn)                                                         │
│    811 │   │                                                                                     │
│    812 │   │   def compute_should_use_set_data(tensor, tensor_applied):                          │
│    813 │   │   │   if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):           │
│                                                                                                  │
│ /home/michele/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:810 in _apply   │
│                                                                                                  │
│    807 │   def _apply(self, fn, recurse=True):                                                   │
│    808 │   │   if recurse:                                                                       │
│    809 │   │   │   for module in self.children():                                                │
│ ❱  810 │   │   │   │   module._apply(fn)                                                         │
│    811 │   │                                                                                     │
│    812 │   │   def compute_should_use_set_data(tensor, tensor_applied):                          │
│    813 │   │   │   if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):           │
│                                                                                                  │
│ /home/michele/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:810 in _apply   │
│                                                                                                  │
│    807 │   def _apply(self, fn, recurse=True):                                                   │
│    808 │   │   if recurse:                                                                       │
│    809 │   │   │   for module in self.children():                                                │
│ ❱  810 │   │   │   │   module._apply(fn)                                                         │
│    811 │   │                                                                                     │
│    812 │   │   def compute_should_use_set_data(tensor, tensor_applied):                          │
│    813 │   │   │   if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):           │
│                                                                                                  │
│ /home/michele/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:810 in _apply   │
│                                                                                                  │
│    807 │   def _apply(self, fn, recurse=True):                                                   │
│    808 │   │   if recurse:                                                                       │
│    809 │   │   │   for module in self.children():                                                │
│ ❱  810 │   │   │   │   module._apply(fn)                                                         │
│    811 │   │                                                                                     │
│    812 │   │   def compute_should_use_set_data(tensor, tensor_applied):                          │
│    813 │   │   │   if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):           │
│                                                                                                  │
│ /home/michele/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:810 in _apply   │
│                                                                                                  │
│    807 │   def _apply(self, fn, recurse=True):                                                   │
│    808 │   │   if recurse:                                                                       │
│    809 │   │   │   for module in self.children():                                                │
│ ❱  810 │   │   │   │   module._apply(fn)                                                         │
│    811 │   │                                                                                     │
│    812 │   │   def compute_should_use_set_data(tensor, tensor_applied):                          │
│    813 │   │   │   if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):           │
│                                                                                                  │
│ /home/michele/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:810 in _apply   │
│                                                                                                  │
│    807 │   def _apply(self, fn, recurse=True):                                                   │
│    808 │   │   if recurse:                                                                       │
│    809 │   │   │   for module in self.children():                                                │
│ ❱  810 │   │   │   │   module._apply(fn)                                                         │
│    811 │   │                                                                                     │
│    812 │   │   def compute_should_use_set_data(tensor, tensor_applied):                          │
│    813 │   │   │   if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):           │
│                                                                                                  │
│ /home/michele/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:810 in _apply   │
│                                                                                                  │
│    807 │   def _apply(self, fn, recurse=True):                                                   │
│    808 │   │   if recurse:                                                                       │
│    809 │   │   │   for module in self.children():                                                │
│ ❱  810 │   │   │   │   module._apply(fn)                                                         │
│    811 │   │                                                                                     │
│    812 │   │   def compute_should_use_set_data(tensor, tensor_applied):                          │
│    813 │   │   │   if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):           │
│                                                                                                  │
│ /home/michele/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:833 in _apply   │
│                                                                                                  │
│    830 │   │   │   # track autograd history of `param_applied`, so we have to use                │
│    831 │   │   │   # `with torch.no_grad():`                                                     │
│    832 │   │   │   with torch.no_grad():                                                         │
│ ❱  833 │   │   │   │   param_applied = fn(param)                                                 │
│    834 │   │   │   should_use_set_data = compute_should_use_set_data(param, param_applied)       │
│    835 │   │   │   if should_use_set_data:                                                       │
│    836 │   │   │   │   param.data = param_applied                                                │
│                                                                                                  │
│ /home/michele/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1158 in convert │
│                                                                                                  │
│   1155 │   │   │   if convert_to_format is not None and t.dim() in (4, 5):                       │
│   1156 │   │   │   │   return t.to(device, dtype if t.is_floating_point() or t.is_complex() els  │
│   1157 │   │   │   │   │   │   │   non_blocking, memory_format=convert_to_format)                │
│ ❱ 1158 │   │   │   return t.to(device, dtype if t.is_floating_point() or t.is_complex() else No  │
│   1159 │   │                                                                                     │
│   1160 │   │   return self._apply(convert)                                                       │
│   1161                                                                                           │
│                                                                                                  │
│ /home/michele/miniconda3/lib/python3.10/site-packages/bitsandbytes/nn/modules.py:336 in to       │
│                                                                                                  │
│   333 │   │   │   return self.cuda(device)                                                       │
│   334 │   │   else:                                                                              │
│   335 │   │   │   new_param = Int8Params(                                                        │
│ ❱ 336 │   │   │   │   super().to(                                                                │
│   337 │   │   │   │   │   device=device, dtype=dtype, non_blocking=non_blocking                  │
│   338 │   │   │   │   ),                                                                         │
│   339 │   │   │   │   requires_grad=self.requires_grad,                                          │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
OutOfMemoryError: CUDA out of memory. Tried to allocate 68.00 MiB. GPU 2 has a total capacty of 23.70 GiB of which 70.69 MiB is free. Process 489946 has 5.36 GiB memory in use. Including non-PyTorch memory, this process has 12.90 GiB 
memory in use. Process 490341 has 5.36 GiB memory in use. Of the allocated memory 11.89 GiB is allocated by PyTorch, and 197.41 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting 
max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

I found a similar issue but the suggestion in that case was to switch from DataParallel to DistributedDataParallel, which I’m already using. Here is the trainer

trainer = L.Trainer(
    devices=args.num_devices,
    accelerator="gpu",
    strategy="ddp",
    #strategy=DeepSpeedStrategy(
    #     stage=3,
    #     offload_optimizer=True,
    #     offload_parameters=True,
    # ),
    max_steps=10,
    #precision="16-mixed",
    enable_checkpointing=True,
    accumulate_grad_batches=args.accumulate_grad_batches,
    log_every_n_steps=20,
    val_check_interval=30,
    limit_val_batches=200,
    #default_root_dir="checkpoints",
    callbacks=[checkpoint_callback],
    gradient_clip_val=args.gradient_clip_val,
    logger=mlflow_logger,
)
print(args.quantized)
print(quantization_config)
print(type(args.peft),args.peft)
quantization_config = None if not args.quantized else quantization_config
load_in_8bit = False if not quantization_config else True
model = MTModel(model_name=args.model_name, pad_token_id=tokenizer.pad_token_id, inference=False,
                 learning_rate=args.learning_rate, weight_decay=args.weight_decay, 
                 betas=args.betas, quantization_config=quantization_config,peft=args.peft, load_in_8bit=load_in_8bit)

# Note that the training effective batch size is num_nodes * num_gpus * batch_size * accumulate_grad_batches
train_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=32)
test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=32)

trainer.fit(model, train_data_loader, test_data_loader)

and here is the model code

class MTModel(L.LightningModule):
    def __init__(self, model_name: str, pad_token_id: int,inference: bool, 
                 learning_rate: float = 1e-4, weight_decay: float = 0.0, 
                 betas: tuple = (0.9, 0.95), quantization_config:LoraConfig = None,
                 load_in_8bit:bool=False,peft:bool=False):
        super().__init__()
        self.save_hyperparameters()
        
        self.model = AutoModelForCausalLM.from_pretrained(model_name,trust_remote_code=True,device_map="auto",
                                                          quantization_config=quantization_config,load_in_8bit=load_in_8bit)
        
        #self.model = dispatch_model(self.model,device_map)
        if quantization_config: self.model = prepare_model_for_int8_training(self.model)
        #self.model = self.model.to(self.device)
        if peft:
            peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=inference, r=4, lora_alpha=4, lora_dropout=0.01)
            self.model = get_peft_model(self.model, peft_config)
            
        

    def training_step(self, batch, batch_idx):
        get_accelerator().empty_cache()
        input_ids, target_start_idx = batch
        #print(self.device)
        #print(input_ids.get_device())
        logits = self.model(input_ids).logits
        loss = mt_loss(logits, input_ids, target_start_idx, self.hparams.pad_token_id)
        self.log("train_loss", loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        input_ids, target_start_idx = batch
        #print(self.device)
        #breakpoint()
        #print(input_ids.get_device())
        logits = self.model(input_ids).logits
        loss = mt_loss(logits, input_ids, target_start_idx, self.hparams.pad_token_id)
        self.log("val_loss", loss, sync_dist=True)
        self.log("val_ppl", torch.exp(loss), sync_dist=True)
        return loss

    def generate(self, batch, **kwargs):
        return self.model.generate(batch,**kwargs)
  
    def configure_optimizers(self):
        return DeepSpeedCPUAdam(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay, betas=self.hparams.betas)

Can you try initializing the model inside configure_model?

Example:

class LightningGPTModule(L.LightningModule):
    def __init__(self, config: Config) -> None:
        super().__init__()
        self.config = config
        self.module: Optional[torch.nn.Module] = None
        self.measured_flops: Optional[int] = None

    def configure_model(self) -> None:
        self.module = GPT(self.config)
        self.module.apply(self.module._init_weights)