Freezing portions of the model during training

How can I freeze a portion of the model during training. I set the requires_grad = False for params in the portion of the model I want to freeze but I keep running into DDP unused parameters errors. Can someone let me know if I am missing a step?

Here is an example with a dummy model:

import torch
from torch.utils.data import DataLoader, Dataset

from lightning.pytorch import LightningModule, Trainer


class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.backbone = torch.nn.Linear(32, 32)
        self.backbone.requires_grad_(False)  # freeze backbone
        self.head = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.head(self.backbone(x))

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        # select trainable parameters:
        return torch.optim.SGD(self.head.parameters(), lr=0.1)


train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(max_epochs=1, accelerator="cpu", devices=2)
trainer.fit(model, train_data)

It shows how you can freeze a submodule and train only the other part of the model (using DDP).

Thanks for the response, the way I was thinking about it was more aligned with how we can freeze after some number of epochs. Something more along the lines of:

class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.backbone = torch.nn.Linear(32, 32)
        self.head = torch.nn.Linear(32, 2)
        self.freeze_after_n_epochs = 5
        self.backbone_is_frozen = False

    def forward(self, x):
        return self.head(self.backbone(x))

    def on_train_epoch_end(self) -> None:
        if not self.backbone_is_frozen:
            if self.freeze_after_n_epochs == self.current_epoch:
                self.backbone.requires_grad_(False)
                self.backbone_is_frozen = True
        return super().on_train_epoch_end()

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        # select trainable parameters:
        return torch.optim.SGD(self.head.parameters(), lr=0.1)

The main issue i have with this code is that it runs into an unused parameter issue with DDP.

Oh if change it during training, then you’ll have to set

trainer = Trainer(..., strategy="ddp_find_unused_parameters_true")

This will result in a performance hit normally though. The alternative is to do multiple trainer.fit() calls sequentially as you are freezing more of your network

trainer = Trainer(max_epochs=10)
trainer.fit(...)  
# freeze portions of the model and train for more epochs
trainer.max_epochs = 20
trainer.fit(...)
etc.

Ah okay. Was hoping to avoid having the do setting find unused parameters to true but that’s fine. Thanks for confirming this!