Why Lightning is almost 3x slower than plain PyTorch?

Why the following code executes 3x slower then plain PyTorch training?

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(331, 512),
            nn.ReLU(),
            nn.Linear(512, 1)
        )

    def forward(self, x):
        log_val = self.linear_relu_stack(x)
        return log_val
train_dataloader = DataLoader(TensorDataset(train_features, train_target), batch_size=1, shuffle=True)
test_dataloader = DataLoader(TensorDataset(test_features, test_target), batch_size=1, shuffle=False)
class HousePriceModule(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.loss = nn.MSELoss()

    def training_step(self, batch, batch_idx):
        x, y = batch
        loss = self.loss(self.model(x), y)
        #self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        loss = self.loss(self.model(x), y)
        #self.log("val_loss", loss)
        return loss

    def configure_optimizers(self):
        #optimizer = optim.Adam(self.parameters(), lr=1e-3)
        optimizer = optim.SGD(self.parameters(), lr=1e-3)
        return optimizer
housePriceModule = HousePriceModule(NeuralNetwork())

trainer = pl.Trainer(max_epochs=10, log_every_n_steps=100, enable_progress_bar=False)
trainer.fit(model=housePriceModule, train_dataloaders=train_dataloader, val_dataloaders=test_dataloader)

Hey, can you show the corresponding plain pytorch training loop? This shouldn’t be the case and we actually test every single commit to make sure we match the pytorch speed.

One difference is, that by default you have checkpointing enabled which on its own can take quite some time. To actually spot differences, I’d need to see the corresponding raw pytorch loop.

Best,
Justus

2 Likes
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

def test(dataloader, model, loss_fn):
    num_batches = len(dataloader)
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
    test_loss /= num_batches
    print(f"Avg loss: {test_loss:>8f} \n")

loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

epochs = 10
for t in range(epochs):
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)

enable_checkpointing=False improved speed by ~10%. It is still almost 3x longer. It is trained on CPU.

batch_size=1, train_features size = 1360, test_features size = 100

I am facing the same issue with pretty similar code. Anyone has solution?

2 Likes

I am experiencing the same thing, any insights?

I am experiencing the same issue, but ~6x slower with Lightning (see this notebook). I thought it was because I used the Combined Loader class of Lightning (for both implementations). Any new insight?

To compare against PyTorch, you can set Trainer(barebones=True).

Well, that helped a lot… The training time is even lower than plain PyTorch (~x0.7). Thank you for the reply!