Hey everyone!
I am quite new to deep learning and PyTorch lightning, and I have some issues with my loss of values while trying to pre-train BERT for a recommendation from scratch.
I followed this tutorial https://towardsdatascience.com/build-your-own-movie-recommender-system-using-bert4rec-92e4e34938c5 (and use the GitHub code as my starting point) for Bert4Rec.
Here is the relevant snippet from my module implementation
class Recommender(pl.LightningModule):
def __init__(self, vocabulary_size, features=128,
mask=1, dropout=0.4, lr=5-e5, iterations=[]):
super().__init__()
...
self.item_embeddings = torch.nn.Embedding(self.vocabulary_size, embedding_dim=features)
self.input_pos_embedding = torch.nn.Embedding(512, embedding_dim=features)
encoder_layer = nn.TransformerEncoderLayer(d_model=features, nhead=4, dropout=self.dropout)
self.encoder = torch.nn.TransformerEncoder(encoder_layer, num_layers=6)
self.linear_out = Linear(features, self.vocabulary_size)
def encode_src(self, src_items):
src_items = self.item_embeddings(src_items)
batch_size, in_sequence_len = src_items.size(0), src_items.size(1)
pos_encoder = (
torch.arange(0, in_sequence_len, device=src_items.device)
.unsqueeze(0)
.repeat(batch_size, 1)
)
pos_encoder = self.input_pos_embedding(pos_encoder)
src_items += pos_encoder
src = src_items.permute(1, 0, 2)
src = self.encoder(src)
return src.permute(1, 0, 2)
def forward(self, src_items):
src = self.encode_src(src_items)
out = self.linear_out(src)
return out
def training_step(self, batch, batch_idx):
src_items, y_true = batch
y_pred = self(src_items)
y_pred = y_pred.view(-1, y_pred.size(2))
y_true = y_true.view(-1)
src_items = src_items.view(-1)
mask = src_items == self.mask
loss = masked_ce(y_pred=y_pred, y_true=y_true, mask=mask)
accuracy = masked_accuracy(y_pred=y_pred, y_true=y_true, mask=mask)
self.log("train_loss", loss)
self.log("train_accuracy", accuracy)
return loss
….
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, patience=10, factor=0.1
)
return {
"optimizer": optimizer,
"lr_scheduler": scheduler,
"monitor": "valid_loss",
}
For some reason, after a while, my learning rate stays pretty much the same (around 2.1) throughout the training, here are the average loss of each epoch for 300 epochs:
[5.215897862021033, 4.307419357834397, 3.9978328323937036, 3.8327397668922507, 3.687536907625628, 3.5862434299381167, 3.504597392764774, 3.4267588714221575, 3.3451359359113066, 3.2565978139012426, 3.2273226185245916, 3.140159515527872, 3.1139863249775885, 3.072347399947402, 3.0406936243609026, 3.001059137903773, 2.9714541030717685, 2.932119946102719, 2.92338597309124, 2.887609561403712, 2.89176990379681, 2.8512023840103304, 2.800608117897828, 2.7806637058028945, 2.790072974857983, 2.7700434235839158, 2.7709330624646253, 2.737586958272321, 2.7506634561387866, 2.7047642879896574, 2.7033177772919097, 2.6592260257856504, 2.6874104410081774, 2.666762097342475, 2.6342785286831782, 2.629130595797175, 2.6112309293346003, 2.6169400425167293, 2.5992956018304683, 2.5875380239687167, 2.598613882446671, 2.5834985335429272, 2.5814251636002994, 2.55975524483023, 2.563504737597686, 2.5425207213954524, 2.5441495048868523, 2.527069270789802, 2.519392321894954, 2.4823296286083676, 2.498437447471542, 2.502717134472844, 2.497532544909297, 2.489024817943573, 2.481112331062466, 2.4673362061664745, 2.4719012540739937, 2.460711818736595, 2.462394949731168, 2.44435056694993, 2.448073567630531, 2.4373490108742013, 2.41340720516306, 2.40083835719226, 2.430739034045566, 2.404747154738929, 2.4160519766735957, 2.4129087116505885, 2.4081945191990504, 2.395370520807959, 2.3460567462909685, 2.4075533818076917, 2.3823164269492194, 2.36776337197593, 2.3636932604544394, 2.380185321167305, 2.371870071501345, 2.375798768348045, 2.372823623147932, 2.3598847922620116, 2.3406515125636465, 2.355634041764476, 2.3676569249417567, 2.345680000128092, 2.3391598698732494, 2.315167697998616, 2.333602963267146, 2.303073968316938, 2.327711789338319, 2.3321030771171487, 2.346811938691545, 2.3228134136419514, 2.324048904088644, 2.2958460228221194, 2.2882123300143786, 2.2898601495229207, 2.2825498343468666, 2.2881406057584988, 2.282753063095463, 2.2780174417896673, 2.3061474098696246, 2.2788862180900766, 2.276809148244314, 2.2573344752237245, 2.267140531444454, 2.2677663033073014, 2.2408642875062332, 2.260455063811771, 2.2410880836161287, 2.257640224617642, 2.2716933290163674, 2.2446437362913376, 2.2394013909606247, 2.247728812205302, 2.2538036140235693, 2.237980705541414, 2.2370640824626276, 2.243911793819061, 2.235663570918598, 2.2143911284966036, 2.2189845510789223, 2.218917491796377, 2.2362850300363593, 2.22802360112722, 2.2069116861851246, 2.2263007872216813, 2.1958916097789913, 2.2124622908082454, 2.215156395156104, 2.218396941343466, 2.2063985850001955, 2.2218682895432242, 2.188224109980437, 2.186440182997061, 2.2213237939534842, 2.1583791249268525, 2.1540346879500882, 2.1667329162448734, 2.162871088888552, 2.1509411728179253, 2.1487952094894274, 2.1423775988059477, 2.134084624929113, 2.122743466236928, 2.1457578765976058, 2.153763327870641, 2.1373268459532953, 2.123572857888253, 2.1082677820900657, 2.1465463696776688, 2.121045095008892, 2.114945897170612, 2.1025079012633086, 2.143361220488677, 2.1197627566240214, 2.1106052048153825, 2.145451051217538, 2.127855628997356, 2.118844061881095, 2.133004295157718, 2.123980550615637, 2.107732644548884, 2.1369195053646632, 2.1227938454072395, 2.107776931217602, 2.106330737039014, 2.1286828901555324, 2.1234672562376753, 2.1153388228025043, 2.1260185034067423, 2.126335740387738, 2.103722700664589, 2.103477485067732, 2.087210615356644, 2.1357034185269215, 2.120789026116227, 2.097256830385378, 2.081185018753743, 2.0951963311081774, 2.1229392157063947, 2.112820274419374, 2.0973021252615913, 2.1015769997993865, 2.078238507648846, 2.1254136346362613, 2.122731619530373, 2.106998563767434, 2.088252352612155, 2.0864079854151867, 2.1210617853237226, 2.1171441963604383, 2.078529192401363, 2.0930584185832255, 2.0759896338284314, 2.116819376224751, 2.081939611170027, 2.119148053206481, 2.1262888721636943, 2.108645480614644, 2.119510283400943, 2.0977860467211977, 2.1132550431682064, 2.0850224250907057, 2.09263404950246, 2.1319653992061025, 2.095875710815758, 2.107226548252163, 2.1082355374091857, 2.1068308096867545, 2.1049049644618183, 2.082865047383237, 2.104338508170169, 2.0929228865706526, 2.080949502127307, 2.085981863874334, 2.1092890691351487, 2.1200141368566214, 2.102557414048188, 2.114553865608391, 2.0953348426966816, 2.1051940550436608, 2.0884465265262113, 2.095959615898323, 2.109497363144929, 2.10603609481254, 2.0999864392929726, 2.107128706660953, 2.084229455099211, 2.097175305849081, 2.1018672287762463, 2.100449399964826, 2.077514022439569, 2.1022495352350794, 2.069558752907647, 2.0984696185027993, 2.0981300791223965, 2.0977061387893556, 2.123796395771019, 2.092895509721758, 2.0983876811491475, 2.137945756539926, 2.1082705245003686, 2.0934668199436084, 2.095478994471652, 2.0837738592106776, 2.0994501587626218, 2.093280213820684, 2.0901179890911856, 2.0806633436882698, 2.105752343052739, 2.12362516451407, 2.0749484598338306, 2.089124313584558, 2.10382049810421, 2.085073526616808, 2.0793647033435567, 2.093597667711275, 2.111878449196095, 2.1113540356581635, 2.0885868385150745, 2.09748313329122, 2.099132583425329, 2.113959271091599, 2.091684911642466, 2.0930197149425656, 2.1119412023264603, 2.109747784572082, 2.1128484507700107, 2.1090998791598223, 2.1068240908889084, 2.085279980161646, 2.092537217133038, 2.098047769045806, 2.1106533031086543, 2.1191082527568272, 2.0863734186590612, 2.081073121086613, 2.1127054432968237, 2.106303176543376, 2.116069818223203, 2.099261497830724, 2.0956528060906403, 2.10383954378697, 2.0963755933730095, 2.105407002750221, 2.1021434911736496, 2.0790102963333017, 2.1066373411241592, 2.095192236823959, 2.0882216465246453, 2.1060060731641523, 2.0814747804397338, 2.0748307990120933, 2.0975484999569805, 2.0896632927316086, 2.122882520174002, 2.1060680718393296, 2.10889205035266, 2.090368301302821]
#015 #033[A[5.215897862021033, 4.307419357834397, 3.9978328323937036, 3.8327397668922507, 3.687536907625628, 3.5862434299381167, 3.504597392764774, 3.4267588714221575, 3.3451359359113066, 3.2565978139012426, 3.2273226185245916, 3.140159515527872, 3.1139863249775885, 3.072347399947402, 3.0406936243609026, 3.001059137903773, 2.9714541030717685, 2.932119946102719, 2.92338597309124, 2.887609561403712, 2.89176990379681, 2.8512023840103304, 2.800608117897828, 2.7806637058028945, 2.790072974857983, 2.7700434235839158, 2.7709330624646253, 2.737586958272321, 2.7506634561387866, 2.7047642879896574, 2.7033177772919097, 2.6592260257856504, 2.6874104410081774, 2.666762097342475, 2.6342785286831782, 2.629130595797175, 2.6112309293346003, 2.6169400425167293, 2.5992956018304683, 2.5875380239687167, 2.598613882446671, 2.5834985335429272, 2.5814251636002994, 2.55975524483023, 2.563504737597686, 2.5425207213954524, 2.5441495048868523, 2.527069270789802, 2.519392321894954, 2.4823296286083676, 2.498437447471542, 2.502717134472844, 2.497532544909297, 2.489024817943573, 2.481112331062466, 2.4673362061664745, 2.4719012540739937, 2.460711818736595, 2.462394949731168, 2.44435056694993, 2.448073567630531, 2.4373490108742013, 2.41340720516306, 2.40083835719226, 2.430739034045566, 2.404747154738929, 2.4160519766735957, 2.4129087116505885, 2.4081945191990504, 2.395370520807959, 2.3460567462909685, 2.4075533818076917, 2.3823164269492194, 2.36776337197593, 2.3636932604544394, 2.380185321167305, 2.371870071501345, 2.375798768348045, 2.372823623147932, 2.3598847922620116, 2.3406515125636465, 2.355634041764476, 2.3676569249417567, 2.345680000128092, 2.3391598698732494, 2.315167697998616, 2.333602963267146, 2.303073968316938, 2.327711789338319, 2.3321030771171487, 2.346811938691545, 2.3228134136419514, 2.324048904088644, 2.2958460228221194, 2.2882123300143786, 2.2898601495229207, 2.2825498343468666, 2.2881406057584988, 2.282753063095463, 2.2780174417896673, 2.3061474098696246, 2.2788862180900766, 2.276809148244314, 2.2573344752237245, 2.267140531444454, 2.2677663033073014, 2.2408642875062332, 2.260455063811771, 2.2410880836161287, 2.257640224617642, 2.2716933290163674, 2.2446437362913376, 2.2394013909606247, 2.247728812205302, 2.2538036140235693, 2.237980705541414, 2.2370640824626276, 2.243911793819061, 2.235663570918598, 2.2143911284966036, 2.2189845510789223, 2.218917491796377, 2.2362850300363593, 2.22802360112722, 2.2069116861851246, 2.2263007872216813, 2.1958916097789913, 2.2124622908082454, 2.215156395156104, 2.218396941343466, 2.2063985850001955, 2.2218682895432242, 2.188224109980437, 2.186440182997061, 2.2213237939534842, 2.1583791249268525, 2.1540346879500882, 2.1667329162448734, 2.162871088888552, 2.1509411728179253, 2.1487952094894274, 2.1423775988059477, 2.134084624929113, 2.122743466236928, 2.1457578765976058, 2.153763327870641, 2.1373268459532953, 2.123572857888253, 2.1082677820900657, 2.1465463696776688, 2.121045095008892, 2.114945897170612, 2.1025079012633086, 2.143361220488677, 2.1197627566240214, 2.1106052048153825, 2.145451051217538, 2.127855628997356, 2.118844061881095, 2.133004295157718, 2.123980550615637, 2.107732644548884, 2.1369195053646632, 2.1227938454072395, 2.107776931217602, 2.106330737039014, 2.1286828901555324, 2.1234672562376753, 2.1153388228025043, 2.1260185034067423, 2.126335740387738, 2.103722700664589, 2.103477485067732, 2.087210615356644, 2.1357034185269215, 2.120789026116227, 2.097256830385378, 2.081185018753743, 2.0951963311081774, 2.1229392157063947, 2.112820274419374, 2.0973021252615913, 2.1015769997993865, 2.078238507648846, 2.1254136346362613, 2.122731619530373, 2.106998563767434, 2.088252352612155, 2.0864079854151867, 2.1210617853237226, 2.1171441963604383, 2.078529192401363, 2.0930584185832255, 2.0759896338284314, 2.116819376224751, 2.081939611170027, 2.119148053206481, 2.1262888721636943, 2.108645480614644, 2.119510283400943, 2.0977860467211977, 2.1132550431682064, 2.0850224250907057, 2.09263404950246, 2.1319653992061025, 2.095875710815758, 2.107226548252163, 2.1082355374091857, 2.1068308096867545, 2.1049049644618183, 2.082865047383237, 2.104338508170169, 2.0929228865706526, 2.080949502127307, 2.085981863874334, 2.1092890691351487, 2.1200141368566214, 2.102557414048188, 2.114553865608391, 2.0953348426966816, 2.1051940550436608, 2.0884465265262113, 2.095959615898323, 2.109497363144929, 2.10603609481254, 2.0999864392929726, 2.107128706660953, 2.084229455099211, 2.097175305849081, 2.1018672287762463, 2.100449399964826, 2.077514022439569, 2.1022495352350794, 2.069558752907647, 2.0984696185027993, 2.0981300791223965, 2.0977061387893556, 2.123796395771019, 2.092895509721758, 2.0983876811491475, 2.137945756539926, 2.1082705245003686, 2.0934668199436084, 2.095478994471652, 2.0837738592106776, 2.0994501587626218, 2.093280213820684, 2.0901179890911856, 2.0806633436882698, 2.105752343052739, 2.12362516451407, 2.0749484598338306, 2.089124313584558, 2.10382049810421, 2.085073526616808, 2.0793647033435567, 2.093597667711275, 2.111878449196095, 2.1113540356581635, 2.0885868385150745, 2.09748313329122, 2.099132583425329, 2.113959271091599, 2.091684911642466, 2.0930197149425656, 2.1119412023264603, 2.109747784572082, 2.1128484507700107, 2.1090998791598223, 2.1068240908889084, 2.085279980161646, 2.092537217133038, 2.098047769045806, 2.1106533031086543, 2.1191082527568272, 2.0863734186590612, 2.081073121086613, 2.1127054432968237, 2.106303176543376, 2.116069818223203, 2.099261497830724, 2.0956528060906403, 2.10383954378697, 2.0963755933730095, 2.105407002750221, 2.1021434911736496, 2.0790102963333017, 2.1066373411241592, 2.095192236823959, 2.0882216465246453, 2.1060060731641523, 2.0814747804397338, 2.0748307990120933, 2.0975484999569805, 2.0896632927316086, 2.122882520174002, 2.1060680718393296, 2.10889205035266, 2.090368301302821]
What could be the reason for it?
I am working with around 32k sequences, each is a minimum of 7 items long, could it be that it is simply not enough data for training from scratch?
Some of the responses for similar issues suggest playing around with the hyperparameters, so I did try a few things:
- Since it could be stuck in local min, it could be that the learning rate needs to be changed in order to be able to go out of it (is that right?) - I tried to change my learning rate from 1-e4 to 5-e5, it didn’t help much
- In order to check if the training work as expected, I tried to overfit my model on a small number of datasets (10) and the avg of the loss for the first 20 epochs looks as follow:
[0.0, 0.0, 0.0, 1.2672607898712158, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.9458177089691162, 0.0, 0.0, 0.0, 1.6752853393554688, 0.0, 0.0]
Any idea or suggestions would be very appreciated