Hi, I’m trying to implement Adversarial Training using Pytorch-Lightning.
Problem is that, code of adversarial training in Pytorch is:
for i, (data, target) in enumerate(train_dataloader):
...
loss = loss_fn(model(data), target)
loss.backward()
optimizer.step()
...
adv_data = perturb(data)
adv_loss = loss_fn(model(adv_data), target)
adv_loss .backward()
optimizer.step()
that is, the optimizer will step twice with two different loss in one single loop. I could use some help on how to implement this using pytorch-lightning.
Thank you!