import torch from torchmetrics.wrappers import MetricTracker from torchmetrics.classification import BinaryAccuracy tracker = MetricTracker(BinaryAccuracy()) for epoch in range(5): tracker.increment() for batch_idx in range(5): tracker.update(torch.randint(2, (10,)), torch.randint(2, (10,))) fig_, ax_ = tracker.plot() # plot all epochs