• Docs >
  • Inference in Production
Shortcuts

Inference in Production

Once a model is trained, deploying to production and running inference is the next task. To help you with it, here are the possible approaches you can use to deploy and make inferences with your models.


With Lightning API

The following are some possible ways you can use Lightning to run inference in production. Note that PyTorch Lightning has some extra dependencies and using raw PyTorch might be advantageous. in your production environment.


Prediction API

Lightning provides you with a prediction API that can be accessed using predict(). To configure this with your LightningModule, you would need to override the predict_step() method. By default predict_step() calls the forward() method. In order to customize this behaviour, simply override the predict_step() method. This can be useful to add some pre-processing or post-processing logic to your data.

For the example let’s override predict_step and try out Monte Carlo Dropout:

class LitMCdropoutModel(pl.LightningModule):
    def __init__(self, model, mc_iteration):
        super().__init__()
        self.model = model
        self.dropout = nn.Dropout()
        self.mc_iteration = mc_iteration

    def predict_step(self, batch, batch_idx):
        # enable Monte Carlo Dropout
        self.dropout.train()

        # take average of `self.mc_iteration` iterations
        pred = [self.dropout(self.model(x)).unsqueeze(0) for _ in range(self.mc_iteration)]
        pred = torch.vstack(pred).mean(dim=0)
        return pred

PyTorch Runtime

You can also load the saved checkpoint and use it as a regular torch.nn.Module.

class SimpleModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(in_features=64, out_features=4)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))


# create the model
model = SimpleModel()

# train it
trainer = Trainer(accelerator="gpu", devices=2)
trainer.fit(model, train_dataloader, val_dataloader)
trainer.save_checkpoint("best_model.ckpt", weights_only=True)

# use model after training or load weights and drop into the production system
model = SimpleModel.load_from_checkpoint("best_model.ckpt")
model.eval()
x = torch.randn(1, 64)

with torch.no_grad():
    y_hat = model(x)

Without Lightning API

As the LightningModule is simply a torch.nn.Module, common techniques to export PyTorch models to production apply here too. However, the LightningModule provides helper methods to help you out with it.


Convert to ONNX

Lightning provides a handy function to quickly export your model to ONNX format which allows the model to be independent of PyTorch and run on an ONNX Runtime.

To export your model to ONNX format call the to_onnx() function on your LightningModule with the filepath and input_sample.

class SimpleModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(in_features=64, out_features=4)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))


# create the model
model = SimpleModel()
filepath = "model.onnx"
input_sample = torch.randn((1, 64))
model.to_onnx(filepath, input_sample, export_params=True)

You can also skip passing the input sample if the example_input_array property is specified in your LightningModule.

class SimpleModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(in_features=64, out_features=4)
        self.example_input_array = torch.randn(7, 64)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))


# create the model
model = SimpleModel()
filepath = "model.onnx"
model.to_onnx(filepath, export_params=True)

Once you have the exported model, you can run it on your ONNX runtime in the following way:

import onnxruntime

ort_session = onnxruntime.InferenceSession(filepath)
input_name = ort_session.get_inputs()[0].name
ort_inputs = {input_name: np.random.randn(1, 64)}
ort_outs = ort_session.run(None, ort_inputs)

Convert to TorchScript

TorchScript allows you to serialize your models in a way that it can be loaded in non-Python environments. The LightningModule has a handy method to_torchscript() that returns a scripted module which you can save or directly use.

class SimpleModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.l1 = torch.nn.Linear(in_features=64, out_features=4)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))


# create the model
model = SimpleModel()
script = model.to_torchscript()

# save for use in production environment
torch.jit.save(script, "model.pt")

It is recommended that you install the latest supported version of PyTorch to use this feature without limitations.

Once you have the exported model, you can run it in Pytorch or C++ runtime:

inp = torch.rand(1, 64)
scripted_module = torch.jit.load("model.pt")
output = scripted_module(inp)

If you want to script a different method, you can decorate the method with torch.jit.export():

class LitMCdropoutModel(pl.LightningModule):
    def __init__(self, model, mc_iteration):
        super().__init__()
        self.model = model
        self.dropout = nn.Dropout()
        self.mc_iteration = mc_iteration

    @torch.jit.export
    def predict_step(self, batch, batch_idx):
        # enable Monte Carlo Dropout
        self.dropout.train()

        # take average of `self.mc_iteration` iterations
        pred = [self.dropout(self.model(x)).unsqueeze(0) for _ in range(self.mc_iteration)]
        pred = torch.vstack(pred).mean(dim=0)
        return pred


model = LitMCdropoutModel(...)
script = model.to_torchscript(file_path="model.pt", method="script")

PyTorch Runtime

You can also load the saved checkpoint and use it as a regular torch.nn.Module. You can extract all your torch.nn.Module and load the weights using the checkpoint saved using LightningModule after training. For this, we recommend copying the exact implementation from your LightningModule init and forward method.

class Encoder(nn.Module):
    ...


class Decoder(nn.Module):
    ...


class AutoEncoderProd(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        return self.encoder(x)


class AutoEncoderSystem(LightningModule):
    def __init__(self):
        super().__init__()
        self.auto_encoder = AutoEncoderProd()

    def forward(self, x):
        return self.auto_encoder.encoder(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.auto_encoder.encoder(x)
        y_hat = self.auto_encoder.decoder(y_hat)
        loss = ...
        return loss


# train it
trainer = Trainer(devices=2, accelerator="gpu", strategy="ddp")
model = AutoEncoderSystem()
trainer.fit(model, train_dataloader, val_dataloader)
trainer.save_checkpoint("best_model.ckpt")


# create the PyTorch model and load the checkpoint weights
model = AutoEncoderProd()
checkpoint = torch.load("best_model.ckpt")
hyper_parameters = checkpoint["hyper_parameters"]

# if you want to restore any hyperparameters, you can pass them too
model = AutoEncoderProd(**hyper_parameters)

state_dict = checkpoint["state_dict"]

# update keys by dropping `auto_encoder.`
for key in list(model_weights):
    model_weights[key.replace("auto_encoder.", "")] = model_weights.pop(key)

model.load_state_dict(model_weights)
model.eval()
x = torch.randn(1, 64)

with torch.no_grad():
    y_hat = model(x)