Mac M2 MPS: failed assertion `destination kernel width and filter kernel width mismatch'

Hi,

I am training an adversarial autoencoder using PyTorch 2.0.0 and pytorch lightning 2.0.1.post0 on Apple M2 (Ventura 13.1), with conda 23.1.0 as manager.

I encountered this error:

/AppleInternal/Library/BuildRoots/5b8a32f9-5db2-11ed-8aeb-7ef33c48bc85/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSNDArray/Kernels/MPSNDArrayConvolutionA14.mm:3967: failed assertion `destination kernel width and filter kernel width mismatch'
/Users/vk/miniconda3/envs/betavae/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown

To my knowledge, the code broke down when running self.manual_backward(loss["g_loss"]) this block:

g_opt.zero_grad()
self.manual_backward(loss["g_loss"])
g_opt.step()

The same code run without problems on linux distribution.

Any thoughts on how to fix it are highly appreciated!