Inference in Production¶
Once a model is trained, deploying to production and running inference is the next task. To help you with it, here are the possible approaches you can use to deploy and make inferences with your models.
With Lightning API¶
The following are some possible ways you can use Lightning to run inference in production. Note that PyTorch Lightning has some extra dependencies and using raw PyTorch might be advantageous. in your production environment.
Prediction API¶
Lightning provides you with a prediction API that can be accessed using predict()
.
To configure this with your LightningModule, you would need to override the predict_step()
method.
By default predict_step()
calls the forward()
method. In order to customize this behaviour, simply override the predict_step()
method. This can be useful to add some pre-processing or post-processing logic to your data.
For the example let’s override predict_step
and try out Monte Carlo Dropout:
class LitMCdropoutModel(pl.LightningModule):
def __init__(self, model, mc_iteration):
super().__init__()
self.model = model
self.dropout = nn.Dropout()
self.mc_iteration = mc_iteration
def predict_step(self, batch, batch_idx):
# enable Monte Carlo Dropout
self.dropout.train()
# take average of `self.mc_iteration` iterations
pred = [self.dropout(self.model(x)).unsqueeze(0) for _ in range(self.mc_iteration)]
pred = torch.vstack(pred).mean(dim=0)
return pred
PyTorch Runtime¶
You can also load the saved checkpoint and use it as a regular torch.nn.Module
.
class SimpleModel(LightningModule):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(in_features=64, out_features=4)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
# create the model
model = SimpleModel()
# train it
trainer = Trainer(accelerator="gpu", devices=2)
trainer.fit(model, train_dataloader, val_dataloader)
trainer.save_checkpoint("best_model.ckpt", weights_only=True)
# use model after training or load weights and drop into the production system
model = SimpleModel.load_from_checkpoint("best_model.ckpt")
model.eval()
x = torch.randn(1, 64)
with torch.no_grad():
y_hat = model(x)
Without Lightning API¶
As the LightningModule
is simply a torch.nn.Module
, common techniques to export PyTorch models
to production apply here too. However, the LightningModule
provides helper methods to help you out with it.
Convert to ONNX¶
Lightning provides a handy function to quickly export your model to ONNX format which allows the model to be independent of PyTorch and run on an ONNX Runtime.
To export your model to ONNX format call the to_onnx()
function on your LightningModule
with the filepath
and input_sample
.
class SimpleModel(LightningModule):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(in_features=64, out_features=4)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
# create the model
model = SimpleModel()
filepath = "model.onnx"
input_sample = torch.randn((1, 64))
model.to_onnx(filepath, input_sample, export_params=True)
You can also skip passing the input sample if the example_input_array
property is specified in your LightningModule
.
class SimpleModel(LightningModule):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(in_features=64, out_features=4)
self.example_input_array = torch.randn(7, 64)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
# create the model
model = SimpleModel()
filepath = "model.onnx"
model.to_onnx(filepath, export_params=True)
Once you have the exported model, you can run it on your ONNX runtime in the following way:
import onnxruntime
ort_session = onnxruntime.InferenceSession(filepath)
input_name = ort_session.get_inputs()[0].name
ort_inputs = {input_name: np.random.randn(1, 64)}
ort_outs = ort_session.run(None, ort_inputs)
Convert to TorchScript¶
TorchScript allows you to serialize your models in a way that it can be loaded in non-Python environments.
The LightningModule
has a handy method to_torchscript()
that returns a scripted module which you
can save or directly use.
class SimpleModel(LightningModule):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(in_features=64, out_features=4)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
# create the model
model = SimpleModel()
script = model.to_torchscript()
# save for use in production environment
torch.jit.save(script, "model.pt")
It is recommended that you install the latest supported version of PyTorch to use this feature without limitations.
Once you have the exported model, you can run it in Pytorch or C++ runtime:
inp = torch.rand(1, 64)
scripted_module = torch.jit.load("model.pt")
output = scripted_module(inp)
If you want to script a different method, you can decorate the method with torch.jit.export()
:
class LitMCdropoutModel(pl.LightningModule):
def __init__(self, model, mc_iteration):
super().__init__()
self.model = model
self.dropout = nn.Dropout()
self.mc_iteration = mc_iteration
@torch.jit.export
def predict_step(self, batch, batch_idx):
# enable Monte Carlo Dropout
self.dropout.train()
# take average of `self.mc_iteration` iterations
pred = [self.dropout(self.model(x)).unsqueeze(0) for _ in range(self.mc_iteration)]
pred = torch.vstack(pred).mean(dim=0)
return pred
model = LitMCdropoutModel(...)
script = model.to_torchscript(file_path="model.pt", method="script")
PyTorch Runtime¶
You can also load the saved checkpoint and use it as a regular torch.nn.Module
. You can extract all your torch.nn.Module
and load the weights using the checkpoint saved using LightningModule after training. For this, we recommend copying the exact implementation
from your LightningModule init
and forward
method.
class Encoder(nn.Module):
...
class Decoder(nn.Module):
...
class AutoEncoderProd(nn.Module):
def __init__(self):
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def forward(self, x):
return self.encoder(x)
class AutoEncoderSystem(LightningModule):
def __init__(self):
super().__init__()
self.auto_encoder = AutoEncoderProd()
def forward(self, x):
return self.auto_encoder.encoder(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.auto_encoder.encoder(x)
y_hat = self.auto_encoder.decoder(y_hat)
loss = ...
return loss
# train it
trainer = Trainer(devices=2, accelerator="gpu", strategy="ddp")
model = AutoEncoderSystem()
trainer.fit(model, train_dataloader, val_dataloader)
trainer.save_checkpoint("best_model.ckpt")
# create the PyTorch model and load the checkpoint weights
model = AutoEncoderProd()
checkpoint = torch.load("best_model.ckpt")
hyper_parameters = checkpoint["hyper_parameters"]
# if you want to restore any hyperparameters, you can pass them too
model = AutoEncoderProd(**hyper_parameters)
state_dict = checkpoint["state_dict"]
# update keys by dropping `auto_encoder.`
for key in list(model_weights):
model_weights[key.replace("auto_encoder.", "")] = model_weights.pop(key)
model.load_state_dict(model_weights)
model.eval()
x = torch.randn(1, 64)
with torch.no_grad():
y_hat = model(x)