Hi,
I would like to implement UDA (Unsupervised Data Augmentation, a paper by Google) wiht PyTorch Lightning. It already works with native PyTorch, but I wonder how I can make it work with PL. Compared to a regular training loop, UDA takes in a supervised batch and an unsupervised batch in each forward pass. Moreover, the batch sizes are different (I’m using 8 for the supervised batch and 24 for the unsupervised batch). So I have two dataloaders (sup_iter
and unsup_iter
).
The guide explaining how to use multiple training dataloaders is really helpful. However, it doesn’t say anything about the training step. Should this be updated accordingly?
Code:
from pytorch_lightning.core.lightning import LightningModule
from transformers import BertForSequenceClassification, AdamW
import torch.nn as nn
class LightningUDA(LightningModule):
def __init__(self):
super().__init__()
self.model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=len(selected_classes))
self.sup_criterion = nn.CrossEntropy()
def forward(self, input_ids, attention_mask, token_type_ids):
outputs = self.model(input_ids, attention_mask, token_type_ids)
return outputs
def training_step(self, batch):
sup_batch = batch['sup_batch']
unsup_batch = batch['unsup_batch']
(...)
return final_loss
def train_dataloader(self):
# taken from https://pytorch-lightning.readthedocs.io/en/stable/advanced/multiple_loaders.html
loaders = {'sup_batch': sup_iter, 'unsup_batch': unsup_iter}
return loaders
Is this the correct way to do it? Or should the training_step accept the keys defined in the train_dataloader
method?