from torchmetrics.text import SQuAD metric = SQuAD() preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}] target = [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}] metric.update(preds, target) fig_, ax_ = metric.plot()