Shortcuts

Loop

class pytorch_lightning.loops.base.Loop[source]

Bases: abc.ABC, Generic[pytorch_lightning.loops.base.T]

Basic Loops interface. All classes derived from this must implement the following properties and methods:

  • done (property): Condition to break the loop

  • reset (method): Resets the internal state between multiple calls of run

  • advance (method): Implements one step of the loop

This class implements the following loop structure:

on_run_start()

while not done:
    on_advance_start()
    advance()
    on_advance_end()

on_run_end()
abstract advance(*args, **kwargs)[source]

Performs a single step.

Accepts all arguments passed to run.

Example:

def advance(self, iterator):
    batch = next(iterator)
    loss = self.trainer.lightning_module.training_step(batch, batch_idx)
    ...
Return type

None

connect(**kwargs)[source]

Optionally connect one or multiple loops to this one.

Linked loops should form a tree.

Return type

None

load_state_dict(state_dict, prefix='', metrics=None)[source]

Loads the state of this loop and all its children.

Return type

None

on_advance_end()[source]

Hook to be called each time after advance is called.

Return type

None

on_advance_start(*args, **kwargs)[source]

Hook to be called each time before advance is called.

Accepts all arguments passed to :attr`run`.

Return type

None

on_load_checkpoint(state_dict)[source]

Called when loading a model checkpoint, use to reload loop state.

Return type

None

on_run_end()[source]

Hook to be called at the end of the run.

Its return argument is returned from run.

Return type

~T

on_run_start(*args, **kwargs)[source]

Hook to be called as the first thing after entering run (except the state reset).

Accepts all arguments passed to run.

Return type

None

on_save_checkpoint()[source]

Called when saving a model checkpoint, use to persist loop state.

Return type

Dict

Returns

The current loop state.

on_skip()[source]

The function to run when run() should be skipped, determined by the condition in skip.

Return type

Optional[Any]

Returns

the default output value of on_run_end()

abstract reset()[source]

Resets the internal state of the loop at the beginning of each call to run.

Example:

def reset(self):
    # reset your internal state or add custom logic
    # if you expect run() to be called multiple times
    self.current_iteration = 0
    self.outputs = []
Return type

None

run(*args, **kwargs)[source]

The main entry point to the loop.

Will frequently check the done condition and calls advance until done evaluates to True.

Override this if you wish to change the default behavior. The default implementation is:

Example:

def run(self, *args, **kwargs):
    if self.skip:
        return self.on_skip()

    self.reset()
    self.on_run_start(*args, **kwargs)

    while not self.done:
        self.advance(*args, **kwargs)

    output = self.on_run_end()
    return output
Return type

~T

Returns

The output of on_run_end (often outputs collected from each step of the loop)

state_dict(destination=None, prefix='')[source]

The state dict is determined by the state and progress of this loop and all its children.

Parameters
  • destination (Optional[Dict]) – An existing dictionary to update with this loop’s state. By default a new dictionary is returned.

  • prefix (Optional[str]) – A prefix for each key in the state dictionary

Return type

Dict

teardown()[source]

Use to release memory etc.

Return type

None

abstract property done: bool

Property indicating when the loop is finished.

Example:

@property
def done(self):
    return self.trainer.global_step >= self.trainer.max_steps
Return type

bool

property skip: bool

Determine whether to return immediately from the call to run().

Example:

@property
def skip(self):
    return len(self.trainer.train_dataloader) == 0
Return type

bool