I am using loss for optimization but I want to use AUROC from torchmetrics to monitor for early stopping. The code I have written is as follows:
def validation_step(self, batch, batch_idx):
x = batch['src']
y = batch['label']
mask = batch['mask']
x = self.base_model(x, mask)
x = self.linear(x).mean(axis=1).squeeze(1)
loss = F.binary_cross_entropy_with_logits(input=x,
target=y)
return {'loss': loss, 'preds': x, 'target': y}
def validation_step_end(self, outputs):
self.valid_auc(torch.sigmoid(outputs['preds']), outputs['target'].int())
self.log('valid_auc', self.valid_auc)
self.log('valid_loss', outputs['loss'])
The early stopping callback looks like:
early_stopping_cb = EarlyStopping(
monitor='valid_auc',
min_delta=args.min_delta,
patience=args.patience,
mode='max',
strict=True)
The question I have is whether AUROC is being aggregated correctly across mini-batches. Is there anything else I need to do? Is there a good way to validate that the sample being used to calculate AUROC is the entire validation set as opposed to the average of AUROC for each batch?