from torch import rand, randint from torchmetrics.classification import BinaryFairness metric = BinaryFairness(2) metric.update(rand(20), randint(2,(20,)), randint(2,(20,))) fig_, ax_ = metric.plot()