I believe the pretrained model occupies too much memory during training, leading me to suspect that I inadvertently trained both models simultaneously.
Do you disable the autograd for embedding model? Maybe try something like -
class LitModel(L.LightningModule):
def forward(self, X):
with torch.no_grad():
o = self.embedding_model(X)
y = self.classification_model(o)
return y