Hello, I have the following callback I wrote to log metrics:
class MetricCallback(Callback, ABC):
def __init__(self, metric: Type[Metric], metric_args: dict, name: str):
self._name = name
self._train_metric = metric(**metric_args)
self._val_metric = metric(**metric_args)
@abstractmethod
def _update_batch(self, pl_module: GAN, batch: Tensor, metric: Metric):
pass
@classmethod
def log(cls, pl_module, label, metric):
pl_module.log(label, metric)
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
self._update_batch(pl_module, batch, self._train_metric)
def on_train_epoch_end(self, trainer, pl_module, outputs):
self.log(pl_module, f'train_{self._name}', self._train_metric.compute())
def on_validation_batch_end(self, trainer, pl_module: GAN, outputs, batch: Tensor, batch_idx, dataloader_idx):
self._update_batch(pl_module, batch, self._val_metric)
def on_validation_epoch_end(self, trainer, pl_module):
self.log(pl_module, f'val_{self._name}', self._val_metric.compute())
Note how logging is done in the log()
method which itself calls pl_module.log()
. In the documentation it says “Lightning auto-determines the correct logging mode for you. But of course you can override the default behavior by manually setting the log()
parameters.” This is why I have not extra parameters passed. When I use this class, my metrics are being evaluated but they are not showing up in tensorboard. What might be causing this?