I am training a DA network where I use GRL in the discriminator to train the encoder?
Is the GRL layer implementation in PL similar to the PyTorch one?
from torch.autograd import Function
class GradReverse(Function):
@staticmethod
def forward(ctx, x, alpha):
ctx.alpha = alpha
return x.view_as(x)
print(alpha)
@staticmethod
def backward(ctx, grad_output):
output = grad_output * -ctx.alpha
return output, None
def grad_reverse(x,alpha):
return GradReverse.apply(x,alpha)
####
#In the discriminator forward pass
####
def forward(self, y,alpha):
self.alpha=alpha
y = grad_reverse(y,self.alpha)
y=self.classifier3(y)
return y