Input shape issue with pl_bolts.callbacks.variational.LatentDimInterpolator

Hello everybody,
I am training a variational autoencoder and I want to use pl_bolts.callbacks.variational.LatentDimInterpolator. I wrote the following code

import os 
import argparse
from vae_options import options
from torch.nn import functional as F 
from pytorch_lightning import Trainer 
from data_module import ChromoDataModule
from pl_bolts.models.autoencoders import VAE 
from pl_bolts.callbacks.variational import LatentDimInterpolator 

# option parsing 

parser = argparse.ArgumentParser()
option_parser = options()
option_parser.initialize(parser)
opt = parser.parse_args()
print(opt)

# training 

savepath = os.path.join(opt.save_path, opt.experiment_name)
dataset = ChromoDataModule(train_dir = opt.train_path, test_dir = opt.test_path, batch_size = opt.batch_size, num_workers = opt.num_workers)
model = VAE(input_height = 64, enc_type = "resnet18")
trainer = Trainer(gpus=1,  max_epochs=opt.n_epochs, callbacks = [LatentDimInterpolator()], default_root_dir = savepath)
trainer.fit(model, dataset)

the code runs fine for training, but when the callback is called, I get the following error

RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 3, 3, 3], but got 2-dimensional input of size [2, 256] instead

This happens when :

 File "lib/python3.7/site-packages/pl_bolts/callbacks/variational.py", line 91, in interpolate_latent_space
    img = pl_module(z)
  File "lib/python3.7/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "lib/python3.7/site-packages/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py", line 109, in forward
    x = self.encoder(x)

(I simplified the path for the pl_lightning install) the encoder is not happy because it wants an image, and not a latent vector. Shouldn’t the callback code rather be img = pl_module.decoder(z) instead ? Or am I understanding this all wrong (which is very likely) ?
In any case, help would be much appreciated :slight_smile: !

EDIT :
with this modification, the code runs to completion without issue. It brings another very dumb question from me : the callback outputs a list of generated images. How should I retrieve this list to save it somewhere ? Should I subclass the callback to save the images directly ?

Hello, my apology for the late reply. We are slowly converging to deprecate this forum in favor of the GH build-in version… Could we kindly ask you to recreate your question there - Lightning Discussions