Deploy models into production (intermediate)¶
Audience: Researchers and MLEs looking to use their models for predictions without Lightning dependencies.
Use PyTorch as normal¶
If you prefer to use PyTorch directly, feel free to use any Lightning checkpoint without Lightning.
import torch
class MyModel(nn.Module):
...
model = MyModel()
checkpoint = torch.load("path/to/lightning/checkpoint.ckpt")
model.load_state_dict(checkpoint["state_dict"])
model.eval()
Extract nn.Module from Lightning checkpoints¶
You can also load the saved checkpoint and use it as a regular torch.nn.Module
. You can extract all your torch.nn.Module
and load the weights using the checkpoint saved using LightningModule after training. For this, we recommend copying the exact implementation
from your LightningModule init
and forward
method.
class Encoder(nn.Module):
...
class Decoder(nn.Module):
...
class AutoEncoderProd(nn.Module):
def __init__(self):
super().__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def forward(self, x):
return self.encoder(x)
class AutoEncoderSystem(LightningModule):
def __init__(self):
super().__init__()
self.auto_encoder = AutoEncoderProd()
def forward(self, x):
return self.auto_encoder.encoder(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.auto_encoder.encoder(x)
y_hat = self.auto_encoder.decoder(y_hat)
loss = ...
return loss
# train it
trainer = Trainer(devices=2, accelerator="gpu", strategy="ddp")
model = AutoEncoderSystem()
trainer.fit(model, train_dataloader, val_dataloader)
trainer.save_checkpoint("best_model.ckpt")
# create the PyTorch model and load the checkpoint weights
model = AutoEncoderProd()
checkpoint = torch.load("best_model.ckpt")
hyper_parameters = checkpoint["hyper_parameters"]
# if you want to restore any hyperparameters, you can pass them too
model = AutoEncoderProd(**hyper_parameters)
model_weights = checkpoint["state_dict"]
# update keys by dropping `auto_encoder.`
for key in list(model_weights):
model_weights[key.replace("auto_encoder.", "")] = model_weights.pop(key)
model.load_state_dict(model_weights)
model.eval()
x = torch.randn(1, 64)
with torch.no_grad():
y_hat = model(x)