Using TFRecords from MetaDataset in training model

Hi,

I have a dataset that is stored in TFRecords format and I was wondering how to use that in training the PL module. The TFRecords are actually from Google’s meta dataset.

The usual code to read the TFrecords look something like below. As you can see we do not use Dataloaders but rather use TF sessions to get dataset batches.

Any help on how to handle this scenario in Pytorch Lightning will be greatly appreciated.

Thanks!

from data.meta_dataset_reader import MetaDatasetBatchReader

trainsets = ['ilsvrc_2012', 'omniglot', 'aircraft', 'cu_birds', 'dtd', 'quickdraw', 'fungi', 'vgg_flower']
valsets = ['ilsvrc_2012', 'omniglot', 'aircraft', 'cu_birds', 'dtd', 'quickdraw', 'fungi', 'vgg_flower']
testsets = ['ilsvrc_2012', 'omniglot', 'aircraft', 'cu_birds', 'dtd', 'quickdraw', 'fungi', 'vgg_flower']

train_loaders = []
for t_indx, trainset in enumerate(trainsets):
    dataset = MetaDatasetBatchReader('train', [trainset], valsets, testsets,
                                                         batch_size=5)
    train_loaders.append(dataset)


config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = False
with tf.compat.v1.Session(config=config) as session:
    for t_indx, (name, train_loader) in enumerate(zip(trainsets, train_loaders)):
        sample = dataset.get_train_batch()
        print(sample['images'].shape)