I am trying to implement a WGAN-GP (gradient penalty) by defining the following function which is called inside training_step() method.
def wgan_gradient_penalty(
real: torch.Tensor,
fake: torch.Tensor,
discriminator: torch.nn.Module) -> torch.Tensor:
alpha = torch.rand(real.size(0), 1, 1, 1).type_as(real)
x_hat = alpha * real + (1 - alpha) * fake.detach()
x_hat.requires_grad = True
# calc. d_hat: discriminator output on x_hat
d_hat = discriminator(x_hat)
# calc. gradients of d_hat vs. x_hat
grads = torch.autograd.grad(
outputs=d_hat,
inputs=x_hat,
grad_outputs=torch.ones(d_hat.size()).type_as(real),
create_graph=True,
retain_graph=True)[0]
But it seems that the output of the network does is detached from the network. When I check d_hat.grad_fn
which is None.
(Pdb) print(d_hat.grad_fn)
None
and therefore, the grad_fn is not defined, and it results in the following error:
*** RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn