Bug in the trainer.predict()

I believe I found a bug in the trainer.predict() function.

As written in the documentation, " By default, the predict_step() method runs the forward() method. In order to customize this behaviour, simply override the predict_step() method."

It sounds logical and when I try to exactly that (i.e. reuse the forward() method), I get an error because the whole sample [x, y] is passed in the forward(), instead of just x, as it should be.

Below is the simplified code to reproduce the problem:

import lightning as L
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader

class SimpleModel(L.LightningModule):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.layer1 = nn.LazyLinear(16)
        self.layer2 = nn.Linear(16, 1)
    def forward(self, x):
        x = torch.relu(self.layer1(x))
        x = self.layer2(x)
        return x
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.MSELoss()(y_hat, y)
        return loss
    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-3)

class SimpleDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.data[idx]
        y = self.labels[idx]
        sample = (x, y)
        return sample

X = torch.rand((1000, 16))
Y = torch.rand((1000, 1))

dataset = SimpleDataset(X, Y)
dataloader = DataLoader(dataset, batch_size=32)

model = SimpleModel()

trainer = L.Trainer(max_epochs=2)
trainer.fit(model=model, train_dataloaders=dataloader) # Ok
trainer.predict(model=model, dataloaders=dataloader)   # Error

The lightning version is the latest stable: 2.2.5