Am I right in assuming that the LightningDatamodule methods prepare_data() and setup() should not create torch.Tensors (or more generally, any action that would assign data to a device)?
For example, if I wanted to apply a transform that included a torch.as_tensor() call, then this should happen in the train/val/test_dataloader() methods, not in the prepare_data() or setup() methods, correct?
I ask this because my understanding is that the code doesn’t (shouldn’t) know about the engineering/hardware unless it is run by a Trainer. Yet on the LightningDataModule doc page, it is suggested that “when information about the dataset is needed to build the model”, prepare_data() and setup() can be called outside the Trainer:
dm = MNISTDataModule() dm.prepare_data() dm.setup('fit') model = Model(num_classes=dm.num_classes, width=dm.width, vocab=dm.vocab) trainer.fit(model, dm) dm.setup('test') trainer.test(datamodule=dm)
dm = MNISTDataModule()
dm.prepare_data()
dm.setup('fit')
model = Model(num_classes=dm.num_classes, width=dm.width, vocab=dm.vocab)
trainer.fit(model, dm)
dm.setup('test')
trainer.test(datamodule=dm)
Finally, if this is right, is there a rule of thumb to make sure I don’t accidentally call a method that uses engineering/hardware knowledge behind the scenes?
P.S. The docs state
prepare_data is called from a single GPU. Do not use it to assign state (self.x = y)
and
setup is called from every GPU. Setting state here is okay
but these comments weren’t enough for me to feel confident in what is/isn’t allowed.