Where should code to compute dataset-level stats go?

Hi all, new to PL and trying to learn it by coding… I am trying to use LightningDataModule() in a computer vision scenario. I want to compute the dataset MEAN and STD over all images in the training set. In a multi-GPU situation the docs say that prepare_data() method is executed on ONE GPU only (per node if we have multiple nodes). My question is where should the dataset-level mean and std be computed? seems that it can be done in the prepare_data() method, but how would the stats be shared across GPUs on a node? can we save the stats in LightningDataModule() instance variables, e.g. self.mean = ... and just use them in other methods where needed, e.g. in setup() we’d have transform = transforms.Normalize(self.mean, self.std)? I don’t think it can be computed in the setup() method because that method sees only a portion of the data per GPU (I think) which is not what we’re looking for. Last option I can think of is to compute the stats in init() , but I’m not sure what happens in that case. Main goal is to understand Lightning, not solving this particular problem… details/links for reading materials/examples are greatly appreciated.

Hi @Quazi, if you intend to use mean and std for normalisation of dataset during training then prepare_data would be a better approach.
prepare_data is ensured to be called from a single process. Since your dataset mean and std won’t change until you change your Dataset, I will suggest you to save it offline for reuse (in other words cache the mean std).