Difference between BertForSequenceClassification and Bert + nn.Linear

I was trying to fine tune BERT for a continuous outcome (ranges between 0-400). I noticed a big difference in validation loss during training between loading the pre-trained BERT with BertForSequenceClassification and loading with BertModel + writing nn.Linear, dropout, loss myself.

Specifically, using BertForSequenceClassification seems to work fine, with validation loss decreasing in each epoch. But if l load the same pre-trained BERT using BertModel.from_pretrained and writing the linear, dropout, and loss myself, the validation loss quickly stagnates:

The data, seed, hardware and much of the code are the same, with the only difference being the snippets below:

    self.config.num_labels = self.hparams.num_labels
    self.model = BertForSequenceClassification.from_pretrained(self.hparams.bert_path, config=self.config)
    self.tokenizer = BertTokenizer.from_pretrained(self.hparams.bert_path)

def forward(self, **inputs):
    return self.model(**inputs)

def training_step(self, batch, batch_idx):
    outputs = self(**batch)
    loss = outputs[0]
    return loss

And

    self.config.num_labels = self.hparams.num_labels
    self.model = BertModel.from_pretrained(self.hparams.bert_path, config=self.config)
    self.tokenizer = BertTokenizer.from_pretrained(self.hparams.bert_path)
    self.drop = nn.Dropout(p=self.hparams.dropout)
    self.out = nn.Linear(self.model.config.hidden_size, self.hparams.num_labels)
    self.loss = nn.MSELoss()

def forward(self, input_ids, att_mask):
    res = self.model(input_ids = input_ids, attention_mask = att_mask)
    dropout_output = self.drop(res.pooler_output)
    out = self.out(dropout_output)
    return out

def training_step(self, batch, batch_idx):
    outputs = self(input_ids = batch["input_ids"], att_mask = batch["attention_mask"])
    loss = self.loss(outputs.view(-1), batch['labels'].view(-1))
    return loss

I’m lost as to why this is the case. I especially want the second approach to work so I can build on this further. Any advice is greatly appreciated!

The first thing I’d suggest is to check whether the backbone has the same weights in both cases. Although it seems correct to me, just verify it once though. Second check the dropout rate whether it’s same or not. Also maybe you need to check this too.

I’ve been working on this recently and compared the performance of BertForSequenceClassification vs my own classifier head.

Indeed, dropout and weight initialization seem to be the only major concerns.
In huggingface, unless a classifier-specific dropout rate is specified, the internal bert’s dropout is used (0.1 apparently), so I apply that. Weight initialization differs from the default pytorch initialization, so I take the initializing function implemented in the PretrainedBert class and run it over my classifier as well.

Here’s how I do it:

        # Load model and add classification head
        self.model = AutoModel.from_pretrained(huggingface_model)
        self.classifier = nn.Linear(self.model.config.hidden_size, num_labels)

        # Init classifier weights according to initialization rules of model
        self.model._init_weights(self.classifier)

        # Apply dropout rate of model
        dropout_prob = self.model.config.hidden_dropout_prob
        log.info(f"Dropout probability of classifier set to {dropout_prob}.")
        self.dropout = nn.Dropout(dropout_prob)

Those two things helped me get performance parity between the two models!
These are my quick conclusions from testing on the MNLI task, but I may have missed stuff. More comments are welcome, for the sake of future search results leading to this page.

More code regarding implementing huggingface models in pytorch lightning can be found in my github template repository!