Loop
- class pytorch_lightning.loops.base.Loop[source]
Bases:
abc.ABC
,Generic
[pytorch_lightning.loops.base.T
]Basic Loops interface. All classes derived from this must implement the following properties and methods:
This class implements the following loop structure:
on_run_start() while not done: on_advance_start() advance() on_advance_end() on_run_end()
- abstract advance(*args, **kwargs)[source]
Performs a single step.
Accepts all arguments passed to
run
.Example:
def advance(self, iterator): batch = next(iterator) loss = self.trainer.lightning_module.training_step(batch, batch_idx) ...
- Return type
- connect(**kwargs)[source]
Optionally connect one or multiple loops to this one.
Linked loops should form a tree.
- Return type
- load_state_dict(state_dict, prefix='', metrics=None)[source]
Loads the state of this loop and all its children.
- Return type
- on_advance_start(*args, **kwargs)[source]
Hook to be called each time before
advance
is called.Accepts all arguments passed to :attr`run`.
- Return type
- on_load_checkpoint(state_dict)[source]
Called when loading a model checkpoint, use to reload loop state.
- Return type
- on_run_end()[source]
Hook to be called at the end of the run.
Its return argument is returned from
run
.- Return type
TypeVar
(T
)
- on_run_start(*args, **kwargs)[source]
Hook to be called as the first thing after entering
run
(except the state reset).Accepts all arguments passed to
run
.- Return type
- on_save_checkpoint()[source]
Called when saving a model checkpoint, use to persist loop state.
- Return type
- Returns
The current loop state.
- on_skip()[source]
The function to run when
run()
should be skipped, determined by the condition inskip
.- Return type
- Returns
the default output value of
on_run_end()
- abstract reset()[source]
Resets the internal state of the loop at the beginning of each call to
run
.Example:
def reset(self): # reset your internal state or add custom logic # if you expect run() to be called multiple times self.current_iteration = 0 self.outputs = []
- Return type
- run(*args, **kwargs)[source]
The main entry point to the loop.
Will frequently check the
done
condition and callsadvance
untildone
evaluates toTrue
.Override this if you wish to change the default behavior. The default implementation is:
Example:
def run(self, *args, **kwargs): if self.skip: return self.on_skip() self.reset() self.on_run_start(*args, **kwargs) while not self.done: self.advance(*args, **kwargs) output = self.on_run_end() return output
- Return type
TypeVar
(T
)- Returns
The output of
on_run_end
(often outputs collected from each step of the loop)
- state_dict(destination=None, prefix='')[source]
The state dict is determined by the state and progress of this loop and all its children.
- abstract property done: bool
Property indicating when the loop is finished.
Example:
@property def done(self): return self.trainer.global_step >= self.trainer.max_steps
- property skip: bool
Determine whether to return immediately from the call to
run()
.Example:
@property def skip(self): return len(self.trainer.train_dataloader) == 0