##############################################
Truncated Backpropagation Through Time (TBPTT)
##############################################

Truncated Backpropagation Through Time (TBPTT) performs backpropogation every k steps of
a much longer sequence. This is made possible by passing training batches
split along the time-dimensions into splits of size k to the
``training_step``. In order to keep the same forward propagation behavior, all
hidden states should be kept in-between each time-dimension split.


.. code-block:: python

    import torch
    import torch.optim as optim
    import pytorch_lightning as pl
    from pytorch_lightning import LightningModule

    class LitModel(LightningModule):

        def __init__(self):
            super().__init__()

            # 1. Switch to manual optimization
            self.automatic_optimization = False

            self.truncated_bptt_steps = 10
            self.my_rnn = ParityModuleRNN() # Define RNN model using ParityModuleRNN

        # 2. Remove the `hiddens` argument
        def training_step(self, batch, batch_idx):

            # 3. Split the batch in chunks along the time dimension
            split_batches = split_batch(batch, self.truncated_bptt_steps)

            batch_size = 10
            hidden_dim = 20
            hiddens = torch.zeros(1, batch_size, hidden_dim, device=self.device)
            for split_batch in range(split_batches):
                # 4. Perform the optimization in a loop
                loss, hiddens = self.my_rnn(split_batch, hiddens)
                self.backward(loss)
                self.optimizer.step()
                self.optimizer.zero_grad()

                # 5. "Truncate"
                hiddens = hiddens.detach()

            # 6. Remove the return of `hiddens`
            # Returning loss in manual optimization is not needed
            return None

        def configure_optimizers(self):
            return optim.Adam(self.my_rnn.parameters(), lr=0.001)

    if __name__ == "__main__":
        model = LitModel()
        trainer = pl.Trainer(max_epochs=5)
        trainer.fit(model, train_dataloader) # Define your own dataloader