How can I freeze a portion of the model during training. I set the requires_grad = False
for params in the portion of the model I want to freeze but I keep running into DDP unused parameters errors. Can someone let me know if I am missing a step?
Here is an example with a dummy model:
import torch
from torch.utils.data import DataLoader, Dataset
from lightning.pytorch import LightningModule, Trainer
class RandomDataset(Dataset):
def __init__(self, size, length):
self.len = length
self.data = torch.randn(length, size)
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return self.len
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.backbone = torch.nn.Linear(32, 32)
self.backbone.requires_grad_(False) # freeze backbone
self.head = torch.nn.Linear(32, 2)
def forward(self, x):
return self.head(self.backbone(x))
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
# select trainable parameters:
return torch.optim.SGD(self.head.parameters(), lr=0.1)
train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
model = BoringModel()
trainer = Trainer(max_epochs=1, accelerator="cpu", devices=2)
trainer.fit(model, train_data)
It shows how you can freeze a submodule and train only the other part of the model (using DDP).
Thanks for the response, the way I was thinking about it was more aligned with how we can freeze after some number of epochs. Something more along the lines of:
class BoringModel(LightningModule):
def __init__(self):
super().__init__()
self.backbone = torch.nn.Linear(32, 32)
self.head = torch.nn.Linear(32, 2)
self.freeze_after_n_epochs = 5
self.backbone_is_frozen = False
def forward(self, x):
return self.head(self.backbone(x))
def on_train_epoch_end(self) -> None:
if not self.backbone_is_frozen:
if self.freeze_after_n_epochs == self.current_epoch:
self.backbone.requires_grad_(False)
self.backbone_is_frozen = True
return super().on_train_epoch_end()
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
# select trainable parameters:
return torch.optim.SGD(self.head.parameters(), lr=0.1)
The main issue i have with this code is that it runs into an unused parameter issue with DDP.
Oh if change it during training, then you’ll have to set
trainer = Trainer(..., strategy="ddp_find_unused_parameters_true")
This will result in a performance hit normally though. The alternative is to do multiple trainer.fit() calls sequentially as you are freezing more of your network
trainer = Trainer(max_epochs=10)
trainer.fit(...)
# freeze portions of the model and train for more epochs
trainer.max_epochs = 20
trainer.fit(...)
etc.
Ah okay. Was hoping to avoid having the do setting find unused parameters to true but that’s fine. Thanks for confirming this!