Why does checkpointing dump the Trainer object instance within the state_dict of the DataModule?
When using the trainer.fit function in the standard way:
trainer = Trainer(...)
trainer.fit(lightning_model, lightning_data_module)
In my project,
The Dataloader objects cannot be pickled because the dataset I am using is loaded from an h5py format.
I am thus removing the DataSet objects from the state_dict of my DataModule by overloading the ‘state_dict’ method of my DataModule:
class MyDataModule(LightningDataModule):
...
def state_dict(self) -> Dict[str, Any]:
# Get all attributes of the class
state = self.__dict__
# Remove the attributes that are not serializable
if 'train_dataset' in state.keys():
del state['train_dataset']
if 'test_dataset' in state.keys():
del state['test_dataset']
if 'val_dataset' in state.keys():
del state['val_dataset']
if 'predict_dataset' in state.keys():
del state['predict_dataset']
I noticed that the trainer itself is then given as an argument of the state_dict of the DataModule inside the checkpoint object.
This thus raises an error when dumping the checkpoint, because the trainer has the different dataloaders (e.g., val_dataloader) as instance argument.
My questions are:
- Why is the trainer passed as an argument inside the state_dict of the DataModule?
- Is there a way to avoid the trainer getting the different dataloaders as instance arguments?
(I tried to overload the ‘state_dict’ class method of a custom trainer but that does not solve the issue.)
class MyTrainer(Trainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def state_dict(self) -> Dict[str, Any]:
# Get all attributes of the class
state = self.__dict__
# Remove the attributes that are not serializable
if 'train_dataloader' in state.keys():
del state['train_dataloader']
if 'test_dataloader' in state.keys():
del state['test_dataloader']
if 'val_dataloader' in state.keys():
del state['val_dataloader']
if 'predict_dataloader' in state.keys():
del state['predict_dataloader']
return state