Deploy models into production (advanced)¶
Audience: Machine learning engineers optimizing models for enterprise-scale production environments.
Compile your model to ONNX¶
ONNX is a package developed by Microsoft to optimize inference. ONNX allows the model to be independent of PyTorch and run on any ONNX Runtime.
To export your model to ONNX format call the to_onnx()
function on your LightningModule
with the filepath
and input_sample
.
class SimpleModel(LightningModule):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(in_features=64, out_features=4)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
# create the model
model = SimpleModel()
filepath = "model.onnx"
input_sample = torch.randn((1, 64))
model.to_onnx(filepath, input_sample, export_params=True)
You can also skip passing the input sample if the example_input_array
property is specified in your LightningModule
.
class SimpleModel(LightningModule):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(in_features=64, out_features=4)
self.example_input_array = torch.randn(7, 64)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
# create the model
model = SimpleModel()
filepath = "model.onnx"
model.to_onnx(filepath, export_params=True)
Once you have the exported model, you can run it on your ONNX runtime in the following way:
import onnxruntime
ort_session = onnxruntime.InferenceSession(filepath)
input_name = ort_session.get_inputs()[0].name
ort_inputs = {input_name: np.random.randn(1, 64)}
ort_outs = ort_session.run(None, ort_inputs)
Validate a Model Is Servable¶
Warning
This is an experimental feature.
Production ML Engineers would argue that a model shouldn’t be trained if it can’t be deployed reliably and in a fully automated manner.
In order to ease transition from training to production, PyTorch Lightning provides a way for you to validate a model can be served even before starting training.
In order to do so, your LightningModule needs to subclass the ServableModule
, implements its hooks and pass a ServableModuleValidator
callback to the Trainer.
Below you can find an example of how the serving of a resnet18 can be validated.
import base64
from dataclasses import dataclass
from io import BytesIO
from os import path
from typing import Optional
import numpy as np
import torch
import torchvision
import torchvision.transforms as T
from PIL import Image as PILImage
from lightning.pytorch import LightningDataModule, LightningModule, cli_lightning_logo
from lightning.pytorch.cli import LightningCLI
from lightning.pytorch.serve import ServableModule, ServableModuleValidator
from lightning.pytorch.utilities.model_helpers import get_torchvision_model
DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets")
class LitModule(LightningModule):
def __init__(self, name: str = "resnet18"):
super().__init__()
self.model = get_torchvision_model(name, weights="DEFAULT")
self.model.fc = torch.nn.Linear(self.model.fc.in_features, 10)
self.criterion = torch.nn.CrossEntropyLoss()
def training_step(self, batch, batch_idx):
inputs, labels = batch
outputs = self.model(inputs)
loss = self.criterion(outputs, labels)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
inputs, labels = batch
outputs = self.model(inputs)
loss = self.criterion(outputs, labels)
self.log("val_loss", loss)
def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
class CIFAR10DataModule(LightningDataModule):
transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor()])
def train_dataloader(self, *args, **kwargs):
trainset = torchvision.datasets.CIFAR10(root=DATASETS_PATH, train=True, download=True, transform=self.transform)
return torch.utils.data.DataLoader(trainset, batch_size=2, shuffle=True, num_workers=0)
def val_dataloader(self, *args, **kwargs):
valset = torchvision.datasets.CIFAR10(root=DATASETS_PATH, train=False, download=True, transform=self.transform)
return torch.utils.data.DataLoader(valset, batch_size=2, shuffle=True, num_workers=0)
@dataclass(unsafe_hash=True)
class Image:
height: Optional[int] = None
width: Optional[int] = None
extension: str = "JPEG"
mode: str = "RGB"
channel_first: bool = False
def deserialize(self, data: str) -> torch.Tensor:
encoded_with_padding = (data + "===").encode("UTF-8")
img = base64.b64decode(encoded_with_padding)
buffer = BytesIO(img)
img = PILImage.open(buffer, mode="r")
if self.height and self.width:
img = img.resize((self.width, self.height))
arr = np.array(img)
return T.ToTensor()(arr).unsqueeze(0)
class Top1:
def serialize(self, tensor: torch.Tensor) -> int:
return torch.nn.functional.softmax(tensor).argmax().item()
class ProductionReadyModel(LitModule, ServableModule):
def configure_payload(self):
# 1: Access the train dataloader and load a single sample.
image, _ = self.trainer.train_dataloader.dataset[0]
# 2: Convert the image into a PIL Image to bytes and encode it with base64
pil_image = T.ToPILImage()(image)
buffered = BytesIO()
pil_image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("UTF-8")
return {"body": {"x": img_str}}
def configure_serialization(self):
return {"x": Image(224, 224).deserialize}, {"output": Top1().serialize}
def serve_step(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
return {"output": self.model(x)}
def configure_response(self):
return {"output": 7}
def cli_main():
cli = LightningCLI(
ProductionReadyModel,
CIFAR10DataModule,
seed_everything_default=42,
save_config_kwargs={"overwrite": True},
run=False,
trainer_defaults={
"accelerator": "cpu",
"callbacks": [ServableModuleValidator()],
"max_epochs": 1,
"limit_train_batches": 5,
"limit_val_batches": 5,
},
)
cli.trainer.fit(cli.model, cli.datamodule)
if __name__ == "__main__":
cli_lightning_logo()
cli_main()