How to continue training for more epochs?


I’ve trained the model and want to add more epochs to it. I didn’t save the checkpoints, but from my understanding pytorch lightning knows the model state to continue training where he left out. I didn’t close the kernel yet.

I did 3 epochs, I’ve set max epoch to 5. It is still training for 5 more epochs. Can someone point out if my reasoning is correct? Ty

model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)

class SentimentClassifier(pl.LightningModule):
    def __init__(self, model):
        self.model = model
        self.loss = torch.nn.CrossEntropyLoss()

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.model(input_ids, attention_mask=attention_mask, labels=labels)
        return outputs
    def predict_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        with torch.no_grad():
            outputs = self(input_ids, attention_mask, labels=None)
            preds = outputs.logits.argmax(-1)
            return preds.tolist()

# here I define the trainer
trainer = pl.Trainer(
    logger=pl.loggers.TensorBoardLogger('logs/', name='sentiment_classifier')

sentiment_classifier = SentimentClassifier(model), 
            DataLoader(train_dataset, batch_size=16), 
            DataLoader(val_dataset, batch_size=16))

# once fitted, I train for more without saving checkpoint
# it's going to train 5 more times
trainer2 = pl.Trainer(max_epochs=5), 
            DataLoader(train_dataset, batch_size=16), 
            DataLoader(val_dataset, batch_size=16))

Hey, the number of epochs are not stored in the model, it only has a reference to the trainer, where some properties are forwarded from. The trainer owns the number of epochs as well as the current epoch. Therefore the behavior you’re seeing is correct.

This would work with a checkpoint of trainer1 restored to trainer2 since in a checkpoint we don’t only store the model state but (among other things) also the trainer state which contains the current epoch.


1 Like