I believe I found a bug in the trainer.predict()
function.
As written in the documentation, " By default, the predict_step()
method runs the forward()
method. In order to customize this behaviour, simply override the predict_step()
method."
It sounds logical and when I try to exactly that (i.e. reuse the forward()
method), I get an error because the whole sample [x, y] is passed in the forward()
, instead of just x, as it should be.
Below is the simplified code to reproduce the problem:
import lightning as L
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
class SimpleModel(L.LightningModule):
def __init__(self):
super(SimpleModel, self).__init__()
self.layer1 = nn.LazyLinear(16)
self.layer2 = nn.Linear(16, 1)
def forward(self, x):
x = torch.relu(self.layer1(x))
x = self.layer2(x)
return x
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.MSELoss()(y_hat, y)
return loss
def configure_optimizers(self):
return Adam(self.parameters(), lr=1e-3)
class SimpleDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
x = self.data[idx]
y = self.labels[idx]
sample = (x, y)
return sample
X = torch.rand((1000, 16))
Y = torch.rand((1000, 1))
dataset = SimpleDataset(X, Y)
dataloader = DataLoader(dataset, batch_size=32)
model = SimpleModel()
trainer = L.Trainer(max_epochs=2)
trainer.fit(model=model, train_dataloaders=dataloader) # Ok
trainer.predict(model=model, dataloaders=dataloader) # Error
The lightning version is the latest stable: 2.2.5