Shortcuts

Loop

class pytorch_lightning.loops.loop.Loop[source]

Bases: abc.ABC, Generic[pytorch_lightning.loops.loop.T]

Basic Loops interface. All classes derived from this must implement the following properties and methods:

  • done (property): Condition to break the loop

  • reset (method): Resets the internal state between multiple calls of run

  • advance (method): Implements one step of the loop

This class implements the following loop structure:

on_run_start()

while not done:
    on_advance_start()
    advance()
    on_advance_end()

on_run_end()
abstract advance(*args, **kwargs)[source]

Performs a single step.

Accepts all arguments passed to run.

Example:

def advance(self, iterator):
    batch = next(iterator)
    loss = self.trainer.lightning_module.training_step(batch, batch_idx)
    ...
Return type:

None

connect(**kwargs)[source]

Optionally connect one or multiple loops to this one.

Linked loops should form a tree.

Return type:

None

load_state_dict(state_dict, prefix='', metrics=None)[source]

Loads the state of this loop and all its children.

Return type:

None

on_advance_end()[source]

Hook to be called each time after advance is called.

Return type:

None

on_advance_start(*args, **kwargs)[source]

Hook to be called each time before advance is called.

Accepts all arguments passed to :attr`run`.

Return type:

None

on_load_checkpoint(state_dict)[source]

Called when loading a model checkpoint, use to reload loop state.

Return type:

None

on_run_end()[source]

Hook to be called at the end of the run.

Its return argument is returned from run.

Return type:

TypeVar(T)

on_run_start(*args, **kwargs)[source]

Hook to be called as the first thing after entering run (except the state reset).

Accepts all arguments passed to run.

Return type:

None

on_save_checkpoint()[source]

Called when saving a model checkpoint, use to persist loop state.

Return type:

Dict

Returns:

The current loop state.

on_skip()[source]

The function to run when run() should be skipped, determined by the condition in skip.

Return type:

TypeVar(T)

Returns:

the default output value of on_run_end()

replace(**loops)[source]

Optionally replace one or multiple of this loop’s sub-loops.

This method takes care of instantiating the class (if necessary) with all existing arguments, connecting all sub-loops of the old loop to the new instance, setting the Trainer reference, and connecting the new loop to the parent.

Parameters:

**loops (Union[Loop, Type[Loop]]) – Loop subclasses or instances. The name used should match the loop attribute name you want to replace.

Raises:

MisconfigurationException – When passing a Loop class, if the __init__ arguments do not match those of the Loop class it replaces.

Return type:

None

abstract reset()[source]

Resets the internal state of the loop at the beginning of each call to run.

Example:

def reset(self):
    # reset your internal state or add custom logic
    # if you expect run() to be called multiple times
    self.current_iteration = 0
    self.outputs = []
Return type:

None

run(*args, **kwargs)[source]

The main entry point to the loop.

Will frequently check the done condition and calls advance until done evaluates to True.

Override this if you wish to change the default behavior. The default implementation is:

Example:

def run(self, *args, **kwargs):
    if self.skip:
        return self.on_skip()

    self.reset()
    self.on_run_start(*args, **kwargs)

    while not self.done:
        self.advance(*args, **kwargs)

    output = self.on_run_end()
    return output
Return type:

TypeVar(T)

Returns:

The output of on_run_end (often outputs collected from each step of the loop)

state_dict(destination=None, prefix='')[source]

The state dict is determined by the state and progress of this loop and all its children.

Parameters:
  • destination (Optional[Dict]) – An existing dictionary to update with this loop’s state. By default a new dictionary is returned.

  • prefix (str) – A prefix for each key in the state dictionary

Return type:

Dict

teardown()[source]

Use to release memory etc.

Return type:

None

abstract property done: bool

Property indicating when the loop is finished.

Example:

@property
def done(self):
    return self.trainer.global_step >= self.trainer.max_steps
property restarting: bool

Whether the state of this loop was reloaded and it needs to restart.

property skip: bool

Determine whether to return immediately from the call to run().

Example:

@property
def skip(self):
    return len(self.trainer.train_dataloader) == 0