FSDP not reducing memory for non-trainable submodule

Hi, I have a LightningModule with 2 models:

  1. teacher with requires_grad= false: huggingface pretrained gpt-xl model placed in. eval() mode and with torch.no_grad() in forward pass. -1.5B params
  2. trainable student -124M params

I am trying to leverage fsdp to partition both teacher and student weights accross 6 gpus, I am particularly interested in reducing memory footprint of teacher, but I am targeting both.

However, my runs with FSDP/deepspeed_zero_3 don`t show any improvement in memory usage compared to ddp_sharded/deepspeed_zero_2.

Is there something fundamentally wrong with my approach. Why doesn`t fsdp reduce the memory usage in my teacher?