I am trying to get early stopping to work in my code.
class FFNPL(pl.LightningModule):
def __init__(self, prm):
super(FFNPL, self).__init__()
self.model = FFN(prm["model"])
self.lr = prm["lr"]
def forward(self, x):
return self.model.forward(x)
def step(self, batch):
X,y = batch
yhat = self.forward(X)
loss = F.binary_cross_entropy_with_logits(yhat,y)
return y,yhat,loss
def training_step(self, batch, batch_idx):
y,yhat,loss = self.step(batch)
return dict(loss=loss)
def validation_step(self, batch, batch_idx):
y,yhat,loss = self.step(batch)
return dict(val_loss=loss, y=y, yhat=yhat)
def validation_epoch_end(self, outputs):
y = torch.cat([x["y"] for x in outputs])
yhat = torch.sigmoid(torch.cat([x["yhat"] for x in outputs]))
auc,ap = plf.classification.auroc(yhat,y,pos_label=1),plf.average_precision(yhat,y,pos_label=1)
self.log("val_ap", ap)
print(f" Epoch {self.current_epoch} val_auc: {auc:.2%}, val_ap: {ap:.2%}")
def test_step(self, batch, batch_idx):
y,yhat,loss = self.step(batch)
return dict(y=y, yhat=yhat)
def test_epoch_end(self, outputs):
y = torch.cat([x["y"] for x in outputs])
yhat = torch.sigmoid(torch.cat([x["yhat"] for x in outputs]))
print(f"Test: auc: {plf.classification.auroc(yhat,y):.2%}, ap: {plf.average_precision(yhat,y,pos_label=1):.2%}")
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
return optimizer
pl.seed_everything(42)
prm = {"model":{"input_dim":smat_train.shape[1], "hidden_dim":(256,256,256,64), "dropout_prob":(.5,.5,.5,.5,)}, "lr":1E-4}
model = FFNPL(prm)
trainer = pl.Trainer(auto_scale_batch_size="power", gpus=1, deterministic=True, max_epochs=20, progress_bar_refresh_rate=0,
callbacks=[plc.early_stopping.EarlyStopping(monitor="val_ap", patience=3)])
trainer.fit(model=model, train_dataloader=trainDL, val_dataloaders=validDL)
When I comment out EarlyStopping, I get the following
| Name | Type | Params
-------------------------------
0 | model | FFN | 5.4 M
-------------------------------
5.4 M Trainable params
0 Non-trainable params
5.4 M Total params
Epoch 0 val_auc: 47.87%, val_ap: 48.39%
Epoch 0 val_auc: 74.66%, val_ap: 74.06%
Epoch 1 val_auc: 94.57%, val_ap: 94.72%
Epoch 2 val_auc: 96.22%, val_ap: 96.32%
Epoch 3 val_auc: 96.77%, val_ap: 96.89% <==
Epoch 4 val_auc: 96.97%, val_ap: 97.12%
Epoch 5 val_auc: 97.19%, val_ap: 97.35%
Epoch 6 val_auc: 97.26%, val_ap: 97.42%
Epoch 7 val_auc: 97.33%, val_ap: 97.49%
Epoch 8 val_auc: 97.35%, val_ap: 97.51%
Epoch 9 val_auc: 97.36%, val_ap: 97.52%
Epoch 10 val_auc: 97.41%, val_ap: 97.58%
Epoch 11 val_auc: 97.41%, val_ap: 97.57%
Epoch 12 val_auc: 97.40%, val_ap: 97.58%
Epoch 13 val_auc: 97.40%, val_ap: 97.59%
Epoch 14 val_auc: 97.44%, val_ap: 97.62%
Epoch 15 val_auc: 97.42%, val_ap: 97.61%
Epoch 16 val_auc: 97.40%, val_ap: 97.61%
Epoch 17 val_auc: 97.43%, val_ap: 97.63%
Epoch 18 val_auc: 97.44%, val_ap: 97.64%
Epoch 19 val_auc: 97.39%, val_ap: 97.61%
When EarlyStopping is turned on (as shown above), the training stop at Epoch 3. What am I messing up?