Limit the vocabulary for auto-regressive decoder (such as BART or GPT) in next token prediction?

Hey together, I am currently training a model for seq-2-seq relation extraction. The pattern of the output is something like “ my subject my object relation”. Now “my subject” and “my object” are inside the text input while the tokens “” & co as well as all 20 options for “relation” are given from a list. Now I would like to increase training speed by limitting the dictionary during classification to only include the tokens (, , ), my 20 relations and the input text instead of the full dictionary.
Any ideas how to do this?

Hey,

I don’t quite get what you want to achieve. Can you show me how you would do it in raw python/pytorch code? I could then help you to achieve the same with lightning (shouldn’t be much different).

Best,
Justus

1 Like

Hey Justus, thanks for your answer.
As I have basically no knowledge, regarding PyTorch I can not show you an implementation.

What I am trying to achieve is best described in this paper: https://arxiv.org/pdf/2204.01098.pdf
They describe the model as “hallucinating” in early training, meaning that the task only requires the model to output tokens from the input and special tokens, however, the model (as the vocabulary is 50k+ words) outputs different words instead.

As far as I know, a text generation transformer will always choose from its full vocabulary for the next token to generate, as such I would like to reduce the transformer vocabulary to only include words from the input, my list of relations, and special tokens (e.g. an input may be "German chancellor Olaf Scholz decided to meet Joe Biden in November, the weather was nice; the output may be “ German chancellor Olaf Scholz Joe Biden Consult”, and the model would never be allowed to predict something like “ German chancellor Angela Merkel Joe Biden Consult”)

Does this describe my problem better? While my model only has problems in this regard early epochs, im certain it would speed up overall training and it was found to improve performance in the above linked article

Best,
Valentin

Hey Valentin, sorry for the late reply.
Unfortunately, I simply don’t have the time to fully reimplement a Paper.

From my own experience, simply give it a try and post here if you get stuck on more concrete implementation details. Feel free to also ping me on these questions!

Best,
Justus

1 Like

Hey Justus, thanks for your reply. I understand that you dont have the time, just hoped that somebody knows the problem / has like a 4 line solution for it :slight_smile: I will look into it again after finishing my thesis when I am on less time pressure and will post my results here.

(If anybody reads this in the future and I forgot to post, feel free to msg me)