I wanted to make a custom data module, I got an error that ‘TranslationDataModule’ object has no attribute ‘on_init_start’
This is my DataModule definition.
Can anybody please help? I used DataModules before and didn’t face this error…
import pytorch_lightning as pl
from torchvision import transforms
class TranslationDataModule(pl.LightningDataModule):
def __init__(self, train_path):
super().__init__()
self.train_path = train_path
self.transform = transforms.Compose([TokenizeAndNumericalise(n_en, n_te), ])
def prepare_data(self):
# download the dataset
ParallelDataset(self.train_path, transform=self.transform)
def setup(self, stage=None):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage is None:
self.train_dataset = ParallelDataset(self.train_path, transform=self.transform)
def train_dataloader(self):
# REQUIRED
return torch.utils.data.DataLoader(self.train_dataset, batch_size=4,
shuffle=True, num_workers=0, collate_fn=collate_fn_padd)