I have below lighting module, where I have two train_dataloader
class CustomModel(pl.LightningModule):
def __init__(self, **kwargs):
pass
def train_dataloader(self):
return (
DataLoader(
self.train_dataset,
collate_fn=collate_fn_0(),
),
DataLoader(
self.mbuffer_dataset,
collate_fn=collate_fn_1(),
)
)
runner = Trainer(**kwargs)
runner.fit(CustomModel())
and I like to use this in the following fashion -
for epoch in range(num_epochs):
... # code
for batch in train_dataloaders[0]:
outputs = model(batch)
# do something
for batch in train_dataloaders[1]:
outputs = model(batch)
How I can do it?