Hi Folks,
I have moved my code from Pytorch to PyTorch Lightning recently to facilitate the implementation of distributed data parallel training.
As per the title, my validation loss is exploding to the order of 10e+10 while using L1 loss. This occurs around epoch 20. I am confused as I am also reporting PSNR on my validation dataset and the values are not great but acceptable (~30 dB). My model is a UNET (encoder-decoder) with ~58M parameters. The training loss does not show this behaviour.
Is it possible that it is overfitting to my training data? I have successfully trained models a fraction of this size prior without issue acheiving better results. My validation set is a subset of my training set. I’ve also ran an integrity check on my validation images and nothing is corrupted.
Below is my lightning code:
class LightningResunet(pl.LightningModule):
def __init__(self):
super().__init__()
self.model = resunetplus(3)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
noise_img, clean_img, _, _ = batch
denoised_img = self.model(noise_img)
loss = F.l1_loss(denoised_img, clean_img)
self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=0.0001)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=[20, 150, 500], gamma=0.1)
return [optimizer], [scheduler]
def validation_step(self, batch, batch_idx):
noisy, clean = batch
denoised = self.model(noisy)
val_loss = F.l1_loss(denoised, clean)
psnr = test_psnr(clean, denoised)
self.log("validation_loss", val_loss, on_epoch=True, on_step=False, sync_dist=True, prog_bar=True, logger=True)
self.log("validation_psnr", psnr, on_epoch=True, on_step=False, sync_dist=True, prog_bar=True, logger=True)
return {'val_loss': val_loss, 'val_psnr': psnr}
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([torch.as_tensor(x['val_loss']) for x in outputs]).mean()
avg_psnr = torch.stack([torch.as_tensor(x['val_psnr']) for x in outputs]).mean()
self.log('avg_val_loss', avg_loss, on_epoch=True, sync_dist=True, prog_bar=True)
self.log('avg_val_psnr', avg_psnr, on_epoch=True, sync_dist=True, prog_bar=True)
if __name__ == '__main__':
torch.set_printoptions(linewidth=120)
now = datetime.now()
current_time = now.strftime("%H_%M_%S")
path = "/home/bledc/denoiser/models/Apr14_singleconnection_{}".format(current_time)
os.mkdir(path)
text_path = path + "/" + current_time + ".txt"
train_dataset = Syn_noisemaps('/home/bledc/datasets/Syn_train/', 800, 96) + Real('/home/bledc/datasets/SIDD_train/', 320, 96) \
+ Syn_noisemaps('/home/bledc/datasets/mit_all/', 3500, 96) + just_gaussian('/home/bledc/datasets/gaussian/', 500, 96)
train_size = int(0.95 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, _ = torch.utils.data.random_split(train_dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))
test_set = test_my_mixed_set("/home/bledc/datasets/my_test_set/", 96, patch_size=96)
test_loader = torch.utils.data.DataLoader(
test_set, batch_size=1,
shuffle=False, num_workers=8,
pin_memory=True, drop_last=True)
data_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=16,
shuffle=True, num_workers=8,
pin_memory=True, drop_last=False)
chk_path = "/home/bledc/denoiser/models/Apr14_noconnections_resunet_02_30_56/model_E_epoch=2657-validation_psnr=30.16.ckpt"
checkpoint_callback = ModelCheckpoint(
save_top_k=10,
monitor="validation_psnr",
mode="max",
dirpath=path,
filename="model_E_{epoch:02d}-{validation_psnr:.2f}")
resume_from_checkpoint=chk_path, callbacks=[checkpoint_callback])
trainer = pl.Trainer(max_epochs=5000, gpus=4, strategy="ddp", callbacks=[checkpoint_callback], accumulate_grad_batches=2, gradient_clip_val=0.2, resume_from_checkpoint=chk_path)
model = LightningResunet()
trainer.fit(model, train_dataloaders=data_loader,
val_dataloaders=test_loader)
Any suggestions are welcome! Thanks folks!