Discrepancy between val and test metrics

Hello there!

I’m using PyTorch Lightning v1.1.0 and I am unable to have consistent results when using a custom Accuracy metric on both the val and test splits that happen to contain the same data.

The code boils down to 3 parts:

  1. The custom Accuracy metric:
class MyAccuracy(Metric):
    def __init__(
            self,
            threshold: float = 0.5,
            compute_on_step: bool = True,
            dist_sync_on_step=False,
            process_group: Optional[Any] = None,
            dist_sync_fn: Callable = None,
    ):
        super().__init__(
            compute_on_step=compute_on_step,
            dist_sync_on_step=dist_sync_on_step,
            process_group=process_group,
            dist_sync_fn=dist_sync_fn,
        )

        self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

        self.threshold = threshold

    def update(self, logits: torch.Tensor, target: torch.Tensor):
        preds, target = _input_format_classification(logits, target, self.threshold)
        assert preds.shape == target.shape

        self.correct += torch.sum(preds == target)
        self.total += target.numel()

    def compute(self):
        return 100. * self.correct / self.total
  1. The Lightning Module class:
class LitClassifier(pl.LightningModule):
    def __init__(self, cfg: DictConfig):
        super().__init__()
        self.cfg = cfg
        self.best_val_acc = torch.tensor(0.)
        self.train_accuracy = MyAccuracy()
        self.val_accuracy = MyAccuracy()
        self.test_accuracy = MyAccuracy()


    def loss(self, outputs: torch.Tensor, targets: torch.Tensor):
        return F.cross_entropy(outputs, targets)

    def training_step(self, batch, batch_idx):
        x, y = batch

        out = self(x)

        loss = self.loss(out, y)

        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc_s", self.train_accuracy(out, y))

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch

        out = self(x)

        val_loss = self.loss(out, y)
        self.log("val_loss", val_loss)

        results = {"val_acc": self.val_accuracy(out, y)}
        return results

    def test_step(self, batch, batch_idx):
        x, y = batch

        out = self(x)

        results = {"test_acc": self.test_accuracy(out, y)}

        return results

    def training_epoch_end(self, outputs):
        self.log("train_acc_e", self.train_accuracy.compute(), prog_bar=True)


    def validation_epoch_end(self, outputs):

        val_acc = self.val_accuracy.compute()

        if self.best_val_acc < val_acc:
            self.best_val_acc = val_acc
            logger.debug(f"New best val acc: {self.best_val_acc:.2f}")

        self.log("val_acc", val_acc, prog_bar=True)
        self.log("best_val_acc", self.best_val_acc, prog_bar=True)

    def test_epoch_end(self, outputs):
        self.log("test_acc_all", self.test_accuracy.compute())
  1. The Lightning Data Module
class DataModule(pl.LightningDataModule):
    def __init__(self, cfg: DictConfig, trfs=None):
        super().__init__()
        self.name = cfg.datasets.name
        self.class_name = cfg.datasets.class_name
        self.root = cfg.datasets.path
        self.loader_params = cfg.data.loader_params
        # Transforms
        means, stds = MEANS[self.name], STDS[self.name]
        logger.debug(f"hard coded means: {means}, stds: {stds}")
        if trfs is not None:
            self.train_transforms = trfs
            self.test_transforms = trfs
        else:
            self.train_transforms = transforms.Compose([
                transforms.RandomCrop(**cfg.datasets.trfs_params.random_crop), # standard DA
                transforms.RandomHorizontalFlip(), # standard DA
                transforms.ToTensor(),
                transforms.Normalize(means, stds),
            ])
            self.test_transforms = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(means, stds),
            ])

    def __repr__(self):
        msg = f"Dataset: {self.name} ({self.class_name}) @ {self.root}"
        return msg

    def prepare_data(self):
        # download data if needed
        if self.name in ["CIFAR10", "CIFAR100"]:
            load_obj(self.class_name)(
                root=self.root,
                train=True,
                download=True
            )
            load_obj(self.class_name)(
                root=self.root,
                train=False,
                download=True
            )
        elif self.name in ["STL10"]:
            load_obj(self.class_name)(
                root=self.root,
                split="train",
                download=True
            )
            load_obj(self.class_name)(
                root=self.root,
                split="test",
                download=True
            )

    def setup(self, stage=None):
        # Assign train/val/test for dataloaders
        if self.name in ["CIFAR10", "CIFAR100"]:
            train_data = load_obj(self.class_name)(
                root=self.root,
                train=True,
                download=False,
                transform=self.train_transforms,
            )
            test_data = load_obj(self.class_name)(
                root=self.root,
                train=False,
                download=False,
                transform=self.test_transforms,
            )
        elif self.name in ["STL10"]:

            train_data = load_obj(self.class_name)(
                root=self.root,
                split="train",
                download=False,
                transform=self.train_transforms,
            )
            test_data = load_obj(self.class_name)(
                root=self.root,
                split="test",
                download=False,
                transform=self.test_transforms,
            )

        self.train_data = train_data
        self.val_data = test_data
        self.test_data = test_data

    def train_dataloader(self):
        return DataLoader(self.train_data, shuffle=True, **self.loader_params)

    def val_dataloader(self):
        return DataLoader(self.val_data, shuffle=False, **self.loader_params)

    def test_dataloader(self):
        return DataLoader(self.test_data, shuffle=False, **self.loader_params)

As you can see, the validation and testing data are the same, hence I expect the accuracy metric to output the same value when calling:

trainer.test(ckpt_path="best")

However I get the following results:

[{'val_loss': 0.4769740104675293, 'val_acc': 84.53125, 'best_val_acc': 84.53125, 'test_acc_all': 85.625}]                                                                                     

Questions:

  • What did I miss there? Is my validation accuracy metric computed on the whole validation data as I expect it to be?
  • Is the test_acc_all reflecting the accuracy on the whole test data (equal to validation data here) as I expect it to be?

Yes, since you are using the same dataset for validation and testing, the accuracy for test and validation should be identical.

I wonder if extra data is being added to the metric during the validation sanity check step. Is this still the case over more than one epoch?

Hey @teddy, thanks for your answer.
Actually, yes, it is still the case over more than one epoch.
Note that this issue does not occur when I use fast_dev_run=True, in which case both metrics are identical. Furthermore, it starts not working when using overfit_batches=N and min_epochs=M where M>1.

Do you have a colab I could use to reproduce this? Would be great so I can debug :slight_smile:

Hey @teddy, here it is: Google Colab.
I was able to reproduce a minimal working example where the issue shows up.
I enabled the Viewing option in the sharing parameters. Tell me if you are able to run the notebook on your own by copying it first.
Thanks for your kind help!

Note that I had been running the code with batch_size=128 but the testing samples number (10 000) cannot be equally divided by it. Instead I tried running the same code with batch_size=100 but the issue remains.
Hope that helps!

I downgraded the PTL requirement to 1.0.7 and tweaked a bit my custom Accuracy metric to fit the former API but the issue still occurs.
@teddy I will be spending some time on it today. Let me know if you have time to take a look at it :slight_smile:

Solved!
I was misled by the fact that in the case I use overfit_batches=N, the val_dataloader() and the test_dataloader() are simply not the same: val_dataloader() corresponds to the N training batches while test_dataloader() corresponds to the whole testing (=validation) set.
I have been training and testing on the full dataset and the results are now consistent.

Thanks for following up @azouaoui! I was trying to reproduce this but was having trouble. I’m glad you figured it out :). Having very trustworthy and well tested metrics has and will always be the first priority since the inception of the metrics package.