This section in the documentation has this pseudocode to explain what happens when configure_optimizers
returns multiple optimizers:
for epoch in epochs:
for batch in data:
for opt in optimizers:
disable_grads_for_other_optimizers()
train_step(opt)
opt.step()
This seems very inefficient. For example, what if I have 3 optimizers, and they all require embeddings for the current batch? Then embeddings = model(x)
is going to be called 3 times, when it only needs to be called once.