How to implement Linear Probing for first N epochs and then switch to fine-tuning?

Hello, I’m thinking how I should implement a training techninque from Fine-Tuning can Distort Pretrained Features and Underperform Out-of-Distribution paper. Essentialy what authors describe is to freeze all model weights except softmax layer for beginning of training and after that switch to fine-tuning. I’m working on BERT-like models from transformers. Also how I could do this switch to fine-tuning gradual (let’s say every epoch unfreeze 1 top layer from transformer)?

Hi @Konrad, you can use the BaseFinetuning callback to achieve this.

You will need to override the freeze_before_training and finetune_function methods with logic to unfreeze 1 top layer at the start of each epoch. Let me know if you face any issue while implementing it.

Also, We are moving support and community discussion from this forum to GitHub Discussions , as it makes questions more discoverable and keeps all the knowledge in one single place!