In the __init__()
of LightningModule
instead of
self.bert = BertModel.from_pretrained(BERT_MODEL_NAME, return_dict=True)
self.classifier = nn.Linear(self.bert.config.hidden_size, n_classes)
you need to put
self.bert = BertForSequenceClassification.from_pretrained(BERT_MODEL_NAME, num_labels=YOUR_NUM_OF_CLASSES)
Of course, don’t forget that this change will affect some other methods in the pl.LightningModule
class. so … up to you to change them
good luck!