Why is the Trainer instance saved inside the DataModule during checkpoint save?

Why does checkpointing dump the Trainer object instance within the state_dict of the DataModule?

When using the trainer.fit function in the standard way:

trainer = Trainer(...)
trainer.fit(lightning_model, lightning_data_module)

In my project,
The Dataloader objects cannot be pickled because the dataset I am using is loaded from an h5py format.
I am thus removing the DataSet objects from the state_dict of my DataModule by overloading the ‘state_dict’ method of my DataModule:

class MyDataModule(LightningDataModule):
...
    def state_dict(self) -> Dict[str, Any]:
            # Get all attributes of the class
            state = self.__dict__
            # Remove the attributes that are not serializable
            if 'train_dataset' in state.keys():
                del state['train_dataset']
            if 'test_dataset' in state.keys():
                del state['test_dataset']
            if 'val_dataset' in state.keys():
                del state['val_dataset']
            if 'predict_dataset' in state.keys():
                del state['predict_dataset']

I noticed that the trainer itself is then given as an argument of the state_dict of the DataModule inside the checkpoint object.
This thus raises an error when dumping the checkpoint, because the trainer has the different dataloaders (e.g., val_dataloader) as instance argument.

My questions are:

  • Why is the trainer passed as an argument inside the state_dict of the DataModule?
  • Is there a way to avoid the trainer getting the different dataloaders as instance arguments?
    (I tried to overload the ‘state_dict’ class method of a custom trainer but that does not solve the issue.)

class MyTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def state_dict(self) -> Dict[str, Any]:
        # Get all attributes of the class
        state = self.__dict__

        # Remove the attributes that are not serializable
        if 'train_dataloader' in state.keys():
            del state['train_dataloader']
        if 'test_dataloader' in state.keys():
            del state['test_dataloader']
        if 'val_dataloader' in state.keys():
            del state['val_dataloader']
        if 'predict_dataloader' in state.keys():
            del state['predict_dataloader']

        return state

It is because of this line of the code:

state = self.__dict__

This is not typically how state_dict is defined. You don’t want to literally dump every attribute of the datamodule class (since also the trainer is attached to it). I recommend that you only return what you actually want to be saved:

def state_dict(self):
    return {"x": self.x, "y", self.y}  # etc.

You also want to implement the corresponding load_state_dict method, which should do the reverse.

1 Like

Thanks a lot for this fast and detailed answer.

1 Like