• Docs >
  • Deploy models into production (basic)
Shortcuts

Deploy models into production (basic)

Audience: All users.


Load a checkpoint and predict

The easiest way to use a model for predictions is to load the weights using load_from_checkpoint found in the LightningModule.

model = LitModel.load_from_checkpoint("best_model.ckpt")
model.eval()
x = torch.randn(1, 64)

with torch.no_grad():
    y_hat = model(x)

Predict step with your LightningModule

Loading a checkpoint and predicting still leaves you with a lot of boilerplate around the predict epoch. The predict step in the LightningModule removes this boilerplate.

class MyModel(LightningModule):
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        return self(batch)

And pass in any dataloader to the Lightning Trainer:

data_loader = DataLoader(...)
model = MyModel()
trainer = Trainer()
predictions = trainer.predict(model, data_loader)

Enable complicated predict logic

When you need to add complicated pre-processing or post-processing logic to your data use the predict step. For example here we do Monte Carlo Dropout for predictions:

class LitMCdropoutModel(pl.LightningModule):
    def __init__(self, model, mc_iteration):
        super().__init__()
        self.model = model
        self.dropout = nn.Dropout()
        self.mc_iteration = mc_iteration

    def predict_step(self, batch, batch_idx):
        # enable Monte Carlo Dropout
        self.dropout.train()

        # take average of `self.mc_iteration` iterations
        pred = [self.dropout(self.model(x)).unsqueeze(0) for _ in range(self.mc_iteration)]
        pred = torch.vstack(pred).mean(dim=0)
        return pred

Enable distributed inference

By using the predict step in Lightning you get free distributed inference

trainer = Trainer(devices=8, accelerator="gpu")
predictions = trainer.predict(model, data_loader)