Hello there!
I’m using PyTorch Lightning v1.1.0 and I am unable to have consistent results when using a custom Accuracy metric on both the val
and test
splits that happen to contain the same data.
The code boils down to 3 parts:
- The custom Accuracy metric:
class MyAccuracy(Metric):
def __init__(
self,
threshold: float = 0.5,
compute_on_step: bool = True,
dist_sync_on_step=False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
)
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
self.threshold = threshold
def update(self, logits: torch.Tensor, target: torch.Tensor):
preds, target = _input_format_classification(logits, target, self.threshold)
assert preds.shape == target.shape
self.correct += torch.sum(preds == target)
self.total += target.numel()
def compute(self):
return 100. * self.correct / self.total
- The Lightning Module class:
class LitClassifier(pl.LightningModule):
def __init__(self, cfg: DictConfig):
super().__init__()
self.cfg = cfg
self.best_val_acc = torch.tensor(0.)
self.train_accuracy = MyAccuracy()
self.val_accuracy = MyAccuracy()
self.test_accuracy = MyAccuracy()
def loss(self, outputs: torch.Tensor, targets: torch.Tensor):
return F.cross_entropy(outputs, targets)
def training_step(self, batch, batch_idx):
x, y = batch
out = self(x)
loss = self.loss(out, y)
self.log("train_loss", loss, prog_bar=True)
self.log("train_acc_s", self.train_accuracy(out, y))
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
out = self(x)
val_loss = self.loss(out, y)
self.log("val_loss", val_loss)
results = {"val_acc": self.val_accuracy(out, y)}
return results
def test_step(self, batch, batch_idx):
x, y = batch
out = self(x)
results = {"test_acc": self.test_accuracy(out, y)}
return results
def training_epoch_end(self, outputs):
self.log("train_acc_e", self.train_accuracy.compute(), prog_bar=True)
def validation_epoch_end(self, outputs):
val_acc = self.val_accuracy.compute()
if self.best_val_acc < val_acc:
self.best_val_acc = val_acc
logger.debug(f"New best val acc: {self.best_val_acc:.2f}")
self.log("val_acc", val_acc, prog_bar=True)
self.log("best_val_acc", self.best_val_acc, prog_bar=True)
def test_epoch_end(self, outputs):
self.log("test_acc_all", self.test_accuracy.compute())
- The Lightning Data Module
class DataModule(pl.LightningDataModule):
def __init__(self, cfg: DictConfig, trfs=None):
super().__init__()
self.name = cfg.datasets.name
self.class_name = cfg.datasets.class_name
self.root = cfg.datasets.path
self.loader_params = cfg.data.loader_params
# Transforms
means, stds = MEANS[self.name], STDS[self.name]
logger.debug(f"hard coded means: {means}, stds: {stds}")
if trfs is not None:
self.train_transforms = trfs
self.test_transforms = trfs
else:
self.train_transforms = transforms.Compose([
transforms.RandomCrop(**cfg.datasets.trfs_params.random_crop), # standard DA
transforms.RandomHorizontalFlip(), # standard DA
transforms.ToTensor(),
transforms.Normalize(means, stds),
])
self.test_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(means, stds),
])
def __repr__(self):
msg = f"Dataset: {self.name} ({self.class_name}) @ {self.root}"
return msg
def prepare_data(self):
# download data if needed
if self.name in ["CIFAR10", "CIFAR100"]:
load_obj(self.class_name)(
root=self.root,
train=True,
download=True
)
load_obj(self.class_name)(
root=self.root,
train=False,
download=True
)
elif self.name in ["STL10"]:
load_obj(self.class_name)(
root=self.root,
split="train",
download=True
)
load_obj(self.class_name)(
root=self.root,
split="test",
download=True
)
def setup(self, stage=None):
# Assign train/val/test for dataloaders
if self.name in ["CIFAR10", "CIFAR100"]:
train_data = load_obj(self.class_name)(
root=self.root,
train=True,
download=False,
transform=self.train_transforms,
)
test_data = load_obj(self.class_name)(
root=self.root,
train=False,
download=False,
transform=self.test_transforms,
)
elif self.name in ["STL10"]:
train_data = load_obj(self.class_name)(
root=self.root,
split="train",
download=False,
transform=self.train_transforms,
)
test_data = load_obj(self.class_name)(
root=self.root,
split="test",
download=False,
transform=self.test_transforms,
)
self.train_data = train_data
self.val_data = test_data
self.test_data = test_data
def train_dataloader(self):
return DataLoader(self.train_data, shuffle=True, **self.loader_params)
def val_dataloader(self):
return DataLoader(self.val_data, shuffle=False, **self.loader_params)
def test_dataloader(self):
return DataLoader(self.test_data, shuffle=False, **self.loader_params)
As you can see, the validation and testing data are the same, hence I expect the accuracy metric to output the same value when calling:
trainer.test(ckpt_path="best")
However I get the following results:
[{'val_loss': 0.4769740104675293, 'val_acc': 84.53125, 'best_val_acc': 84.53125, 'test_acc_all': 85.625}]
Questions:
- What did I miss there? Is my validation accuracy metric computed on the whole validation data as I expect it to be?
- Is the
test_acc_all
reflecting the accuracy on the whole test data (equal to validation data here) as I expect it to be?