Validation loss does not decrease while training

Hey everyone!
I am quite new to deep learning and PyTorch lightning, and I have some issues with my loss of values while trying to pre-train BERT for a recommendation from scratch.

I followed this tutorial Build Your Own Movie Recommender System Using BERT4Rec | by Youness Mansar | Towards Data Science (and use the GitHub code as my starting point) for Bert4Rec.

Here is the relevant snippet from my module implementation

def masked_ce(y_pred, y_true, mask):
    loss = F.cross_entropy(y_pred, y_true, reduction="none")
    loss = loss * mask
    return loss.sum() / (mask.sum() + 1e-8)

def masked_accuracy(y_pred: torch.Tensor, y_true: torch.Tensor, mask: torch.Tensor):
    _, predicted = torch.max(y_pred, 1)
    y_true = torch.masked_select(y_true, mask)
    predicted = torch.masked_select(predicted, mask)
    acc = (y_true == predicted).double().mean()
    return ACC

class Recommender(pl.LightningModule):
    def __init__(self, vocabulary_size, features=128,
                 mask=1, dropout=0.4, lr=5-e5, iterations=[]):
        super().__init__()
	...
        self.item_embeddings = torch.nn.Embedding(self.vocabulary_size, embedding_dim=features)

        self.input_pos_embedding = torch.nn.Embedding(512, embedding_dim=features)

        encoder_layer = nn.TransformerEncoderLayer(d_model=features, nhead=4, dropout=self.dropout)

        self.encoder = torch.nn.TransformerEncoder(encoder_layer, num_layers=6)
        self.linear_out = Linear(features, self.vocabulary_size)

def encode_src(self, src_items):
        src_items = self.item_embeddings(src_items)

        batch_size, in_sequence_len = src_items.size(0), src_items.size(1)
        pos_encoder = (
            torch.arange(0, in_sequence_len, device=src_items.device)
                .unsqueeze(0)
                .repeat(batch_size, 1)
        )
        pos_encoder = self.input_pos_embedding(pos_encoder)

        src_items += pos_encoder

        src = src_items.permute(1, 0, 2)

        src = self.encoder(src)

        return src.permute(1, 0, 2)


    def forward(self, src_items):
        src = self.encode_src(src_items)
        out = self.linear_out(src)
        return out

    def training_step(self, batch, batch_idx):
        src_items, y_true = batch

        y_pred = self(src_items)

        y_pred = y_pred.view(-1, y_pred.size(2))
        y_true = y_true.view(-1)

        src_items = src_items.view(-1)
        mask = src_items == self.mask

        loss = masked_ce(y_pred=y_pred, y_true=y_true, mask=mask)
        accuracy = masked_accuracy(y_pred=y_pred, y_true=y_true, mask=mask)
        
        self.log("train_loss", loss)
        self.log("train_accuracy", accuracy)
        return loss

	….
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, patience=10, factor=0.1
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": scheduler,
            "monitor": "valid_loss",
        }

For some reason, after a while, my learning rate stays pretty much the same (around 2.1) throughout the training, here are the average loss of each epoch for 300 epochs:

