To give you a more concrete example, consider a LightningModule
implemented with a setup
block:
class MyNet(ptl.LightningModule):
def __init__(self, conv1_width=6, conv2_width=16,
fc1_width=120, fc2_width=84,
dropout1=0.5, dropout2=0.5,
learning_rate=1e-3, **kwargs):
super().__init__()
self.conv1_width = conv1_width
self.conv2_width = conv2_width
self.fc1_width = fc1_width
self.fc2_width = fc2_width
self.dropout1 = dropout1
self.dropout2 = dropout2
self.learning_rate = learning_rate
self.unused_kwargs = kwargs
self.save_hyperparameters()
def setup(self, step):
self.conv1 = nn.Conv2d(3, self.conv1_width, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(self.conv1_width, self.conv2_width, 5)
self.fc1 = nn.Linear(self.conv2_width * 5 * 5, self.fc1_width)
self.drop1 = nn.Dropout(p=self.dropout1)
self.fc2 = nn.Linear(self.fc1_width, self.fc2_width)
self.drop2 = nn.Dropout(p=self.dropout2)
self.fc3 = nn.Linear(self.fc2_width, 10)
self.criterion = nn.CrossEntropyLoss()
Then, running:
model = MyNet()
data = MyDataModule(data_dir='/tmp/pytorch-example/cifar-10-data')
trainer = Trainer(gpus=1, max_epochs=1)
trainer.fit(model, data)
Will train a model. However, when I try to save and load the model like this:
trainer.save_checkpoint('/tmp/pytorch-example/model.ckpt')
new_model = cifar10.MyNet.load_from_checkpoint(checkpoint_path='/tmp/pytorch-example/model.ckpt')
I get a RuntimeError
because of unexpected keys in the state_dict
, obviously because layer creation relies on the setup
block.
This is pretty easy to solve by doing:
my_checkpoint = torch.load('/tmp/pytorch-example/model.ckpt')
new_model = MyNet(my_checkpoint['hyper_parameters'])
new_model.setup('train')
new_model.load_state_dict(my_checkpoint['state_dict'])
However, running
trainer.fit(new_model, data)
resets the model weights by making a fresh call to setup
.
My work-around is to modify the __init__
and setup
methods in my LightningModule
as follows by adding an is_built
flag to self
:
class MyNet(ptl.LightningModule):
def __init__(self, ...,**kwargs):
super().__init__()
...
self.is_built = False
def setup(self, step):
if not self.is_built:
...
self.is_built = True
else:
pass
but I’m not really sure how this is going to behave down the line when I want to start using some of the cooler Lightning features, like parallelization.