I have been trying to pretrain wav2vec2 using huggingface transformers.
In the first try, I used single GPU(Tesla V100 16GB) and found that I can train the wav2vec2 of 94.5 million parameters with only 3 batches. After then I tried using DDP over 4 Tesla V100 GPUs on the same server.
According to my knowledge, the minimum batch size should be 12 to set so that it could divide 3 batch for each process on the GPU, but by doing that I was getting OOM error on the machine. I started decreasing the batch size to get maximum batch size that can be used for training on the server, and that resulted out to be 3.
I am confused, does this mean that every GPU is using the same batch of data with model to train it.
Here’s my code for the training:
import torch
from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForPreTraining, Wav2Vec2Config
from transformers.models.wav2vec2.modeling_wav2vec2 import _compute_mask_indices
from torch.utils.data import Dataset, DataLoader
import librosa, os
import numpy as np
import pandas as pd
from glob import glob
from torch import nn
from torch.nn.utils.rnn import pad_sequence
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import pytorch_lightning as pl
class Audios(Dataset):
def __init__(self, data):
super(Audios, self).__init__()
self.data = data
self.feat_extractor = Wav2Vec2FeatureExtractor(feature_size=1,
sampling_rate=16000,
padding_value=0.0,
do_normalize=True,
return_attention_mask=False)
def __getitem__(self, idx):
wav, _ = librosa.load("../../datasets/VoxPopuli/"+ self.data[idx], sr = 16000)
input_values = self.feat_extractor(wav, sampling_rate = 16000).input_values[0]
return input_values
def __len__(self,):
return len(self.data)
class DataModel(pl.LightningDataModule):
def __init__(self, batch_size):
super(DataModel, self).__init__()
self.bs = batch_size
def collate_fn(self,batch):
batch = [torch.from_numpy(i) for i in batch]
return pad_sequence(batch, batch_first = True)
def setup(self,stage=None):
train_filenames = []
for csv in glob("../../datasets/VoxPopuli/*/train_10.tsv"):
train_filenames.extend(pd.read_csv(csv, sep = '\t').path.values)
train = Audios(train_filenames)
t_sampler = torch.utils.data.distributed.DistributedSampler(train, shuffle=True)
self.train_loader = DataLoader(train, batch_size=self.bs, sampler = t_sampler, collate_fn = self.collate_fn, num_workers = 4)
val_filenames = []
for csv in glob("../../datasets/VoxPopuli/*/dev_10.tsv"):
val_filenames.extend(pd.read_csv(csv, sep = '\t').path.values)
val = Audios(val_filenames)
v_sampler = torch.utils.data.distributed.DistributedSampler(val, shuffle=False)
self.val_loader = DataLoader(val, sampler = v_sampler, batch_size=self.bs, collate_fn = self.collate_fn, num_workers = 4)
def train_dataloader(self):
return self.train_loader
def val_dataloader(self):
return self.val_loader
class Network(pl.LightningModule):
def __init__(self, config, lr):
super(Network, self).__init__()
self.lr = lr
self.model = Wav2Vec2ForPreTraining(conf)
def forward(self, x, mask):
return self.model(x, mask_time_indices=mask)
def training_step(self,batch,idx):
batch_size, raw_sequence_length = batch.shape
seq_len = self.model._get_feat_extract_output_lengths(raw_sequence_length)
mask_time_indices = _compute_mask_indices((batch_size, seq_len), mask_prob=0.2, mask_length=2, device = batch.get_device())
loss = self(batch, mask_time_indices).loss
self.log('loss',loss,on_step = True,on_epoch = True)
return loss
def validation_step(self, batch, idx):
batch_size, raw_sequence_length = batch.shape
seq_len = self.model._get_feat_extract_output_lengths(raw_sequence_length)
mask_time_indices = _compute_mask_indices((batch_size, seq_len), mask_prob=0.2, mask_length=2, device = batch.get_device())
outputs = self(batch, mask_time_indices)
cosine_sim = torch.cosine_similarity(
outputs.projected_states, outputs.projected_quantized_states, dim=-1
).mean()
self.log('similarity',cosine_sim,on_step = True,on_epoch = True, prog_bar = True ,sync_dist=True)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(),
lr=self.lr, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,
max_lr=self.lr,
steps_per_epoch=5241,
epochs=50,
anneal_strategy='linear')
return [optimizer], [scheduler]
if __name__ == "__main__":
conf = Wav2Vec2Config(vocab_size = 16,
activation_dropout = 0.1,
apply_spec_augment = True,
architectures = ["Wav2Vec2ForPreTraining"],
attention_dropout = 0.1,
bos_token_id = 1,
classifier_proj_size = 256,
codevector_dim = 256,
contrastive_logits_temperature = 0.1,
conv_bias = False,
conv_dim = [512, 512, 512, 512, 512, 512, 512],
conv_kernel = [6, 3, 3, 2, 2, 2, 2],
conv_stride = [3, 2, 2, 2, 2, 2, 2],
ctc_loss_reduction = "mean",
ctc_zero_infinity = True,
diversity_loss_weight = 0.1,
do_stable_layer_norm = False,
eos_token_id = 2,
feat_extract_activation = "gelu",
feat_extract_dropout = 0.0,
feat_extract_norm = "group",
feat_proj_dropout = 0.0,
feat_quantizer_dropout = 0.0,
final_dropout = 0.1,
hidden_act = "gelu",
hidden_dropout = 0.1,
hidden_dropout_prob = 0.1,
hidden_size = 768,
initializer_range = 0.02,
intermediate_size = 3072,
layer_norm_eps = 1e-05,
layerdrop = 0.1,
mask_feature_length = 10,
mask_feature_prob = 0.0,
mask_time_length = 10,
mask_time_prob = 0.05,
num_attention_heads = 12,
num_codevector_groups = 2,
num_codevectors_per_group = 320,
num_conv_pos_embedding_groups = 16,
num_conv_pos_embeddings = 128,
num_feat_extract_layers = 7,
num_hidden_layers = 12,
num_negatives = 100,
pad_token_id = 15,
proj_codevector_dim = 256,
use_weighted_layer_sum = False,
)
model = Network(config = conf,
lr = 1e-4
)
dm = DataModel(batch_size = 12)
early_stop_callback = EarlyStopping(
monitor='similarity',
min_delta=0.00,
patience=10,
verbose=False,
mode='max'
)
model_ckpt = ModelCheckpoint(
monitor='similarity',
dirpath='./',
filename='CKPT/wav2vec2/VOX/pretrainig-vox-gpus-{similarity:.2f}',
mode='max')
logger = TensorBoardLogger("CKPT/wav2vec2/VOX", name="x7")
trainer = pl.Trainer(gpus = -1,
max_epochs = 200,
num_sanity_val_steps = 1,
logger = logger,
callbacks = [early_stop_callback, model_ckpt],
# auto_scale_batch_size="binsearch",
profiler = "simple",
# auto_lr_find=True,
strategy = "ddp",
prepare_data_per_node=True
# precision = 16,
)
trainer.fit(model, dm)
model.model.save_pretrained("Vox-Xtra")