Hi !
I would like to initialize some of my model weights based on statistics from the data (more specifically clusters from the data).
IIUC, in order to do this I need to implement a setup
method in my LightningModule
.
I got this idea from the docs, which illustrate a setup
implementation as:
def setup(self, stage):
data = load_data(...)
self.l1 = nn.Linear(28, data.num_classes)
However, 2 things are not clear to me:
1/ How do I actually load the data? I would like to reuse the datamodule that has beeen setup just before in the _call_setup_hook
method of the trainer, but I don’t know how I can access it.
2/ How will the weight reduction happen across processes in a DDP setting? Basically, how do I make sure that all my processes have the same initialization. If I were to use the train data loaders (should it be possible), it feels like I could end up with very different initializations because the data statistics will be different and therefore the processes will not have the same initialization.