Unexpected keyword argument 'multiprocessing_context'

Hi, I’m seeing the following error ‘multiprocessing_context’ error when calling .fit(model).

I’m using a fairly regular dataset object, and the model is as follows:
pytorch-lightning version 1.2.7

class SimpleModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = SModel()
        
    def forward(self,x):
        #Inference, only X is given, no labels
        return self.model(x)
    
    def train_dataloader(self):
        train_loader = DataLoader(create_dataset(), num_workers=20, batch_size=256)
        return train_loader
    
    def training_step(self, batch, batch_idx):
        x,y = batch
        logits = self.model(x)
        
        loss = F.cross_entropy(logits,y)
        self.log('train_loss', loss)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.parameters(), lr=1e-3)
        return optimizer

main code:

def main():
    model = SimpleModel()
    
    n_gpus = torch.cuda.device_count()
    print(f"Num GPUs:{n_gpus}")
    trainer = pl.Trainer(gpus=n_gpus,  precision=16, accelerator='ddp',plugins=DDPPlugin(find_unused_parameters=False))
    trainer.fit(model)
    
if __name__ == '__main__':
    main()    

And this is the traceback:

packages/pytorch_lightning/trainer/data_loading.py", line 114, in auto_add_sampler
dataloader = self.replace_sampler(dataloader, sampler)
File “/home/users/industry/shared_pylibs/lib/python3.8/site-packages/pytorch_lightning/trainer/data_loading.py”, line 181, in replace_sampler
dataloader = type(dataloader)(**dl_args)
TypeError: intercept_args() got an unexpected keyword argument ‘multiprocessing_context’

It appears to be something to do with DDP replacing the sampler with DistributedDataSampler but I can’t figure out what exactly should be modified. I believe it’s something to do with the dataset and dataloaders, however the dataset obj (generated by create_dataset()) is just a regular dataset subclass that overrides getitem() and len(). Any ideas on how to solve this? Thanks!