Adopting exponential moving average (EMA) for PL pipeline

Hello,

I wonder which would be the correct way to implement the EMA step with lightning.

Example attempt,

*from torch_ema import ExponentialMovingAverage*    
class SomeModel(pl.LightningModule):
    def __init__(self):
        super().__init__()

        self.criterion = SomeLoss()
        self.encoder = encoder()
        self.head = nn.Sequential(...)
        self.ema = ExponentialMovingAverage(self.encoder.parameters(), decay=0.995)

    def forward(self, x):
        [...]
        return logit
 
    def training_step(self, batch, batch_idx):
        [...]
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=(1e-3) * 3)
        scheduler = {'scheduler': torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=len(train_loader), T_mult=1, eta_min=0, last_epoch=-1, verbose=False), 'interval': 'step'}
        return [optimizer], [scheduler]

    def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, closure, on_tpu=False,         using_native_amp=False, using_lbfgs=False):

        ***self.ema.update(self.encoder.parameters())***
        optimizer.step(closure=closure)
  1. Do optimizer_step override configure_optimizers and intended schedules?
  2. Which is the best way to modify optimizer.step, does my implementation make sense?

Thank you in advance.

I checked the repo, it should be done after optimizer.step()

so two ways to do this:

def optimizer_step(self, *args, **kwargs):
    super().optimizer_step(*args, **kwargs)
    self.ema.update(model.parameters())

or just override on_before_zero_grad hook, no need to touch optimizer_step :slight_smile: :

def on_before_zero_grad(self, *args, **kwargs):
    self.ema.update(model.parameters())
1 Like

Hi, why not at the end of training_step?

Hello @YazanGhafir! I just stumbled upon this thread and figured I would give your question an attempted answer. My understanding is that EMA should be applied after optimizer.step() since you need the updated weights in order to update them with EMA: theta_(t+1) = lambda*theta_(t) + (1-lambda)*theta_(t+1).

However, optimizer.step() happens outside of training_step, hence the need to apply it in optimizer_step or _on_before_zero_grad.

1 Like