@andresrosso the issue seems to be related to the stability of nn.BCELoss
in PyTorch (not a Lightning issue). I was able to get consistent results be removing the final sigmoid
operation, and using nn.BCEWithLogitsLoss
instead (docs). This is usually more stable as the log-sum-exp trick is used.
That being said, operations should ideally be consistent across devices, so we may want to file an issue to PyTorch, as it seems this issue was introduced in 1.6.