Hey folk!
I have a pre-trained model (class A), which I use in another model(class B) for its training. Herunder is the code snippet I use to train model B. I am a bit confused with regard to using forward() and predict_step() in class A to generate model A output to be used in model B. Please let me know if the code snippet mentioned hereunder is correct.
class A(pl.LightningModule):
def __init__(self, model ...):
super().__init__()
self.model = model
....
def forward(self, x):
return self.model(x)
def shared_step(self, batch):
z, loss = self(batch)
....
def training_step(self, batch, batch_idx):
return shared_step(batch)
def validation_step(self, batch, batch_idx):
return shared_step(batch)
.... # other methods
#################################
model_A = A.load_from_checkpoint("path.ckpt")
class B(pl.LightningModule)
def __init__(self, model_A, model_B ...):
super().__init__()
self.model_A = model_A
self.model_B = model_B
self.model_A.freeze()
....
def shared_step(self, batch):
self.model_A.eval()
with torch.no_grad():
z0, _ = self.model_A(batch)
z, loss = self.model_B(z0)
....
def training_step(self, batch, batch_idx):
return shared_step(batch)
def validation_step(self, batch, batch_idx):
return shared_step(batch)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.model_B.parameters(), lr=self.lr)
return {"optimizer": optimizer}
.... # other methods
I am asking this because I don’t get the desired results using this pipeline, contrary to when I generate a database explicitly using model_A
and train model_B
using it, i.e., I don’t pass model_A
to class B
. I mainly want to use the above pipeline, where model_A
should just be used as a sample generator whose output can be fed into model_B for its training.
Also, I would like to know what changes I need to make in the shared_step
of class B
if I use predict_step
instead of forward()
in class A
.