What's the best practice for continual learning?

For continual learning, we have several tasks, each of which has one dataset. We need to train a model on these tasks sequentially. There are some requirements:

  1. In training process, we train model on current dataset.
  2. In validation process, we check the performance on all previous datasets.
  3. In testing process, we test our model on all the datasets.

Hope for any suggestions. Thanks!

Clarification

This is a great question! Do you think you could share some details about the training/validation/test process?

Is it something like this (in plain PyTorch)?:

datasets = [DatasetOne(), DatasetTwo(), DatasetThree()]
model = Model()

for current_idx in range(len(datasets)):
    for epochs in range(epochs_per_dataset):

        # train on on current dataset
        train_step(model, datasets[current_idx])

        # validate on previous datasets
        validation_step(model, datasets[:current_idx])

# test on all datasets
test_step(model, datasets)

Potential Solution

I believe you would want to handle this logic in the _dataloader functions of your LightningModule:

from torch.utils.data import DataLoader, ChainDataset

class ContinualLearner(LightningModule)
    def __init__(self, datasets, epochs_per_dataset):
        super().__init__()

        # datasets is a list of torch.util.Dataset
        self.datasets = datasets
        self.curr_index = 0
        self.epochs_per_dataset = epochs_per_dataset

    def train_dataloader(self):
        dl = DataLoader(self.datasets[self.curr_index])

    def val_dataloader(self):
        return DataLoader(ChainDataset(self.datasets[:self.curr_index]))

    def test_dataloader(self):
        return DataLoader(ChainDataset(self.datasets))

    def on_epoch_end(self):
        # update current dataset
        if self.trainer.current_epoch % self.epochs_per_dataset == 0:
            self.curr_index += 1

To make sure you get the new dataloader every epoch, you will need to use the reload_dataloaders_every_epoch flag:

trainer = Trainer(reload_dataloaders_every_epoch=True)
trainer.fit(model)

Hope this sets you off in the right direction! Please do not hesitate to ask any other questions :slight_smile:.

4 Likes

Thanks a lot for replying! That’s almost what I want. There is another thing I want to add. Hopefully, I want to check the performance on each dataset, not the concatenated one. That is, I want the evaluation process to be like this:

def valid_or_test(datasets: List[Dataset]):
    results = SomeCollection()
    for dataset in datasets:
        result = evaluate_model_on_dataset(dataset)
        results.append(result)
    // return or do something with the result collection
    return results

In conclusion, I need the evaluation process to be performed on each dataset separately and provide me with the result on each one. I wonder if I can do it elegantly with pytorch-lightning? Looking forward to your suggestion soon!

This should be remedied by lightning’s multiple data loaders functionality. Using this your results will be indexed by dataset.

    def val_dataloader(self):
        return [DataLoader(ds) for ds in self.datasets[:self.curr_index]]

    def test_dataloader(self):
        return [DataLoader(ds) for ds in self.datasets]

In this case validation_epoch_end and test_epoch_end will be passed a List[List[Any]], where the first list contains a list of outputs for each dataset. Any result you return from validation_step/test_step will be accessible here.

1 Like

Oh, I omitted this functionality in docs. Thanks a lot for reminding me!

No worries! Happy to help

@teddy, will this reset the optimizer and scheduler before training on the new dataset?

I don’t believe so, you will have to do that manually.

So, for Continual Learning, do the optimizer and scheduler need to be reset after every training?

To summarize, my requirement is to train on N datasets sequentially i.e first train on dataset_1 then train on dataset_2 and so on. But given that I’m training on a new dataset, the learning rate/ optimizer should be reset.
@teddy, waiting for your help on this.
Thanks

this should work!

datamodules = [datamodule1, datamodule2, ...]  # create datamodules as per your requirement

def get_trainer(...):
    return Trainer(...)

model = Yourmodel()  # initialize only once

for dm in datamodules:
    trainer = get_trainer(...)
    trainer.fit(model, datamodule=dm)

the learning rate/ optimizer should be reset.

the learning rate will be reset along with the optimizer state but optimizer params will be updated with every run since you are using a trained model trained using previous datasets.

learning_rate is passed as an argument in the Yourmodel() call and is used in the configure_optimizers(self) method. Hence, optimizer is initialized only once.
So, while training on the second datamodule how will the optimizer, learning_rate get reset?
@goku

configure_optimizers is called every time you call .fit, so yes. Optimizers are initialized within LightningModule but are handled by the trainer internally. You can also try using a debugger or simply by printing something inside configure_optimizers and call .fit multiple times, you’ll see it will be executed with each .fit call.

The only thing here is since you are passing lr in Yourmodel and must be assigning it to some variable, I hope it isn’t modified internally anytime.

I am implementing a continual learning model similar to what has been described above. However, I want to have a separate lightning datamodule similar to what you suggested above and my model class to train and validate on tasks sequentially. How should I connect my model and datamodule in that case so when the training and validations ens for one task, my second task data can be retrieved from the datamodule and another task can be learned.
In addition to that, I would like to evaluate some metrics for e.g. average accuracy and backward transfer on all the previous tasks including the current task. Is it feasible to do so considering the solution you mentioned or I need to use some other scenraio?

There is one bug in this code.

def val_dataloader(self):
        return DataLoader(ChainDataset(self.datasets[:self.curr_index]))

Should be [:self.curr_index+1] the +1 is needed to avoid an indexing error. (When curr_index=0 it will return an empty list).