How to fix: RuntimeError: mat1 and mat2 shapes cannot be multiplied (256x4096 and 1024x4)?

Hi all,

I’m trying to run the following code from the Lightening documentation. When running on CPU, everything works fine; however, when trying to run it on GPU, I encounter the following error:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (256x4096 and 1024x4)

I also tried setting torch.set_float32_matmul_precision('high') and also trainer = Trainer(logger=wandb_logger, accelerator='gpu', devices=1) as suggested, but it didn’t resolve the issue. How can I fix this problem?

I have installed PyTorch, Lightning, and other required packages with conda as follows:

conda create -n AE python=3.11
conda activate AE

conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia

conda install -c anaconda jupyter
conda install -c conda-forge pytorch-lightning
conda install -c conda-forge matplotlib
conda install -c anaconda numpy
conda install -c conda-forge tensorboard
conda install -c anaconda seaborn
conda install -c conda-forge wandb

Thanks!