Problem with running in DDP

I have been trying to pretrain wav2vec2 using huggingface transformers.

In the first try, I used single GPU(Tesla V100 16GB) and found that I can train the wav2vec2 of 94.5 million parameters with only 3 batches. After then I tried using DDP over 4 Tesla V100 GPUs on the same server.

According to my knowledge, the minimum batch size should be 12 to set so that it could divide 3 batch for each process on the GPU, but by doing that I was getting OOM error on the machine. I started decreasing the batch size to get maximum batch size that can be used for training on the server, and that resulted out to be 3.

I am confused, does this mean that every GPU is using the same batch of data with model to train it.
Here’s my code for the training:

import torch
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForPreTraining, Wav2Vec2Config
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices

from torch.utils.data import Dataset, DataLoader

import librosa, os
import numpy as np
import pandas as pd

from glob import glob
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint


import pytorch_lightning as pl



class Audios(Dataset):
    def __init__(self, data):
        super(Audios, self).__init__()
        
        self.data = data
        self.feat_extractor = Wav2Vec2FeatureExtractor(feature_size=1, 
                                             sampling_rate=16000, 
                                             padding_value=0.0, 
                                             do_normalize=True, 
                                             return_attention_mask=False)
        
    def __getitem__(self, idx):
        wav, _ = librosa.load("../../datasets/VoxPopuli/"+ self.data[idx], sr = 16000)
        input_values = self.feat_extractor(wav, sampling_rate = 16000).input_values[0]
        
        return input_values
    
    def __len__(self,):
        return len(self.data)
        

class DataModel(pl.LightningDataModule):
    def __init__(self, batch_size):
        super(DataModel, self).__init__()
        self.bs = batch_size
    
    def collate_fn(self,batch):
        batch = [torch.from_numpy(i) for i in batch]
        return pad_sequence(batch, batch_first = True)
        
    def setup(self,stage=None):
        train_filenames = []
        for csv in glob("../../datasets/VoxPopuli/*/train_10.tsv"):
            train_filenames.extend(pd.read_csv(csv, sep = '\t').path.values)
        train = Audios(train_filenames)
        t_sampler = torch.utils.data.distributed.DistributedSampler(train, shuffle=True)

        self.train_loader = DataLoader(train, batch_size=self.bs, sampler = t_sampler, collate_fn = self.collate_fn, num_workers = 4)
        
        val_filenames = []
        for csv in glob("../../datasets/VoxPopuli/*/dev_10.tsv"):
            val_filenames.extend(pd.read_csv(csv, sep = '\t').path.values)
        val = Audios(val_filenames)
        v_sampler = torch.utils.data.distributed.DistributedSampler(val, shuffle=False)
        self.val_loader = DataLoader(val, sampler = v_sampler, batch_size=self.bs, collate_fn = self.collate_fn, num_workers = 4)
        
    def train_dataloader(self):
        return self.train_loader

    def val_dataloader(self):
        return self.val_loader

class Network(pl.LightningModule):
    def __init__(self, config, lr):
        super(Network, self).__init__()
        self.lr = lr
        self.model = Wav2Vec2ForPreTraining(conf)
    
    def forward(self, x, mask):
        return self.model(x, mask_time_indices=mask)

    def training_step(self,batch,idx):
        
        batch_size, raw_sequence_length = batch.shape
        seq_len = self.model._get_feat_extract_output_lengths(raw_sequence_length)
        mask_time_indices = _compute_mask_indices((batch_size, seq_len), mask_prob=0.2, mask_length=2, device = batch.get_device()) 
        loss = self(batch, mask_time_indices).loss
        
        self.log('loss',loss,on_step = True,on_epoch = True)
        return loss
    
    def validation_step(self, batch, idx):
        batch_size, raw_sequence_length = batch.shape
        seq_len = self.model._get_feat_extract_output_lengths(raw_sequence_length)
        mask_time_indices = _compute_mask_indices((batch_size, seq_len), mask_prob=0.2, mask_length=2, device = batch.get_device()) 
        outputs = self(batch, mask_time_indices)
        cosine_sim = torch.cosine_similarity(
                 outputs.projected_states, outputs.projected_quantized_states, dim=-1
            ).mean()
        
        self.log('similarity',cosine_sim,on_step = True,on_epoch = True, prog_bar = True ,sync_dist=True)
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(),
                                      lr=self.lr, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,
                                                  max_lr=self.lr,
                                                  steps_per_epoch=5241,
                                                  epochs=50,
                                                  anneal_strategy='linear')
        return [optimizer], [scheduler]

if __name__ == "__main__":
    conf = Wav2Vec2Config(vocab_size = 16,
                      activation_dropout = 0.1,
                      apply_spec_augment = True,
                      architectures = ["Wav2Vec2ForPreTraining"],
                      attention_dropout = 0.1,
                      bos_token_id = 1,
                      classifier_proj_size = 256,
                      codevector_dim = 256,
                      contrastive_logits_temperature = 0.1,
                      conv_bias = False,
                      conv_dim = [512, 512, 512, 512, 512, 512, 512],
                      conv_kernel = [6, 3, 3, 2, 2, 2, 2],
                      conv_stride = [3, 2, 2, 2, 2, 2, 2],
                      ctc_loss_reduction = "mean",
                      ctc_zero_infinity = True,
                      diversity_loss_weight = 0.1,
                      do_stable_layer_norm = False,
                      eos_token_id = 2,
                      feat_extract_activation = "gelu",
                      feat_extract_dropout = 0.0,
                      feat_extract_norm = "group",
                      feat_proj_dropout = 0.0,
                      feat_quantizer_dropout = 0.0,
                      final_dropout = 0.1,
                      hidden_act = "gelu",
                      hidden_dropout = 0.1,
                      hidden_dropout_prob = 0.1,
                      hidden_size = 768,
                      initializer_range = 0.02,
                      intermediate_size = 3072,
                      layer_norm_eps = 1e-05,
                      layerdrop = 0.1,
                      mask_feature_length = 10,
                      mask_feature_prob = 0.0,
                      mask_time_length = 10,
                      mask_time_prob = 0.05,
                      num_attention_heads = 12,
                      num_codevector_groups = 2,
                      num_codevectors_per_group = 320,
                      num_conv_pos_embedding_groups = 16,
                      num_conv_pos_embeddings = 128,
                      num_feat_extract_layers = 7,
                      num_hidden_layers = 12,
                      num_negatives = 100,
                      pad_token_id = 15,
                      proj_codevector_dim = 256,
                      use_weighted_layer_sum = False,
                     )
    
    model = Network(config = conf, 
                lr = 1e-4
               )

    dm = DataModel(batch_size = 12)
    
    early_stop_callback = EarlyStopping(
       monitor='similarity',
       min_delta=0.00,
       patience=10,
       verbose=False,
       mode='max'
    )

    model_ckpt = ModelCheckpoint(
        monitor='similarity',
        dirpath='./',
        filename='CKPT/wav2vec2/VOX/pretrainig-vox-gpus-{similarity:.2f}',
        mode='max')

    logger = TensorBoardLogger("CKPT/wav2vec2/VOX", name="x7")
    
    trainer = pl.Trainer(gpus = -1,
                         max_epochs = 200,
                         num_sanity_val_steps = 1,
                         logger = logger,
                         callbacks = [early_stop_callback, model_ckpt],
                         # auto_scale_batch_size="binsearch",
                         profiler = "simple",
                         # auto_lr_find=True,
                         strategy = "ddp",
                         prepare_data_per_node=True
                         # precision = 16,
                        )
    trainer.fit(model, dm)
    model.model.save_pretrained("Vox-Xtra")