SQuAD¶
Module Interface¶
- class torchmetrics.text.SQuAD(**kwargs)[source]¶
Calculate SQuAD Metric which is a metric for evaluating question answering models.
This metric corresponds to the scoring script for version 1 of the Stanford Question Answering Dataset (SQuAD).
As input to
forward
andupdate
the metric accepts the following input:preds
(Dict
): A Dictionary or List of Dictionary-s that mapid
andprediction_text
to the respective valuesExample
prediction
:{"prediction_text": "TorchMetrics is awesome", "id": "123"}
target
(Dict
): A Dictionary or List of Dictionary-s that contain theanswers
andid
in the SQuAD Format.Example
target
:{ 'answers': [{'answer_start': [1], 'text': ['This is a test answer']}], 'id': '1', }
Reference SQuAD Format:
{ 'answers': {'answer_start': [1], 'text': ['This is a test text']}, 'context': 'This is a test context.', 'id': '1', 'question': 'Is this a test?', 'title': 'train test' }
As output of
forward
andcompute
the metric returns the following output:squad
(Dict
): A dictionary containing the F1 score (key: “f1”),and Exact match score (key: “exact_match”) for the batch.
- Parameters:
kwargs¶ (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torchmetrics.text import SQuAD >>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}] >>> target = [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}] >>> squad = SQuAD() >>> squad(preds, target) {'exact_match': tensor(100.), 'f1': tensor(100.)}
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val¶ (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax¶ (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> from torchmetrics.text import SQuAD >>> metric = SQuAD() >>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}] >>> target = [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}] >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> from torchmetrics.text import SQuAD >>> metric = SQuAD() >>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}] >>> target = [{"answers": {"answer_start": [97], "text": ["1976"]}, "id": "56e10a3be3433e1400422b22"}] >>> values = [ ] >>> for _ in range(10): ... values.append(metric(preds, target)) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.text.squad(preds, target)[source]¶
Calculate SQuAD Metric .
- Parameters:
preds¶ (
Union
[Dict
[str
,str
],List
[Dict
[str
,str
]]]) –A Dictionary or List of Dictionary-s that map id and prediction_text to the respective values.
Example prediction:
{"prediction_text": "TorchMetrics is awesome", "id": "123"}
target¶ (
Union
[Dict
[str
,Union
[str
,Dict
[str
,Union
[List
[str
],List
[int
]]]]],List
[Dict
[str
,Union
[str
,Dict
[str
,Union
[List
[str
],List
[int
]]]]]]]) –A Dictionary or List of Dictionary-s that contain the answers and id in the SQuAD Format.
Example target:
{ 'answers': [{'answer_start': [1], 'text': ['This is a test answer']}], 'id': '1', }
Reference SQuAD Format:
{ 'answers': {'answer_start': [1], 'text': ['This is a test text']}, 'context': 'This is a test context.', 'id': '1', 'question': 'Is this a test?', 'title': 'train test' }
- Return type:
- Returns:
Dictionary containing the F1 score, Exact match score for the batch.
Example
>>> from torchmetrics.functional.text.squad import squad >>> preds = [{"prediction_text": "1976", "id": "56e10a3be3433e1400422b22"}] >>> target = [{"answers": {"answer_start": [97], "text": ["1976"]},"id": "56e10a3be3433e1400422b22"}] >>> squad(preds, target) {'exact_match': tensor(100.), 'f1': tensor(100.)}
- Raises:
KeyError – If the required keys are missing in either predictions or targets.
References
[1] SQuAD: 100,000+ Questions for Machine Comprehension of Text by Pranav Rajpurkar, Jian Zhang, Konstantin Lopyrev, Percy Liang SQuAD Metric .