I have a pretrained torch.nn.Module
that my LightningModule
uses for training.
For the purpose of example assume it is a pretrained & fixed ResNet image model, that I use for feature generation.
How can I best use such a module from my LightningModule
?
Option 1:
Simply storing it as a child module:
self.resnet = ResNet()
Would result in its parameters being stored as part of the LightningModule
and increase the checkpoint size. Also, this approach prevents a single ResNet model from being shared by multiple modules. This is a big issue for huge models.
Option 2
Pass the pretrained model as a parameter
class MyModel(LightningModule):
def __init__(self, resnet: ResNet):
self._resnet = [resnet]
With this approach the ResNet model is not really owned by the LightningModule, and simply stored as a reference. It allows model sharing and does not store it inside the checkpoint. But the problem is the device management. I need to manually mode the resnet via .cuda()
and the problem is even greater when training on multiple GPUs.
Better Option?
Is there a better option - that stores the model as an attribute for automatic device management, but that does not manage weights?
Thank you!