FullyShardedDataParallel no memory decrease

Thanks for the suggestion.

I needed to bump torch version (now v1.13.0) to access native FSDP. I’m now seeing the following error during the forward pass:

RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.FloatTensor) should be the same

It seems like the model weights haven’t been moved to the GPU or converted to 16-bit precision as expected by the precision=16 flag. Any thoughts on this?

Brett