Loops (Advanced)¶
Persisting the state of loops¶
Note
This is an experimental feature and is not activated by default. Set the environment variable PL_FAULT_TOLERANT_TRAINING = 1 to enable saving the progress of loops. Read more about fault-tolerant training.
A powerful property of the class-based loop interface is that it can own an internal state. Loop instances can save their state to the checkpoint through corresponding hooks and if implemented accordingly, resume the state of exectuion at the appropriate place. This design is particularly interesting for fault-tolerant training which is an experimental feature released in Lightning v1.5.
The two hooks on_save_checkpoint
and on_load_checkpoint
function very similarly to how LightningModules and Callbacks save and load state.
def on_save_checkpoint(self):
state_dict["iteration"] = self.iteration
return state_dict
def on_load_checkpoint(self, state_dict):
self.iteration = state_dict["iteration"]
When the Trainer is restarting from a checkpoint (e.g., through trainer.fit(ckpt_path=...)
), the loop exposes a boolean attribute restarting
.
Based around the value of this variable, the user can write the loop in such a way that it can restart from an arbitrary point given the state loaded from the checkpoint.
For example, the implementation of the reset()
method could look like this given our previous example:
def reset(self):
if not self.restarting:
self.iteration = 0