Hi, I’m trying to implement Adversarial Training using Pytorch-Lightning.
Problem is that, code of adversarial training in Pytorch is:
for i, (data, target) in enumerate(train_dataloader): ... loss = loss_fn(model(data), target) loss.backward() optimizer.step() ... adv_data = perturb(data) adv_loss = loss_fn(model(adv_data), target) adv_loss .backward() optimizer.step()
that is, the optimizer will step twice with two different loss in one single loop. I could use some help on how to implement this using pytorch-lightning.