Hi all,
My code does a bunch of stuff, then trains multiple neural nets in succession. Sometimes, one of the models fails to train due to an OOM error, so I encapsulated trainer.fit() in a try…except clause, like so:
# <SOME CODE, MOSTLY ON CPU>
# Training:
for model in models:
try:
trainer.fit(model)
with open(os.path.join(metrics_path, "metric.txt"), "w") as f:
f.write(str(trainer.callback_metrics["best_val_acc"].item()))
return 1
except RuntimeError:
print("RuntimeError in trainer.fit(). The model probably exceeded the GPU's memory.")
return 0
# <DO SOME STUFF WITH THE TRAINED MODEL>
# <SOME MORE CODE>
The bits of code before and after the loop are mostly executed on CPU.
For this training part, I’m using ddp_spawn on 3 GPUs. I tried using ddp but the bits of code before and after NN training were executed 3 times as well… So I guess that’s problem number 1 and if anyone has a solution to that, that would be great.
For now, however, I found that ddp_spawn does what I need, albeit slower, according to the documentation. My issue (problem number 2!) is that whenever a model causes a RuntimeError, the whole program stops instead of just raising the warning message and moving on to the next model. I suspect this is due to the multiprocessing nature of ddp_spawn, but I don’t know how to solve it. I saw some suggestions online to use torch.multiprocessing to spawn processes manually, but that defeats the point of using PL and I really don’t want to go there…
Thanks in advance!