according to
Saving and loading checkpoints (basic) — PyTorch Lightning 2.1.3 documentation,
There is a model like this:
class Encoder(L.LightningModule):
...
class Decoder(L.LightningModule):
...
class Autoencoder(L.LightningModule):
def __init__(self, encoder, decoder, *args, **kwargs):
self.save_hyperparameters(ignore=['encoder', 'decoder'])
self.encoder=encoder
self.encoder.freeze()
self.decoder=decoder
...
# training code
encoder = Encoder.load_from_checkpoint("encoder.ckpt")
decoder = Decoder(some hyperparameters)
autoencoder = Autoencoder(encoder, decoder)
trainer.fit(autoencoder, datamodule)
We assume that the autoencoder has been stored in the autoencoder.ckpt
file. There are three key points I am curious about:
- Does the
autoencoder.ckpt
file include both theencoder
anddecoder
weights? - If
autoencoder.ckpt
contains theencoder
weights, how can I import the weights fromencoder.ckpt
into theautoencoder
without them being overwritten? - If
autoencoder.ckpt
does not include thedecoder
weights, what is the procedure to save thedecoder
weights separately?