that’s a specific use-case of yours. You can use manual optimization for that purpose. Also in training_step
you get optimizer_idx, so you can call embeddings = model(x)
on optimizer_idx=0
, save it as a state and use it when optimizer_idx=1/2
assuming you have 3 optimizers.