Hey folks!
I have a model which consists of 3 submodules( A, B, and C) stacked sequentially. The pipeline of the model looks like A–>B–>C. For the first few epochs, I train the A–>B and compute the loss; thereafter, I freeze the A and B submodules and train only C. The forward pass looks like this:
def training_step(self, batch, batch_idx):
z = B(A(batch))
if self.current_epoch > 30:
z = C(B(A(batch))
# compute loss based on z
return loss
class Callback(pl.Callback)
....
def on_train_epoch_end(self, trainer, pl_module):
if self.current_epoch > 30 and self.freeze:
for param in A.parameters():
param.requires_grad = False
for param in B.parameters():
param.requires_grad = False
trainer.optimizer[0] = torch.optim.Adam(pl_module.C.parameters(), lr=pl_module.lr)
pl_module.freeze = False
....
Note that I also reset the Adam optimizer to train the C component.
I did not get the desired results with this pipeline. However, if I train a model which consists of only A–>B, and store its outputs (z = B(A(batch)))
externally as a database after the training, and then use this database to train a model which contains only C, then it works for me. Basically, I have to break the pipeline into two to get results rather than just having a single pipeline. Any suggestions where I might be doing something wrong?