Why Lightning is almost 3x slower then plain PyTorch?

Why the following code executes 3x slower then plain PyTorch training?

class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(331, 512),
            nn.ReLU(),
            nn.Linear(512, 1)
        )

    def forward(self, x):
        log_val = self.linear_relu_stack(x)
        return log_val
train_dataloader = DataLoader(TensorDataset(train_features, train_target), batch_size=1, shuffle=True)
test_dataloader = DataLoader(TensorDataset(test_features, test_target), batch_size=1, shuffle=False)
class HousePriceModule(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.loss = nn.MSELoss()

    def training_step(self, batch, batch_idx):
        x, y = batch
        loss = self.loss(self.model(x), y)
        #self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        loss = self.loss(self.model(x), y)
        #self.log("val_loss", loss)
        return loss

    def configure_optimizers(self):
        #optimizer = optim.Adam(self.parameters(), lr=1e-3)
        optimizer = optim.SGD(self.parameters(), lr=1e-3)
        return optimizer
housePriceModule = HousePriceModule(NeuralNetwork())

trainer = pl.Trainer(max_epochs=10, log_every_n_steps=100, enable_progress_bar=False)
trainer.fit(model=housePriceModule, train_dataloaders=train_dataloader, val_dataloaders=test_dataloader)

Hey, can you show the corresponding plain pytorch training loop? This shouldn’t be the case and we actually test every single commit to make sure we match the pytorch speed.

One difference is, that by default you have checkpointing enabled which on its own can take quite some time. To actually spot differences, I’d need to see the corresponding raw pytorch loop.

Best,
Justus

1 Like
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

def test(dataloader, model, loss_fn):
    num_batches = len(dataloader)
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
    test_loss /= num_batches
    print(f"Avg loss: {test_loss:>8f} \n")

loss_fn = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

epochs = 10
for t in range(epochs):
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)

enable_checkpointing=False improved speed by ~10%. It is still almost 3x longer. It is trained on CPU.

batch_size=1, train_features size = 1360, test_features size = 100