Hi all.
I am trying lightning as it make using DDP more simple than regular pytorch. However, upon launching my training, I get out of memory errors using the same batch size as before. Does Lightning use more memory than vanilla Pytorch or am I missing something?
Below is my code:
class LightningResunet(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = resunext(8, 29, 10, 64)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
noise_img, clean_img, _, _ = batch
denoised_img = self.model(noise_img)
loss = F.l1_loss(denoised_img, clean_img)
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=[20, 200, 500], gamma=0.1)
return [optimizer], [scheduler]
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = F.l1_loss(y_hat, y)
self.log("validation_loss", loss, on_step=True, on_epoch=True, sync_dist=True)
if __name__ == '__main__':
torch.set_printoptions(linewidth=120)
now = datetime.now()
current_time = now.strftime("%H_%M_%S")
path = "/home/bledc/my_remote_folder/denoiser/models/Apr4_resunet_custom_notpretrain_continue1_{}".format(current_time)
text_path = path + "/" + current_time + ".txt"
train_dataset = Syn_noisemaps('/home/bledc/dataset/syn_train2/Syn_train/', 800, 128) + Real('/home/bledc/dataset/SIDD_train/', 320, 128) \
+ Syn_noisemaps('/home/bledc/dataset/MIT/MIT/MIT/mit_all/', 3500, 128) + just_gaussian('/home/bledc/dataset/MIT/MIT/MIT/gaussian/', 500, 128)
train_size = int(0.95 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, _ = torch.utils.data.random_split(train_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))
test_set = test_my_mixed_set("/home/bledc/dataset/my_test_set", 128, patch_size=128)
test_loader = torch.utils.data.DataLoader(
test_set, batch_size=1,
shuffle=False, num_workers=8,
pin_memory=True, drop_last=True)
data_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=24,
shuffle=True, num_workers=8,
pin_memory=True, drop_last=True)
trainer = pl.Trainer(max_epochs=1000, gpus=4, strategy="ddp")
model = LightningResunet()
trainer.fit(model, train_dataloaders=data_loader,
val_dataloaders=test_loader)
Thank you!