Computing gradients wrt inputs within training_step

I am trying to implement a WGAN-GP (gradient penalty) by defining the following function which is called inside training_step() method.

def wgan_gradient_penalty(
        real: torch.Tensor,
        fake: torch.Tensor,
        discriminator: torch.nn.Module) -> torch.Tensor:

    alpha = torch.rand(real.size(0), 1, 1, 1).type_as(real)

    x_hat = alpha * real + (1 - alpha) * fake.detach()
    x_hat.requires_grad = True

    # calc. d_hat: discriminator output on x_hat
    d_hat = discriminator(x_hat)

    # calc. gradients of d_hat vs. x_hat
    grads = torch.autograd.grad(

But it seems that the output of the network does is detached from the network. When I check d_hat.grad_fn which is None.

(Pdb) print(d_hat.grad_fn)

and therefore, the grad_fn is not defined, and it results in the following error:

*** RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

I think you are looking for x_hat.grad which holds the value of the gradient of d_hat w.r.t. x_hat.