[5.215897862021033, 4.307419357834397, 3.9978328323937036, 3.8327397668922507, 3.687536907625628, 3.5862434299381167, 3.504597392764774, 3.4267588714221575, 3.3451359359113066, 3.2565978139012426, 3.2273226185245916, 3.140159515527872, 3.1139863249775885, 3.072347399947402, 3.0406936243609026, 3.001059137903773, 2.9714541030717685, 2.932119946102719, 2.92338597309124, 2.887609561403712, 2.89176990379681, 2.8512023840103304, 2.800608117897828, 2.7806637058028945, 2.790072974857983, 2.7700434235839158, 2.7709330624646253, 2.737586958272321, 2.7506634561387866, 2.7047642879896574, 2.7033177772919097, 2.6592260257856504, 2.6874104410081774, 2.666762097342475, 2.6342785286831782, 2.629130595797175, 2.6112309293346003, 2.6169400425167293, 2.5992956018304683, 2.5875380239687167, 2.598613882446671, 2.5834985335429272, 2.5814251636002994, 2.55975524483023, 2.563504737597686, 2.5425207213954524, 2.5441495048868523, 2.527069270789802, 2.519392321894954, 2.4823296286083676, 2.498437447471542, 2.502717134472844, 2.497532544909297, 2.489024817943573, 2.481112331062466, 2.4673362061664745, 2.4719012540739937, 2.460711818736595, 2.462394949731168, 2.44435056694993, 2.448073567630531, 2.4373490108742013, 2.41340720516306, 2.40083835719226, 2.430739034045566, 2.404747154738929, 2.4160519766735957, 2.4129087116505885, 2.4081945191990504, 2.395370520807959, 2.3460567462909685, 2.4075533818076917, 2.3823164269492194, 2.36776337197593, 2.3636932604544394, 2.380185321167305, 2.371870071501345, 2.375798768348045, 2.372823623147932, 2.3598847922620116, 2.3406515125636465, 2.355634041764476, 2.3676569249417567, 2.345680000128092, 2.3391598698732494, 2.315167697998616, 2.333602963267146, 2.303073968316938, 2.327711789338319, 2.3321030771171487, 2.346811938691545, 2.3228134136419514, 2.324048904088644, 2.2958460228221194, 2.2882123300143786, 2.2898601495229207, 2.2825498343468666, 2.2881406057584988, 2.282753063095463, 2.2780174417896673, 2.3061474098696246, 2.2788862180900766, 2.276809148244314, 2.2573344752237245, 2.267140531444454, 2.2677663033073014, 2.2408642875062332, 2.260455063811771, 2.2410880836161287, 2.257640224617642, 2.2716933290163674, 2.2446437362913376, 2.2394013909606247, 2.247728812205302, 2.2538036140235693, 2.237980705541414, 2.2370640824626276, 2.243911793819061, 2.235663570918598, 2.2143911284966036, 2.2189845510789223, 2.218917491796377, 2.2362850300363593, 2.22802360112722, 2.2069116861851246, 2.2263007872216813, 2.1958916097789913, 2.2124622908082454, 2.215156395156104, 2.218396941343466, 2.2063985850001955, 2.2218682895432242, 2.188224109980437, 2.186440182997061, 2.2213237939534842, 2.1583791249268525, 2.1540346879500882, 2.1667329162448734, 2.162871088888552, 2.1509411728179253, 2.1487952094894274, 2.1423775988059477, 2.134084624929113, 2.122743466236928, 2.1457578765976058, 2.153763327870641, 2.1373268459532953, 2.123572857888253, 2.1082677820900657, 2.1465463696776688, 2.121045095008892, 2.114945897170612, 2.1025079012633086, 2.143361220488677, 2.1197627566240214, 2.1106052048153825, 2.145451051217538, 2.127855628997356, 2.118844061881095, 2.133004295157718, 2.123980550615637, 2.107732644548884, 2.1369195053646632, 2.1227938454072395, 2.107776931217602, 2.106330737039014, 2.1286828901555324, 2.1234672562376753, 2.1153388228025043, 2.1260185034067423, 2.126335740387738, 2.103722700664589, 2.103477485067732, 2.087210615356644, 2.1357034185269215, 2.120789026116227, 2.097256830385378, 2.081185018753743, 2.0951963311081774, 2.1229392157063947, 2.112820274419374, 2.0973021252615913, 2.1015769997993865, 2.078238507648846, 2.1254136346362613, 2.122731619530373, 2.106998563767434, 2.088252352612155, 2.0864079854151867, 2.1210617853237226, 2.1171441963604383, 2.078529192401363, 2.0930584185832255, 2.0759896338284314, 2.116819376224751, 2.081939611170027, 2.119148053206481, 2.1262888721636943, 2.108645480614644, 2.119510283400943, 2.0977860467211977, 2.1132550431682064, 2.0850224250907057, 2.09263404950246, 2.1319653992061025, 2.095875710815758, 2.107226548252163, 2.1082355374091857, 2.1068308096867545, 2.1049049644618183, 2.082865047383237, 2.104338508170169, 2.0929228865706526, 2.080949502127307, 2.085981863874334, 2.1092890691351487, 2.1200141368566214, 2.102557414048188, 2.114553865608391, 2.0953348426966816, 2.1051940550436608, 2.0884465265262113, 2.095959615898323, 2.109497363144929, 2.10603609481254, 2.0999864392929726, 2.107128706660953, 2.084229455099211, 2.097175305849081, 2.1018672287762463, 2.100449399964826, 2.077514022439569, 2.1022495352350794, 2.069558752907647, 2.0984696185027993, 2.0981300791223965, 2.0977061387893556, 2.123796395771019, 2.092895509721758, 2.0983876811491475, 2.137945756539926, 2.1082705245003686, 2.0934668199436084, 2.095478994471652, 2.0837738592106776, 2.0994501587626218, 2.093280213820684, 2.0901179890911856, 2.0806633436882698, 2.105752343052739, 2.12362516451407, 2.0749484598338306, 2.089124313584558, 2.10382049810421, 2.085073526616808, 2.0793647033435567, 2.093597667711275, 2.111878449196095, 2.1113540356581635, 2.0885868385150745, 2.09748313329122, 2.099132583425329, 2.113959271091599, 2.091684911642466, 2.0930197149425656, 2.1119412023264603, 2.109747784572082, 2.1128484507700107, 2.1090998791598223, 2.1068240908889084, 2.085279980161646, 2.092537217133038, 2.098047769045806, 2.1106533031086543, 2.1191082527568272, 2.0863734186590612, 2.081073121086613, 2.1127054432968237, 2.106303176543376, 2.116069818223203, 2.099261497830724, 2.0956528060906403, 2.10383954378697, 2.0963755933730095, 2.105407002750221, 2.1021434911736496, 2.0790102963333017, 2.1066373411241592, 2.095192236823959, 2.0882216465246453, 2.1060060731641523, 2.0814747804397338, 2.0748307990120933, 2.0975484999569805, 2.0896632927316086, 2.122882520174002, 2.1060680718393296, 2.10889205035266, 2.090368301302821]

300epochs

{‘test_accuracy’: 0.49178826808929443, ‘test_loss’: 2.17501163482666}

What could be the reason for it?

I am working with around 32k sequences, each is a minimum of 7 items long, could it be that it is simply not enough data for training from scratch?

Some of the responses for similar issues suggest playing around with the hyperparameters, so I did try a few things:

  • Since it could be stuck in local min, it could be that the learning rate needs to be changed in order to be able to go out of it (is that right?) - I tried to change my learning rate from 1-e4 to 5-e5, it didn’t help much
  • In order to check if the training work as expected, I tried to overfit my model on a small number of datasets (10) and the avg of the loss for the first 20 epochs looks as follow:
    [0.0, 0.0, 0.0, 1.2672607898712158, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.9458177089691162, 0.0, 0.0, 0.0, 1.6752853393554688, 0.0, 0.0]

Any idea or suggestions would be very appreciated