How to define training_step when using multiple training dataloaders

Hi,

I would like to implement UDA (Unsupervised Data Augmentation, a paper by Google) wiht PyTorch Lightning. It already works with native PyTorch, but I wonder how I can make it work with PL. Compared to a regular training loop, UDA takes in a supervised batch and an unsupervised batch in each forward pass. Moreover, the batch sizes are different (I’m using 8 for the supervised batch and 24 for the unsupervised batch). So I have two dataloaders (sup_iter and unsup_iter).

The guide explaining how to use multiple training dataloaders is really helpful. However, it doesn’t say anything about the training step. Should this be updated accordingly?

Code:

from pytorch_lightning.core.lightning import LightningModule
from transformers import BertForSequenceClassification, AdamW
import torch.nn as nn

class LightningUDA(LightningModule):

    def __init__(self):
        super().__init__()
        self.model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=len(selected_classes))
        self.sup_criterion = nn.CrossEntropy()


    def forward(self, input_ids, attention_mask, token_type_ids):     
        outputs = self.model(input_ids, attention_mask, token_type_ids)
        return outputs

    def training_step(self, batch):
        sup_batch = batch['sup_batch']
        unsup_batch = batch['unsup_batch']

        (...)
        return final_loss

def train_dataloader(self):
        # taken from https://pytorch-lightning.readthedocs.io/en/stable/advanced/multiple_loaders.html
        loaders = {'sup_batch': sup_iter, 'unsup_batch': unsup_iter}
        return loaders

Is this the correct way to do it? Or should the training_step accept the keys defined in the train_dataloader method?