Confusions about load_from_checkpoint() and save_hyperparameters()

according to

Saving and loading checkpoints (basic) — PyTorch Lightning 2.1.3 documentation,

There is a model like this:

class Encoder(L.LightningModule):
    ...

class Decoder(L.LightningModule):
    ...

class Autoencoder(L.LightningModule):
    def __init__(self, encoder, decoder, *args, **kwargs):
        self.save_hyperparameters(ignore=['encoder', 'decoder'])
        self.encoder=encoder
        self.encoder.freeze()
        self.decoder=decoder
        ...

# training code
encoder = Encoder.load_from_checkpoint("encoder.ckpt")
decoder = Decoder(some hyperparameters)
autoencoder = Autoencoder(encoder, decoder)
trainer.fit(autoencoder, datamodule)

We assume that the autoencoder has been stored in the autoencoder.ckpt file. There are three key points I am curious about:

  1. Does the autoencoder.ckpt file include both the encoder and decoder weights?
  2. If autoencoder.ckpt contains the encoder weights, how can I import the weights from encoder.ckpt into the autoencoder without them being overwritten?
  3. If autoencoder.ckpt does not include the decoder weights, what is the procedure to save the decoder weights separately?

@sznflash The docs don’t show what you have written down here. You wouldn’t have three LightningModules, only one!

The encoder and decoder are regular torch nn.Modules, and only the top-level system, here called Autoencoder, is a LightningModule that manages both the encoder and decoder training.

The checkpoint contains both, and you would call Autoencoder.load_from_checkpoint, not Encoder.load_from_checkpoint.

I hope this makes it clearer.

1 Like