I am incorporating a pytorch based model into the pl framework for ddp training.
I have a lightning model
class ZfoldLightning(pl.LightningModule): def __init__(self, hparams): ... self.model = XFold(MODEL_PARAM)
which initializes the
XFold model in
However, the XFold model contains many ‘to device’ code like
b = torch.randn(1).to(a.device), which is not recommended by PL.
I tried to increase the batch size and train this model on two device. this does not work. OOM error appears. Turns out even DDP is used, I can only use the same batch size as that of single gpu. I think the reason is that all the tensors are stored in one gpu no matter how many gpus are ultized.
One solution is to refactor those to device code and use the recommended usage
a.type_as(b). But there are to many of code to refactor.
I am wondering if there are better solutions?