TLDR below!!
Greetings Lightning Community,
Background:
I am a researcher working on Generative Adversarial Networks for Neuronal Data (EEG data).
I first developed a GAN using bare torch, then I decided to use Lightning 3-4 Months ago and I’m super happy so far. However since 3 weeks I’m hitting a wall because I’m not able to save the models correctly.
Question:
I got a GAN model consisting of 2 networks: generator and critic. I’m using the Lightning framework with the weights an biases logger and the ModelCheckpoint
to train, log and checkpoint. However, when I’m checkpointing my models the forward-pass through the generator of the checkpointed model just leads to dramatically different results compared to what I see at the end of training. Therefor I started debugging and wanted to know whats going on: turns out, the weights that I’m saving for the model are so different from what the weights are there during taring. Below I show forward pass results during training and after loading the last checkpoint:
Results during Training:
Results when loading Checkpoint
… I can only embed one picture, gonna show the others below…
ignore the epoch number. I just put 500 there becuase it needed a number, however the upper picture shows the spectrum of the last epoch during training the lower shows the loaded checkpoint.
After days of debugging I tried to save checkpoints just by using the torch.save(module.state_dict())
inside my custom logger Callback however, same results. When comparing of some randomly chosen kernels I see the following:
Picture of weights:… (also below)
The torch label shows the weights when using torch to save, w and b are the weights logged by the ModelChekpoint and csv is what I get if I just dump all weigts in a CSV at the end of training (probably the weights I want)… The .csv is saved using:
filter_block1 = module.generator.blocks[0].intermediate_sequence[4].conv2.weight.data.clone().flatten()
np.savetxt(".../EEG-GAN/block1_conv2.csv", filter_block1.detach().cpu().numpy())
So now for what I use to log:, this is what I think is the important part of the train.py file:
# Init Logger
# Here we set the Project, Name and Folder for the run:
logger = WandbLogger(log_model="all",
name='debugging run',
project='EEGGAN',
save_dir=results_path, )
# Init Checkpoint
# Here we set the rules for saving Model checkpoints:
checkpoint_callback = ModelCheckpoint(every_n_epochs=500,
filename='checkpoint_{epoch}',
save_top_k=-1,
save_last=True,
)
# Init logging handler
# Custom Callback for logging metrics and plots to wandb:
logging_handler = LoggingHandler()
logging_handler.attach_metrics([Spectrum(250),
SWD(1),
BinStats(channels, mapping, every_n_epochs = 0),
],)
def main():
model = GAN(**GAN_PARAMS)
trainer = Trainer(
max_epochs=2000,
reload_dataloaders_every_n_epochs=500,
callbacks=[Scheduler(), logging_handler, checkpoint_callback],
default_root_dir=results_path,
strategy='ddp_find_unused_parameters_true',
logger=logger,
)
logger.watch(model)
trainer.fit(model, dm)
if __name__ == '__main__':
main()
MAYBE IMPORTANT:
I’m training on a SLURM cluster using 4 GPUs in parallel.
I would be so thankfull if anybody could help, Im absolutely lost with this…
TLDR: When saving my weights into a .csv at the end of trainign the results are very different from using torch or ModelCheckpoint, does anyone know how to fix this?
If you need any further information, package Versions, code snippeds or whatsoever let me know.
Thankful greetings in advance
Samuel