Proper image logging callback with DDP

i recently switched to lightning and was wondering what the proper way is to log images during long DDP runs. For example in image reconstruction or segmentation you often need some model dependent routine to process the images. For that my modules often have a model.forward_vis() function. How would one integrate this in a callback when the model is distributed. I.e:

class ImageLoggingCallback(Callback):
    def __init__(self, every_n_step):
        self.every_n_step = every_n_step
        self.img = transform(some_image)

    def on_train_batch_end(self, *args, **kwargs):
        if trainer.global_step % self.every_n_step == 0 and trainer.global_step != 0:
            # like this for a ddp wrapped model 
            recon_image = trainer.model.forward_visualize(self.img)


Hey @JakobDexl

Your approach here looks good, this should already work. Since you don’t need to sync gradients, you can even just call the method on the LightningModule directly:

def on_train_batch_end(self, trainer, pl_module, ...):
    with torch.no_grad():
        recon_image = pl_module.forward_visualize(self.img)
        # instead of
        # recon_image = pl_modulel.forward_visualize(self.img)

and in addition to that, it is probably best to log only from rank zero:

if trainer.global_rank == 0:

Note that above I’m running the forward_visualize() call on every rank, and by only logging on rank 0 we discard the other images. This is fine, since all of the processes run in parallel so no time is wasted. It is preferred to have all processes do the same amount of work to keep them in sync.

Let me know if you have any more questions about this.

Hi @awaelchli
Thanks a lot for your clear answer!

1 Like