How to set some special layers to float32 when training use mix-precision float16

This is the problem: I want to train sdxl with pytorch lightning, sdxl is so big that I have to use float16 to train it to save GPU memory, but the encoder of sdxl is not stable and very easy to corrupt in float16, the encoder is part of sdxl, so I have the need to set the encoder of sdxl to float32, and the rest of the network to float16. Unfortunately, I can’t find any docs to instruct me how to complete it, can your guys give me some advice?

If you’re using mixed precision training then the weights of all layers stay in float32, only the operations carried out by PyTorch are cast to float16 when appropriate. If you need certain operations to be carried out without autocast, you can disable it in a particular section in your forward:

with torch.autocast(enabled=False):
    # your code to exclude from autocast here

My problem is opposite, the weights of all layers is almost float16, only one or two module is float32, I have try the method you mentioned above but failed. I can post my code bellow:

class StableDiffusionXL(pl.LightningModule):

def __init__(self, unet, vae, controlnet, noise_scheduler, lr, use_ema=True):
    super().__init__()
    self.unet = unet
    self.vae = vae
    self.controlnet = controlnet
    self.noise_scheduler = noise_scheduler
    self.learning_rate = lr
    self.use_ema = use_ema
    self.vae.requires_grad_(False)
    self.vae = self.vae.to(dtype=torch.float32)
    self.unet.requires_grad_(False)
    self.unet.enable_gradient_checkpointing()
    self.unet.enable_xformers_memory_efficient_attention()
    self.controlnet.enable_gradient_checkpointing()
    self.controlnet.enable_xformers_memory_efficient_attention()
    

def compute_time_ids(self, original_size, crops_coords_top_left, target_size):
    # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
    add_time_ids = list(original_size + crops_coords_top_left + target_size)
    add_time_ids = torch.tensor([add_time_ids]).to(self.device)
    return add_time_ids

def compute_snr(self, timesteps):
    """
    Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
    """
    alphas_cumprod = self.noise_scheduler.alphas_cumprod
    sqrt_alphas_cumprod = alphas_cumprod**0.5
    sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5

    sqrt_alphas_cumprod = sqrt_alphas_cumprod[timesteps].float()
    while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
        sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
    alpha = sqrt_alphas_cumprod.expand(timesteps.shape)

    sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[timesteps].float()
    while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
        sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
    sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)

    # Compute SNR.
    snr = (alpha / sigma) ** 2
    return snr

def training_step(self, batch, batch_idx):
    # training_step defines the train loop.
    # if self.trainer.is_global_zero:
    #     print("pixel_values = ", torch.mean(batch['pixel_values']))
    with DisableAutocast():
        model_input = self.vae.encode(batch['pixel_values'].to(torch.float32)).latent_dist.sample()
    if self.trainer.is_global_zero:
        print("before model_input = ", torch.mean(model_input))
        # for name, param in self.named_parameters():
        #     if 'vae' in name:
        #         self.log(f"{name}_dtype", param.dtype, on_step=True, on_epoch=False, prog_bar=True, logger=True)
    model_input = model_input * self.vae.config.scaling_factor
    # if self.trainer.is_global_zero:
    #     print("model_input = ", torch.mean(model_input))
    noise = torch.randn_like(model_input).to(self.device)
    timesteps = torch.randint(0, self.noise_scheduler.config.num_train_timesteps, (model_input.shape[0],)).long().to(self.device)
    noisy_model_input = self.noise_scheduler.add_noise(model_input, noise, timesteps)
    # print("timesteps = ", torch.mean(timesteps))
    # if self.trainer.is_global_zero:
    #     print("noisy_model_input = ", torch.mean(noisy_model_input))
    add_time_ids = torch.cat(
        [self.compute_time_ids(s, c, t) for s, c, t in zip(batch["original_sizes"], batch["crop_top_lefts"], batch['target_sizes'])]
    ).to(self.device)
    # if self.trainer.is_global_zero:
    #     print("add_time_ids = ", torch.mean(add_time_ids))
    down_block_res_samples, mid_block_res_sample = self.controlnet(
                noisy_model_input,
                timesteps,
                encoder_hidden_states=batch["prompt_embeds"],
                added_cond_kwargs={"time_ids": add_time_ids, "text_embeds": batch["pooled_prompt_embeds"]},
                controlnet_cond=batch['conditioning_pixel_values'],
                return_dict=False,
            )
    # if self.trainer.is_global_zero:
    #     print("down_block_res_samples = ", torch.mean(down_block_res_samples))
    #     print("mid_block_res_sample = ", torch.mean(mid_block_res_sample))
    model_pred = self.unet(
        noisy_model_input, timesteps, batch["prompt_embeds"], added_cond_kwargs={"time_ids": add_time_ids, "text_embeds": batch["pooled_prompt_embeds"]}, \
        down_block_additional_residuals=down_block_res_samples,
        mid_block_additional_residual=mid_block_res_sample,
    ).sample
    # if self.trainer.is_global_zero:
    #     print("model pred = ", torch.mean(model_pred))
    #     print("noise = ", torch.mean(noise))
    loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")
    self.log("loss", loss, prog_bar=True)
    return loss

I just want to set the vae layers to float32, but the method didn’t work, do you have any other questions?
My training setting is:

trainer = pl.Trainer(profiler=‘simple’, accumulate_grad_batches=8, gradient_clip_val=1.0, max_epochs=20,
logger=tensorboard, devices=4, precision=16, callbacks=[checkpoint_callback], num_nodes=1, strategy=‘deepspeed_stage_2’)#