I’m correctly working on a text-to-image diffusion model. It has a small trainable UNet and a very large text encoder. I have a cluster with 32 GPUs (4 nodes). If I deploy both the UNet and the text encoder on every GPU, it will consume enormous memory. As a result, I have to encode texts in the dataset to text embeddings and save them for training. However, this leads to unacceptable storage usage.
Is it possible to run the training of UNet on 28 GPUs (7 per node) and use the rest 4 GPUs (1 per node) to run the text encoder such that I can encode the text prompts on-the-run and pass the resultant text embedding to the training devices? If so, could you please provide me an implementation guide such as what to implement in Trainer, what to implement in DataModule, and etc.
Looking forward to your reply.