Freezing and training submodules of model

Hey folks!

I have a model which consists of 3 submodules( A, B, and C) stacked sequentially. The pipeline of the model looks like A–>B–>C. For the first few epochs, I train the A–>B and compute the loss; thereafter, I freeze the A and B submodules and train only C. The forward pass looks like this:

def training_step(self, batch, batch_idx):
   z = B(A(batch))
   if self.current_epoch > 30:
     z = C(B(A(batch))
   # compute loss based on z
  return loss

class Callback(pl.Callback)
  ....

  def on_train_epoch_end(self, trainer, pl_module):
    if self.current_epoch > 30 and self.freeze:
      for param in A.parameters():
        param.requires_grad = False
      for param in B.parameters():
        param.requires_grad = False
   
    trainer.optimizer[0] = torch.optim.Adam(pl_module.C.parameters(), lr=pl_module.lr)
    pl_module.freeze = False
  ....

Note that I also reset the Adam optimizer to train the C component.

I did not get the desired results with this pipeline. However, if I train a model which consists of only A–>B, and store its outputs (z = B(A(batch))) externally as a database after the training, and then use this database to train a model which contains only C, then it works for me. Basically, I have to break the pipeline into two to get results rather than just having a single pipeline. Any suggestions where I might be doing something wrong?

The code you posted does the opposite of what you explained. It unfreezes A and B after 30 epochs:

for param in A.parameters():
        param.requires_grad = True
for param in B.parameters():
        param.requires_grad = True

Based on your description, you want this:

for param in A.parameters():
        param.requires_grad = False
for param in B.parameters():
        param.requires_grad = False
for param in C.parameters():
        param.requires_grad = True

Sorry, I mistakenly wrote it in the post.

Based on your answer, may I know why I must set the param.requires_grad = True for C? I do not freeze its parameter right from the beginning because I am not using C to compute z when I am training A and B and I compute the loss based on the output of B. I assume that C is not in the forward pass (computation graph), so it would not matter to freeze C’s parameters initially and unfreeze them later for training.

It’s only after the 30th epoch that I further compute z using C and compute the loss based on that.

Also, to let you know, I don’t use batch norms in A and B.
I just want that after the 30th epoch A–>B should just behave as a generator of samples that can be fed into C for its training.