Hi,
I want to implement the following pseudo code using pytorch lightining:
model = MyModel()
loaders = {'train': train_loader, 'finetune': finetune_loader}
def fit():
for epoch in epochs:
train_model(model, loaders['train'])
finetune_model(model, loaders['finetune'])
However, since PL will automatically calls the train_step()
function, I suppose it’s not trivial to implement this.
Is there any workarounds for my problem?
Additional Context
1- It seems like it is possible to call train_model
and finetune_model
at every other epochs. So, I might be able to call train_model
at epoch 1 and finetune_model
at epoch 2, and so on. But then all the metrics I report to the logger will take into account the epoch number which makes the code very messy.
2- In general, is it possible to control the behavior of train loop of the trainer?
Thanks