Hello,
I wonder which would be the correct way to implement the EMA step with lightning.
Example attempt,
*from torch_ema import ExponentialMovingAverage*
class SomeModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.criterion = SomeLoss()
self.encoder = encoder()
self.head = nn.Sequential(...)
self.ema = ExponentialMovingAverage(self.encoder.parameters(), decay=0.995)
def forward(self, x):
[...]
return logit
def training_step(self, batch, batch_idx):
[...]
return loss
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=(1e-3) * 3)
scheduler = {'scheduler': torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=len(train_loader), T_mult=1, eta_min=0, last_epoch=-1, verbose=False), 'interval': 'step'}
return [optimizer], [scheduler]
def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False, using_native_amp=False, using_lbfgs=False):
***self.ema.update(self.encoder.parameters())***
optimizer.step(closure=closure)
- Do
optimizer_step
overrideconfigure_optimizers
and intended schedules? - Which is the best way to modify optimizer.step, does my implementation make sense?
Thank you in advance.