Truncated Backpropagation Through Time (TBPTT)ΒΆ

Truncated Backpropagation Through Time (TBPTT) performs backpropogation every k steps of a much longer sequence. This is made possible by passing training batches split along the time-dimensions into splits of size k to the training_step. In order to keep the same forward propagation behavior, all hidden states should be kept in-between each time-dimension split.

import torch
import torch.optim as optim
import pytorch_lightning as pl
from pytorch_lightning import LightningModule

class LitModel(LightningModule):

    def __init__(self):
        super().__init__()

        # 1. Switch to manual optimization
        self.automatic_optimization = False

        self.truncated_bptt_steps = 10
        self.my_rnn = ParityModuleRNN() # Define RNN model using ParityModuleRNN

    # 2. Remove the `hiddens` argument
    def training_step(self, batch, batch_idx):

        # 3. Split the batch in chunks along the time dimension
        split_batches = split_batch(batch, self.truncated_bptt_steps)

        batch_size = 10
        hidden_dim = 20
        hiddens = torch.zeros(1, batch_size, hidden_dim, device=self.device)
        for split_batch in range(split_batches):
            # 4. Perform the optimization in a loop
            loss, hiddens = self.my_rnn(split_batch, hiddens)
            self.backward(loss)
            self.optimizer.step()
            self.optimizer.zero_grad()

            # 5. "Truncate"
            hiddens = hiddens.detach()

        # 6. Remove the return of `hiddens`
        # Returning loss in manual optimization is not needed
        return None

    def configure_optimizers(self):
        return optim.Adam(self.my_rnn.parameters(), lr=0.001)

if __name__ == "__main__":
    model = LitModel()
    trainer = pl.Trainer(max_epochs=5)
    trainer.fit(model, train_dataloader) # Define your own dataloader