Hi,
I have a tf.data.dataset, (an iterable dataset), then I make a pytorch dataset from this iterable dataset as follows:
class WMTDataset(torch.utils.data.IterableDataset):
def __init__(self, dataset):
super(WMTDataset).__init__()
self.dataset = iter(dataset)
self.dataset_size = 5
def __len__(self):
return self.dataset_size
def __iter__(self):
return self.dataset
Then, I wrote a module with pytorch lightening to train a model:
def validation_step(self, batch, batch_idx):
loss = self._step(batch)
return {"val_loss": loss.detach().cpu()}
def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean().detach().cpu().item()
tensorboard_logs = {"val_loss": avg_loss}
return {"avg_val_loss": avg_loss, "log": tensorboard_logs,
'progress_bar': tensorboard_logs}
def val_dataloader(self):
val_dataset = get_dataset(tokenizer=self.tokenizer, split="validation", args=self.hparams)
return DataLoader(val_dataset, batch_size=self.hparams.eval_batch_size, num_workers=self.hparams.num_workers)
I got the following error when running my codes, basically, when the function “validation_epoch_end” is called, the input “outputs” is empty, resulting in this error, I am assuming that dataloader gets to the end of the iterable dataset, and then there is no more element to get input to the “validation_step”. could you assist me please how I can make a dataloader properly out of iterable datasets? Also, if you have an idea about this error, truly appreciated. thanks
Traceback (most recent call last):
File "main.py", line 127, in <module>
model = main()
File "main.py", line 88, in main
trainer = generic_train(model, args)
File "main.py", line 76, in generic_train
trainer.fit(model)
File "/opt/conda/envs/test/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 440, in fit
results = self.accelerator_backend.train()
File "/opt/conda/envs/test/lib/python3.7/site-packages/pytorch_lightning/accelerators/gpu_accelerator.py", line 54, in train
results = self.train_or_test()
File "/opt/conda/envs/test/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 68, in train_or_test
results = self.trainer.train()
File "/opt/conda/envs/test/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 485, in train
self.train_loop.run_training_epoch()
File "/opt/conda/envs/test/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 572, in run_training_epoch
self.trainer.run_evaluation(test_mode=False)
File "/opt/conda/envs/test/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 597, in run_evaluation
num_dataloaders=len(dataloaders)
File "/opt/conda/envs/test/lib/python3.7/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 196, in evaluation_epoch_end
deprecated_results = self.__run_eval_epoch_end(num_dataloaders, using_eval_result)
File "/opt/conda/envs/test/lib/python3.7/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 247, in __run_eval_epoch_end
eval_results = model.validation_epoch_end(eval_results)
File "/home/rabeeh/universal_sentence_encoder/pl_codes/models.py", line 121, in validation_epoch_end
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean().detach().cpu().item()
RuntimeError: stack expects a non-empty TensorList
Exception ignored in: <function tqdm.__del__ at 0x7fccae7d9200>
Traceback (most recent call last):
File "/opt/conda/envs/test/lib/python3.7/site-packages/tqdm/std.py", line 1128, in __del__
File "/opt/conda/envs/test/lib/python3.7/site-packages/tqdm/std.py", line 1341, in close
File "/opt/conda/envs/test/lib/python3.7/site-packages/tqdm/std.py", line 1520, in display
File "/opt/conda/envs/test/lib/python3.7/site-packages/tqdm/std.py", line 1131, in __repr__
File "/opt/conda/envs/test/lib/python3.7/site-packages/tqdm/std.py", line 1481, in format_dict
TypeError: cannot unpack non-iterable NoneType object