I want to implement a coarse-to-fine strategy.
Assuming an input and label. Step1: pred_1 = model (input), loss = L_1(pred_1, label), gradient update. Step2: pred_2 = model (pred_1), loss = L_1(pred_1, label), gradient update. Step3: pred_3 = model (pred_2), loss = L_1(pred_3, label), gradient update…
How are you changing the label
here? I’d encourage you to take a look at Lightning Fabric that you can use to write a regular PyTorch loop and accelerate with Lightning.