Here’s a minimal example of the code in standard PyTorch/Lightning:
def compute_jd_loss(self, demo_obs: Tensor):
demo_obs.requires_grad_(True)
demo_dict = self.jd_forward(obs=demo_obs, return_norm_obs=True)
demo_logits = demo_dict["outs"]
demo_norm_obs = demo_dict["norm_obs"]
# grad penalty
disc_demo_grad = torch.autograd.grad( # <--- Fails here
demo_logits,
demo_norm_obs,
grad_outputs=torch.ones_like(demo_logits),
create_graph=True,
retain_graph=True,
only_inputs=True,
)[0]
disc_demo_grad_norm = torch.sum(torch.square(disc_demo_grad), dim=-1)
disc_grad_penalty = torch.mean(disc_demo_grad_norm)
grad_loss: Tensor = self.config.jd_grad_penalty * disc_grad_penalty
Fabric crashes, as it requires we call fabric.backward(...)
. What would be the “fabric”-way of implementing a gradient penalty loss?
Thanks!