############################################
Deploy models into production (intermediate)
############################################
**Audience**: Researchers and MLEs looking to use their models for predictions without Lightning dependencies.

----

*********************
Use PyTorch as normal
*********************
If you prefer to use PyTorch directly, feel free to use any Lightning checkpoint without Lightning.

.. code-block:: python

    import torch


    class MyModel(nn.Module):
        ...


    model = MyModel()
    checkpoint = torch.load("path/to/lightning/checkpoint.ckpt")
    model.load_state_dict(checkpoint["state_dict"])
    model.eval()

----

********************************************
Extract nn.Module from Lightning checkpoints
********************************************
You can also load the saved checkpoint and use it as a regular :class:`torch.nn.Module`. You can extract all your :class:`torch.nn.Module`
and load the weights using the checkpoint saved using LightningModule after training. For this, we recommend copying the exact implementation
from your LightningModule ``init`` and ``forward`` method.

.. code-block:: python

    class Encoder(nn.Module):
        ...


    class Decoder(nn.Module):
        ...


    class AutoEncoderProd(nn.Module):
        def __init__(self):
            super().__init__()
            self.encoder = Encoder()
            self.decoder = Decoder()

        def forward(self, x):
            return self.encoder(x)


    class AutoEncoderSystem(LightningModule):
        def __init__(self):
            super().__init__()
            self.auto_encoder = AutoEncoderProd()

        def forward(self, x):
            return self.auto_encoder.encoder(x)

        def training_step(self, batch, batch_idx):
            x, y = batch
            y_hat = self.auto_encoder.encoder(x)
            y_hat = self.auto_encoder.decoder(y_hat)
            loss = ...
            return loss


    # train it
    trainer = Trainer(devices=2, accelerator="gpu", strategy="ddp")
    model = AutoEncoderSystem()
    trainer.fit(model, train_dataloader, val_dataloader)
    trainer.save_checkpoint("best_model.ckpt")


    # create the PyTorch model and load the checkpoint weights
    model = AutoEncoderProd()
    checkpoint = torch.load("best_model.ckpt")
    hyper_parameters = checkpoint["hyper_parameters"]

    # if you want to restore any hyperparameters, you can pass them too
    model = AutoEncoderProd(**hyper_parameters)

    model_weights = checkpoint["state_dict"]

    # update keys by dropping `auto_encoder.`
    for key in list(model_weights):
        model_weights[key.replace("auto_encoder.", "")] = model_weights.pop(key)

    model.load_state_dict(model_weights)
    model.eval()
    x = torch.randn(1, 64)

    with torch.no_grad():
        y_hat = model(x)