I am training a DA network where I use GRL in the discriminator to train the encoder?
Is the GRL layer implementation in PL similar to the PyTorch one?
from torch.autograd import Function
class GradReverse(Function):
def forward(ctx, x, alpha):
ctx.alpha = alpha
return x.view_as(x)
def backward(ctx, grad_output):
output = grad_output * -ctx.alpha
return output, None
def grad_reverse(x,alpha):
return GradReverse.apply(x,alpha)
#In the discriminator forward pass
def forward(self, y,alpha):
y = grad_reverse(y,self.alpha)
return y