Truncated Backpropagation Through Time (TBPTT)ΒΆ
Truncated Backpropagation Through Time (TBPTT) performs backpropogation every k steps of
a much longer sequence. This is made possible by passing training batches
split along the time-dimensions into splits of size k to the
training_step
. In order to keep the same forward propagation behavior, all
hidden states should be kept in-between each time-dimension split.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import lightning as L
class AverageDataset(Dataset):
def __init__(self, dataset_len=300, sequence_len=100):
self.dataset_len = dataset_len
self.sequence_len = sequence_len
self.input_seq = torch.randn(dataset_len, sequence_len, 10)
top, bottom = self.input_seq.chunk(2, -1)
self.output_seq = top + bottom.roll(shifts=1, dims=-1)
def __len__(self):
return self.dataset_len
def __getitem__(self, item):
return self.input_seq[item], self.output_seq[item]
class LitModel(L.LightningModule):
def __init__(self):
super().__init__()
self.batch_size = 10
self.in_features = 10
self.out_features = 5
self.hidden_dim = 20
# 1. Switch to manual optimization
self.automatic_optimization = False
self.truncated_bptt_steps = 10
self.rnn = nn.LSTM(self.in_features, self.hidden_dim, batch_first=True)
self.linear_out = nn.Linear(in_features=self.hidden_dim, out_features=self.out_features)
def forward(self, x, hs):
seq, hs = self.rnn(x, hs)
return self.linear_out(seq), hs
# 2. Remove the `hiddens` argument
def training_step(self, batch, batch_idx):
# 3. Split the batch in chunks along the time dimension
x, y = batch
split_x, split_y = [
x.tensor_split(self.truncated_bptt_steps, dim=1),
y.tensor_split(self.truncated_bptt_steps, dim=1)
]
hiddens = None
optimizer = self.optimizers()
losses = []
# 4. Perform the optimization in a loop
for x, y in zip(split_x, split_y):
y_pred, hiddens = self(x, hiddens)
loss = F.mse_loss(y_pred, y)
optimizer.zero_grad()
self.manual_backward(loss)
optimizer.step()
# 5. "Truncate"
hiddens = [h.detach() for h in hiddens]
losses.append(loss.detach())
avg_loss = sum(losses) / len(losses)
self.log("train_loss", avg_loss, prog_bar=True)
# 6. Remove the return of `hiddens`
# Returning loss in manual optimization is not needed
return None
def configure_optimizers(self):
return optim.Adam(self.parameters(), lr=0.001)
def train_dataloader(self):
return DataLoader(AverageDataset(), batch_size=self.batch_size)
if __name__ == "__main__":
model = LitModel()
trainer = L.Trainer(max_epochs=5)
trainer.fit(model)