Using custom pretrained model in a lightning module

Hey folk!

I have a pre-trained model (class A), which I use in another model(class B) for its training. Herunder is the code snippet I use to train model B. I am a bit confused with regard to using forward() and predict_step() in class A to generate model A output to be used in model B. Please let me know if the code snippet mentioned hereunder is correct.

class A(pl.LightningModule):
  def __init__(self, model ...):
    super().__init__()
    self.model = model
    ....

  def forward(self, x):
    return self.model(x)

  def shared_step(self, batch):
    z, loss = self(batch)
    ....

  def training_step(self, batch, batch_idx):
    return shared_step(batch)

  def validation_step(self, batch, batch_idx):
    return shared_step(batch)

  .... # other methods
 
#################################
model_A = A.load_from_checkpoint("path.ckpt")

class B(pl.LightningModule)
  def __init__(self, model_A, model_B ...):
    super().__init__()
    self.model_A = model_A
    self.model_B = model_B
    self.model_A.freeze()
    ....

  def shared_step(self, batch):
    self.model_A.eval()
    with torch.no_grad():
      z0, _ = self.model_A(batch)
    z, loss = self.model_B(z0)
    ....

  def training_step(self, batch, batch_idx):
    return shared_step(batch)

  def validation_step(self, batch, batch_idx):
    return shared_step(batch)

  def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.model_B.parameters(), lr=self.lr)
    return {"optimizer": optimizer}  

    .... # other methods

I am asking this because I don’t get the desired results using this pipeline, contrary to when I generate a database explicitly using model_A and train model_B using it, i.e., I don’t pass model_A to class B. I mainly want to use the above pipeline, where model_A should just be used as a sample generator whose output can be fed into model_B for its training.

Also, I would like to know what changes I need to make in the shared_step of class B if I use predict_step instead of forward() in class A.