Hi.
I’m currently implementing a toy Seq2Seq model for NLP translation German → English (data: Multi30k). The code works, but I have issues defining an inference function that uses a loaded model checkpoint.
During training (at the end of an epoch) I use a function (translate()
) that transforms a test sentence and logs to w&b. This function uses tokenizers and vocabularies that are generated/ stored in the DataModule. This works in this context since the trainer is still live and a pointer to the DataModule exists at this time.
This is the setup of my DataModule:
def setup(self, stage=None):
self.spacy_de = spacy.load("de")
self.spacy_en = spacy.load("en")
self.german = Field(
tokenize=self._tokenize_de,
lower=True,
init_token="<sos>",
eos_token="<eos>",
)
self.english = Field(
tokenize=self._tokenize_en,
lower=True,
init_token="<sos>",
eos_token="<eos>",
)
multi30k_train, multi30k_val, multi30k_test = Multi30k.splits(
exts=(".de", ".en"), fields=(self.german, self.english)
)
self.german.build_vocab(multi30k_train, max_size=10000, min_freq=2)
self.english.build_vocab(multi30k_train, max_size=10000, min_freq=2)
self.train_it, self.val_it, self.test_it = BucketIterator.splits(
(multi30k_train, multi30k_val, multi30k_test),
batch_size=self.batch_size,
sort_within_batch=True,
sort_key=lambda x: len(x.src),
)
This is a chunk of the translate function that utilises stuff from the DataModule (there are more of this):
# example
tokens = [token.text.lower() for token in model.trainer.datamodule.spacy_de(sentence)]
I receive an error that a reference model.trainer.datamodule
(that I use to query the tokenizers) is None
. Since I suppose this is a weak link to the DataModules that’s not part of my checkpoints I’m wondering how people usually split this workload? I guess it’s a messy thing to utilise stuff from the DataModules like this?
Are you defining tokenizers and vocabularies outside of the data modules?
Cheers,
C