In PyTorch Lightning, how can one extract embeddings from a pretrained model to assist another model during training_step?

I believe the pretrained model occupies too much memory during training, leading me to suspect that I inadvertently trained both models simultaneously.

Do you disable the autograd for embedding model? Maybe try something like -

class LitModel(L.LightningModule):
    def forward(self, X):
        with torch.no_grad():
            o = self.embedding_model(X)
        y = self.classification_model(o)
       return y