Sequential Data¶
Truncated Backpropagation Through Time¶
There are times when multiple backwards passes are needed for each batch. For example, it may save memory to use Truncated Backpropagation Through Time when training RNNs.
Lightning can handle TBTT automatically via this flag.
from pytorch_lightning import LightningModule
class MyModel(LightningModule):
def __init__(self):
super().__init__()
# Important: This property activates truncated backpropagation through time
# Setting this value to 2 splits the batch into sequences of size 2
self.truncated_bptt_steps = 2
# Truncated back-propagation through time
def training_step(self, batch, batch_idx, hiddens):
# the training step must be updated to accept a ``hiddens`` argument
# hiddens are the hiddens from the previous truncated backprop step
out, hiddens = self.lstm(data, hiddens)
return {"loss": ..., "hiddens": hiddens}
Note
If you need to modify how the batch is split,
override pytorch_lightning.core.LightningModule.tbptt_split_batch()
.