.. _converting:

######################################
How to Organize PyTorch Into Lightning
######################################

To enable your code to work with Lightning, perform the following to organize PyTorch into Lightning.

--------

*******************************
1. Keep Your Computational Code
*******************************

Keep your regular nn.Module architecture

.. testcode::

    import lightning.pytorch as pl
    import torch
    import torch.nn as nn
    import torch.nn.functional as F


    class LitModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.layer_1 = nn.Linear(28 * 28, 128)
            self.layer_2 = nn.Linear(128, 10)

        def forward(self, x):
            x = x.view(x.size(0), -1)
            x = self.layer_1(x)
            x = F.relu(x)
            x = self.layer_2(x)
            return x

--------

***************************
2. Configure Training Logic
***************************
In the training_step of the LightningModule configure how your training routine behaves with a batch of training data:

.. testcode::

    class LitModel(pl.LightningModule):
        def __init__(self, encoder):
            super().__init__()
            self.encoder = encoder

        def training_step(self, batch, batch_idx):
            x, y = batch
            y_hat = self.encoder(x)
            loss = F.cross_entropy(y_hat, y)
            return loss

.. note:: If you need to fully own the training loop for complicated legacy projects, check out :doc:`Own your loop <../model/own_your_loop>`.

----

****************************************
3. Move Optimizer(s) and LR Scheduler(s)
****************************************
Move your optimizers to the :meth:`~lightning.pytorch.core.module.LightningModule.configure_optimizers` hook.

.. testcode::

    class LitModel(pl.LightningModule):
        def configure_optimizers(self):
            optimizer = torch.optim.Adam(self.encoder.parameters(), lr=1e-3)
            lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
            return [optimizer], [lr_scheduler]

--------

***************************************
4. Organize Validation Logic (optional)
***************************************
If you need a validation loop, configure how your validation routine behaves with a batch of validation data:

.. testcode::

    class LitModel(pl.LightningModule):
        def validation_step(self, batch, batch_idx):
            x, y = batch
            y_hat = self.encoder(x)
            val_loss = F.cross_entropy(y_hat, y)
            self.log("val_loss", val_loss)

.. tip:: ``trainer.validate()`` loads the best checkpoint automatically by default if checkpointing was enabled during fitting.

--------

************************************
5. Organize Testing Logic (optional)
************************************
If you need a test loop, configure how your testing routine behaves with a batch of test data:

.. testcode::

    class LitModel(pl.LightningModule):
        def test_step(self, batch, batch_idx):
            x, y = batch
            y_hat = self.encoder(x)
            test_loss = F.cross_entropy(y_hat, y)
            self.log("test_loss", test_loss)

--------

****************************************
6. Configure Prediction Logic (optional)
****************************************
If you need a prediction loop, configure how your prediction routine behaves with a batch of test data:

.. testcode::

    class LitModel(LightningModule):
        def predict_step(self, batch, batch_idx):
            x, y = batch
            pred = self.encoder(x)
            return pred

--------

******************************************
7. Remove any .cuda() or .to(device) Calls
******************************************

Your :doc:`LightningModule <../common/lightning_module>` can automatically run on any hardware!

If you have any explicit calls to ``.cuda()`` or ``.to(device)``, you can remove them since Lightning makes sure that the data coming from :class:`~torch.utils.data.DataLoader`
and all the :class:`~torch.nn.Module` instances initialized inside ``LightningModule.__init__`` are moved to the respective devices automatically.
If you still need to access the current device, you can use ``self.device`` anywhere in your ``LightningModule`` except in the ``__init__`` and ``setup`` methods.

.. testcode::

    class LitModel(LightningModule):
        def training_step(self, batch, batch_idx):
            z = torch.randn(4, 5, device=self.device)
            ...

Hint: If you are initializing a :class:`~torch.Tensor` within the ``LightningModule.__init__`` method and want it to be moved to the device automatically you should call
:meth:`~torch.nn.Module.register_buffer` to register it as a parameter.

.. testcode::

    class LitModel(LightningModule):
        def __init__(self):
            super().__init__()
            self.register_buffer("running_mean", torch.zeros(num_features))

--------

********************
8. Use your own data
********************
Regular PyTorch DataLoaders work with Lightning. For more modular and scalable datasets, check out :doc:`LightningDataModule <../data/datamodule>`.

----

************
Good to know
************

Additionally, you can run only the validation loop using :meth:`~lightning.pytorch.trainer.trainer.Trainer.validate` method.

.. code-block:: python

    model = LitModel()
    trainer.validate(model)

.. note:: ``model.eval()`` and ``torch.no_grad()`` are called automatically for validation.


The test loop isn't used within :meth:`~lightning.pytorch.trainer.trainer.Trainer.fit`, therefore, you would need to explicitly call :meth:`~lightning.pytorch.trainer.trainer.Trainer.test`.

.. code-block:: python

    model = LitModel()
    trainer.test(model)

.. note:: ``model.eval()`` and ``torch.no_grad()`` are called automatically for testing.

.. tip:: ``trainer.test()`` loads the best checkpoint automatically by default if checkpointing is enabled.


The predict loop will not be used until you call :meth:`~lightning.pytorch.trainer.trainer.Trainer.predict`.

.. code-block:: python

    model = LitModel()
    trainer.predict(model)

.. note:: ``model.eval()`` and ``torch.no_grad()`` are called automatically for testing.

.. tip:: ``trainer.predict()`` loads the best checkpoint automatically by default if checkpointing is enabled.