Try... except statement with DDPSpawn

Hi all,

My code does a bunch of stuff, then trains multiple neural nets in succession. Sometimes, one of the models fails to train due to an OOM error, so I encapsulated trainer.fit() in a try…except clause, like so:

# <SOME CODE, MOSTLY ON CPU>
# Training:
for model in models:
    try:
        trainer.fit(model)  
        with open(os.path.join(metrics_path, "metric.txt"), "w") as f:
            f.write(str(trainer.callback_metrics["best_val_acc"].item()))
        return 1
    except RuntimeError:
        print("RuntimeError in trainer.fit(). The model probably exceeded the GPU's memory.")
        return 0
    # <DO SOME STUFF WITH THE TRAINED MODEL>

# <SOME MORE CODE>

The bits of code before and after the loop are mostly executed on CPU.

For this training part, I’m using ddp_spawn on 3 GPUs. I tried using ddp but the bits of code before and after NN training were executed 3 times as well… So I guess that’s problem number 1 and if anyone has a solution to that, that would be great.
For now, however, I found that ddp_spawn does what I need, albeit slower, according to the documentation. My issue (problem number 2!) is that whenever a model causes a RuntimeError, the whole program stops instead of just raising the warning message and moving on to the next model. I suspect this is due to the multiprocessing nature of ddp_spawn, but I don’t know how to solve it. I saw some suggestions online to use torch.multiprocessing to spawn processes manually, but that defeats the point of using PL and I really don’t want to go there…

Thanks in advance!

Thanks for releasing the post from the spam filter. The problem has been solved on Slack in the meantime. Sorry for cluttering the forum.

1 Like

Sorry that your post was auto-flagged, and glad you were able to solve the issue in the mean time. I hope this doesn’t discourage you to use the forum in the future. Me and the team will be around to help here :slight_smile:

1 Like