I have a pytorch test data loader that has three batches of 64 in it:
DataBatch(x=[1585, 5], edge_index=[2, 3042], y=[64], batch=[1585], ptr=[65])
DataBatch(x=[1311, 5], edge_index=[2, 2494], y=[64], batch=[1311], ptr=[65])
DataBatch(x=[1963, 5], edge_index=[2, 3798], y=[64], batch=[1963], ptr=[65])
There are actually 200 samples in this data set, but I did drop_last = True
to avoid issues with incomplete batches (which someone has mentioned may be error-prone).
But when I predict on these batches like this:
model.eval()
trainer = pl.Trainer(accelerator='gpu',devices=-1)
predictions = trainer.predict(model, graph_test_loader) #where the graph_test_loader is 3 batches of 64, i.e. the structure above
The output is:
(tensor(0.6912), tensor(0.5312), tensor(0.6939), tensor(0.5312), tensor(1.), tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.int32))
(tensor(0.7148), tensor(0.3594), tensor(0.5287), tensor(0.3594), tensor(1.), tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.int32))
(tensor(0.6912), tensor(0.5312), tensor(0.6939), tensor(0.5312), tensor(1.), tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.int32))
(tensor(0.7127), tensor(0.3750), tensor(0.5455), tensor(0.3750), tensor(1.), tensor([1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.int32))
So it has predicted on all 200 samples, i.e. the 3 X 64 batches above, and the final 8 that were dropped for not being a complete batch. But when I do:
for each_data_list in graph_test_loader:
print(each_data_list)
it only prints the three batches I have described above.
How can I also see the DataBatch for the incomplete batch, since it predicted on it?