BERT Score¶
Module Interface¶
- class torchmetrics.text.bert.BERTScore(model_name_or_path=None, num_layers=None, all_layers=False, model=None, user_tokenizer=None, user_forward_fn=None, verbose=False, idf=False, device=None, max_length=512, batch_size=64, num_threads=0, return_hash=False, lang='en', rescale_with_baseline=False, baseline_path=None, baseline_url=None, truncation=False, **kwargs)[source]¶
Bert_score Evaluating Text Generation for measuring text similarity.
BERT leverages the pre-trained contextual embeddings from BERT and matches words in candidate and reference sentences by cosine similarity. It has been shown to correlate with human judgment on sentence-level and system-level evaluation. Moreover, BERTScore computes precision, recall, and F1 measure, which can be useful for evaluating different language generation tasks. This implementation follows the original implementation from BERT_score.
As input to
forward
andupdate
the metric accepts the following input:preds
(List
): An iterable of predicted sentencestarget
(List
): An iterable of reference sentences
As output of
forward
andcompute
the metric returns the following output:score
(Dict
): A dictionary containing the keysprecision
,recall
andf1
with corresponding values
- Parameters:
preds¶ – An iterable of predicted sentences.
target¶ – An iterable of target sentences.
model_type¶ – A name or a model path used to load
transformers
pretrained model.num_layers¶ (
Optional
[int
]) – A layer of representation to use.all_layers¶ (
bool
) – An indication of whether the representation from all model’s layers should be used. Ifall_layers=True
, the argumentnum_layers
is ignored.model¶ (
Optional
[Module
]) – A user’s own model. Must be of torch.nn.Module instance.user_tokenizer¶ (
Optional
[Any
]) – A user’s own tokenizer used with the own model. This must be an instance with the__call__
method. This method must take an iterable of sentences (List[str]) and must return a python dictionary containing “input_ids” and “attention_mask” represented byTensor
. It is up to the user’s model of whether “input_ids” is aTensor
of input ids or embedding vectors. This tokenizer must prepend an equivalent of[CLS]
token and append an equivalent of[SEP]
token astransformers
tokenizer does.user_forward_fn¶ (
Optional
[Callable
[[Module
,Dict
[str
,Tensor
]],Tensor
]]) – A user’s own forward function used in a combination withuser_model
. This function must takeuser_model
and a python dictionary of containing"input_ids"
and"attention_mask"
represented byTensor
as an input and return the model’s output represented by the singleTensor
.verbose¶ (
bool
) – An indication of whether a progress bar to be displayed during the embeddings’ calculation.idf¶ (
bool
) – An indication whether normalization using inverse document frequencies should be used.device¶ (
Union
[str
,device
,None
]) – A device to be used for calculation.max_length¶ (
int
) – A maximum length of input sequences. Sequences longer thanmax_length
are to be trimmed.num_threads¶ (
int
) – A number of threads to use for a dataloader.return_hash¶ (
bool
) – An indication of whether the correspodninghash_code
should be returned.rescale_with_baseline¶ (
bool
) – An indication of whether bertscore should be rescaled with a pre-computed baseline. When a pretrained model fromtransformers
model is used, the corresponding baseline is downloaded from the originalbert-score
package from BERT_score if available. In other cases, please specify a path to the baseline csv/tsv file, which must follow the formatting of the files from BERT_score.baseline_path¶ (
Optional
[str
]) – A path to the user’s own local csv/tsv file with the baseline scale.baseline_url¶ (
Optional
[str
]) – A url path to the user’s own csv/tsv file with the baseline scale.truncation¶ (
bool
) – An indication of whether the input sequences should be truncated to themax_length
.kwargs¶ (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from pprint import pprint >>> from torchmetrics.text.bert import BERTScore >>> preds = ["hello there", "general kenobi"] >>> target = ["hello there", "master kenobi"] >>> bertscore = BERTScore() >>> pprint(bertscore(preds, target)) {'f1': tensor([1.0000, 0.9961]), 'precision': tensor([1.0000, 0.9961]), 'recall': tensor([1.0000, 0.9961])}
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val¶ (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax¶ (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> from torchmetrics.text.bert import BERTScore >>> preds = ["hello there", "general kenobi"] >>> target = ["hello there", "master kenobi"] >>> metric = BERTScore() >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> from torch import tensor >>> from torchmetrics.text.bert import BERTScore >>> preds = ["hello there", "general kenobi"] >>> target = ["hello there", "master kenobi"] >>> metric = BERTScore() >>> values = [] >>> for _ in range(10): ... val = metric(preds, target) ... val = {k: tensor(v).mean() for k,v in val.items()} # convert into single value per key ... values.append(val) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.text.bert.bert_score(preds, target, model_name_or_path=None, num_layers=None, all_layers=False, model=None, user_tokenizer=None, user_forward_fn=None, verbose=False, idf=False, device=None, max_length=512, batch_size=64, num_threads=0, return_hash=False, lang='en', rescale_with_baseline=False, baseline_path=None, baseline_url=None, truncation=False)[source]¶
Bert_score Evaluating Text Generation for text similirity matching.
This metric leverages the pre-trained contextual embeddings from BERT and matches words in candidate and reference sentences by cosine similarity. It has been shown to correlate with human judgment on sentence-level and system-level evaluation. Moreover, BERTScore computes precision, recall, and F1 measure, which can be useful for evaluating different language generation tasks.
This implementation follows the original implementation from BERT_score.
- Parameters:
preds¶ (
Union
[str
,Sequence
[str
],Dict
[str
,Tensor
]]) – Either an iterable of predicted sentences or aDict[input_ids, attention_mask]
.target¶ (
Union
[str
,Sequence
[str
],Dict
[str
,Tensor
]]) – Either an iterable of target sentences or aDict[input_ids, attention_mask]
.model_name_or_path¶ (
Optional
[str
]) – A name or a model path used to loadtransformers
pretrained model.num_layers¶ (
Optional
[int
]) – A layer of representation to use.all_layers¶ (
bool
) – An indication of whether the representation from all model’s layers should be used. Ifall_layers = True
, the argumentnum_layers
is ignored.user_tokenizer¶ (
Optional
[Any
]) – A user’s own tokenizer used with the own model. This must be an instance with the__call__
method. This method must take an iterable of sentences (List[str]
) and must return a python dictionary containing"input_ids"
and"attention_mask"
represented byTensor
. It is up to the user’s model of whether"input_ids"
is aTensor
of input ids or embedding vectors. his tokenizer must prepend an equivalent of[CLS]
token and append an equivalent of[SEP]
token as transformers tokenizer does.user_forward_fn¶ (
Optional
[Callable
[[Module
,Dict
[str
,Tensor
]],Tensor
]]) – A user’s own forward function used in a combination withuser_model
. This function must takeuser_model
and a python dictionary of containing"input_ids"
and"attention_mask"
represented byTensor
as an input and return the model’s output represented by the singleTensor
.verbose¶ (
bool
) – An indication of whether a progress bar to be displayed during the embeddings’ calculation.idf¶ (
bool
) – An indication of whether normalization using inverse document frequencies should be used.device¶ (
Union
[str
,device
,None
]) – A device to be used for calculation.max_length¶ (
int
) – A maximum length of input sequences. Sequences longer thanmax_length
are to be trimmed.num_threads¶ (
int
) – A number of threads to use for a dataloader.return_hash¶ (
bool
) – An indication of whether the correspodninghash_code
should be returned.lang¶ (
str
) – A language of input sentences. It is used when the scores are rescaled with a baseline.rescale_with_baseline¶ (
bool
) – An indication of whether bertscore should be rescaled with a pre-computed baseline. When a pretrained model fromtransformers
model is used, the corresponding baseline is downloaded from the originalbert-score
package from BERT_score if available. In other cases, please specify a path to the baseline csv/tsv file, which must follow the formatting of the files from BERT_scorebaseline_path¶ (
Optional
[str
]) – A path to the user’s own local csv/tsv file with the baseline scale.baseline_url¶ (
Optional
[str
]) – A url path to the user’s own csv/tsv file with the baseline scale.truncation¶ (
bool
) – An indication of whether the input sequences should be truncated to the maximum length.
- Return type:
- Returns:
Python dictionary containing the keys
precision
,recall
andf1
with corresponding values.- Raises:
ValueError – If
len(preds) != len(target)
.ModuleNotFoundError – If tqdm package is required and not installed.
ModuleNotFoundError – If
transformers
package is required and not installed.ValueError – If
num_layer
is larger than the number of the model layers.ValueError – If invalid input is provided.
Example
>>> from pprint import pprint >>> from torchmetrics.functional.text.bert import bert_score >>> preds = ["hello there", "general kenobi"] >>> target = ["hello there", "master kenobi"] >>> pprint(bert_score(preds, target)) {'f1': tensor([1.0000, 0.9961]), 'precision': tensor([1.0000, 0.9961]), 'recall': tensor([1.0000, 0.9961])}