I’m new to pl and I wrote a demo LightningModule implementing a FFN classifying FashionMNIST images for pl study, here’s my LightningModule:
class DemoLightningModule(pl.LightningModule):
def __init__(self, train_batch_size: int, eval_test_batch_size: int, optimizer_type: str) -> None:
# def __init__(self, *args: List[Any], **kwargs: Dict[Any]) -> None:
super().__init__()
self.train_batch_size = train_batch_size
self.eval_test_batch_size = eval_test_batch_size
self.optimizer_type = optimizer_type
self.network = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 10))
self.dataset = torchvision.datasets.FashionMNIST('../Others/ComputerVision/data/FashionMNIST',
transform=transforms.Compose([transforms.ToTensor()]))
def training_step(self, batch, batch_idx):
net_results = self.network(batch[0])
train_loss = nn.functional.cross_entropy(net_results, batch[1])
# self.total_train_loss += train_loss.item()
self.log('train_loss', train_loss.item(), prog_bar=True, on_epoch=True, on_step=True)
# self.trainer.progress_bar_metrics['train_loss'] = train_loss.item()
return train_loss
def validation_step(self, batch, batch_idx):
net_preds = self.network(batch[0])
net_preds = torch.max(net_preds, dim=1).indices
this_accuracy = (net_preds == batch[1]).sum() / batch[1].size()[0] * 100
self.log('val_acc', this_accuracy, prog_bar=True, on_epoch=True, on_step=True)
return this_accuracy
I expected that during validation, the validation progress bar will have val_acc_step
displayed after each batch update and val_acc
after validation, but it appeared only in the main(training) progress bar, just as follows:
Epoch 2: 100%|████████████████████████████████████████████| 1875/1875 [00:38<00:00, 48.25it/s, train_loss_step=0.291, val_acc_step=81.20, val_acc_epoch=86.70, train_loss_epoch=0.383]
Validation DataLoader 0: 39%|█████████████████████████████████████████████▉ | 736/1875 [00:10<00:15, 72.52it/s]
Since it is just a demo project, validation is pretty fast, but during practical uses (such as training a large model in NLP), the validation progress(generation) often takes a lot of time and it is of great importance that I can track the validation metrics in real-time. So are there any solutions to making those validation metrics also appear in the validation progress bar?