Checkpointing saves wrong model weights - No matter if Lightning or bare Torch

TLDR below!!

Greetings Lightning Community,

Background:
I am a researcher working on Generative Adversarial Networks for Neuronal Data (EEG data).
I first developed a GAN using bare torch, then I decided to use Lightning 3-4 Months ago and I’m super happy so far. However since 3 weeks I’m hitting a wall because I’m not able to save the models correctly.

Question:
I got a GAN model consisting of 2 networks: generator and critic. I’m using the Lightning framework with the weights an biases logger and the ModelCheckpoint to train, log and checkpoint. However, when I’m checkpointing my models the forward-pass through the generator of the checkpointed model just leads to dramatically different results compared to what I see at the end of training. Therefor I started debugging and wanted to know whats going on: turns out, the weights that I’m saving for the model are so different from what the weights are there during taring. Below I show forward pass results during training and after loading the last checkpoint:

Results during Training:
Screenshot from 2023-09-01 11-48-31

Results when loading Checkpoint
… I can only embed one picture, gonna show the others below…

ignore the epoch number. I just put 500 there becuase it needed a number, however the upper picture shows the spectrum of the last epoch during training the lower shows the loaded checkpoint.

After days of debugging I tried to save checkpoints just by using the torch.save(module.state_dict()) inside my custom logger Callback however, same results. When comparing of some randomly chosen kernels I see the following:

Picture of weights:… (also below)

The torch label shows the weights when using torch to save, w and b are the weights logged by the ModelChekpoint and csv is what I get if I just dump all weigts in a CSV at the end of training (probably the weights I want)… The .csv is saved using:

filter_block1 = module.generator.blocks[0].intermediate_sequence[4].conv2.weight.data.clone().flatten()

np.savetxt(".../EEG-GAN/block1_conv2.csv",  filter_block1.detach().cpu().numpy())

So now for what I use to log:, this is what I think is the important part of the train.py file:

# Init Logger
# Here we set the Project, Name and Folder for the run:
logger = WandbLogger(log_model="all",
                     name='debugging run',
                     project='EEGGAN',
                     save_dir=results_path, )

# Init Checkpoint
# Here we set the rules for saving Model checkpoints:
checkpoint_callback = ModelCheckpoint(every_n_epochs=500,
                                    filename='checkpoint_{epoch}',
                                    save_top_k=-1,
                                    save_last=True,
                                    )

# Init logging handler
# Custom Callback for logging metrics and plots to wandb:
logging_handler = LoggingHandler()
logging_handler.attach_metrics([Spectrum(250),
                                SWD(1),
                                BinStats(channels, mapping, every_n_epochs = 0),
                                ],)

def main():
    model = GAN(**GAN_PARAMS)

    trainer = Trainer(
            max_epochs=2000,
            reload_dataloaders_every_n_epochs=500,
            callbacks=[Scheduler(), logging_handler, checkpoint_callback],
            default_root_dir=results_path,
            strategy='ddp_find_unused_parameters_true',
            logger=logger,
    )

    logger.watch(model)

    trainer.fit(model, dm)

if __name__ == '__main__':
    main()

MAYBE IMPORTANT:
I’m training on a SLURM cluster using 4 GPUs in parallel.

I would be so thankfull if anybody could help, Im absolutely lost with this…

TLDR: When saving my weights into a .csv at the end of trainign the results are very different from using torch or ModelCheckpoint, does anyone know how to fix this?

If you need any further information, package Versions, code snippeds or whatsoever let me know.

Thankful greetings in advance
Samuel :slight_smile:

Results when loading Checkpoint:
image

Graph of weights:

csv vs torch vs lightning callback:

tempFig

After more research I found this post: How can I make PyTorch save all the weights from all the sub-layers the model is composed of? - PyTorch Forums

It turned out I had the same problem. My model is split into multiple blocks, these blocks are stored in a basic python list. The model can be trained as intended and shows the wanted results. When saving it’s also saving something (probably some weights from a previous step) but not the current model state (weird behavior if you ask me).

To fix this the model blocks need to be stored in a nn.ModuleList.