How would you perform gradient penalty loss in Fabric? Similar to WGAN-GP

Here’s a minimal example of the code in standard PyTorch/Lightning:

    def compute_jd_loss(self, demo_obs: Tensor):
        demo_obs.requires_grad_(True)
        demo_dict = self.jd_forward(obs=demo_obs, return_norm_obs=True)
        demo_logits = demo_dict["outs"]
        demo_norm_obs = demo_dict["norm_obs"]

        # grad penalty
        disc_demo_grad = torch.autograd.grad(  # <--- Fails here
            demo_logits,
            demo_norm_obs,
            grad_outputs=torch.ones_like(demo_logits),
            create_graph=True,
            retain_graph=True,
            only_inputs=True,
        )[0]
        disc_demo_grad_norm = torch.sum(torch.square(disc_demo_grad), dim=-1)
        disc_grad_penalty = torch.mean(disc_demo_grad_norm)
        grad_loss: Tensor = self.config.jd_grad_penalty * disc_grad_penalty

Fabric crashes, as it requires we call fabric.backward(...). What would be the “fabric”-way of implementing a gradient penalty loss?

Thanks!