Suppose you have
class LitModel(LightningModule)
def __init__(self):
self.some_tensor = torch.rand(1,2,3)
How to make sure the tensor gets moved to the right device when training on GPU?
Suppose you have
class LitModel(LightningModule)
def __init__(self):
self.some_tensor = torch.rand(1,2,3)
How to make sure the tensor gets moved to the right device when training on GPU?
Answer:
use register_buffer
, this is a PyTorch method you can call on any nn.Module.
class LitModel(LightningModule)
def __init__(self):
self.register_buffer("some_tensor", torch.rand(1,2,3))