Shortcuts

Strategy

class pytorch_lightning.strategies.Strategy(accelerator=None, checkpoint_io=None, precision_plugin=None)[source]

Bases: abc.ABC

Base class for all strategies that change the behaviour of the training, validation and test- loop.

abstract all_gather(tensor, group=None, sync_grads=False)[source]

Perform an all_gather on all processes.

Parameters:
  • tensor (Tensor) – the tensor to all_gather

  • group (Optional[Any]) – the process group to gather results from

  • sync_grads (bool) – flag that allows users to synchronize gradients for all_gather op

Return type:

Tensor

backward(closure_loss, optimizer, optimizer_idx, *args, **kwargs)[source]

Forwards backward-calls to the precision plugin.

Parameters:
  • closure_loss (Tensor) – a tensor holding the loss value to backpropagate

  • optimizer (Optional[Optimizer]) – An optional optimizer that gets passed down to the precision plugin’s backward

  • optimizer_idx (Optional[int]) – An optional optimizer index that gets passed down to the precision plugin’s backward

  • *args (Any) – Positional arguments that get passed down to the precision plugin’s backward, intended as arguments for the actual function that performs the backward, like backward().

  • **kwargs (Any) – Keyword arguments for the same purpose as *args.

Return type:

Tensor

abstract barrier(name=None)[source]

Synchronizes all processes which blocks processes until the whole group enters this function.

Parameters:

name (Optional[str]) – an optional name to pass into barrier.

Return type:

None

batch_to_device(batch, device=None, dataloader_idx=0)[source]

Moves the batch to the correct device.

The returned batch is of the same type as the input batch, just having all tensors on the correct device.

Parameters:
  • batch (Any) – The batch of samples to move to the correct device

  • device (Optional[device]) – The target device

  • dataloader_idx (int) – The index of the dataloader to which the batch belongs.

Return type:

Any

abstract broadcast(obj, src=0)[source]

Broadcasts an object to all processes.

Parameters:
  • obj (TypeVar(TBroadcast)) – the object to broadcast

  • src (int) – source rank

Return type:

TypeVar(TBroadcast)

connect(model)[source]

Called by the accelerator to connect the accelerator and the model with this plugin.

Return type:

None

dispatch(trainer)[source]

Hook to do something before the training/evaluation/prediction starts.

Return type:

None

lightning_module_state_dict()[source]

Returns model state.

Return type:

Dict[str, Any]

model_sharded_context()[source]

Provide hook to create modules in a distributed aware context. This is useful for when we’d like to shard the model instantly, which is useful for extremely large models which can save memory and initialization time.

Returns: Model parallel context.

Return type:

Generator

abstract model_to_device()[source]

Moves the model to the correct device.

Return type:

None

on_predict_end()[source]

Called when predict ends.

Return type:

None

on_predict_start()[source]

Called when predict begins.

Return type:

None

on_test_end()[source]

Called when test end.

Return type:

None

on_test_start()[source]

Called when test begins.

Return type:

None

on_train_batch_start(batch, batch_idx)[source]

Called in the training loop before anything happens for that batch.

Return type:

None

on_train_end()[source]

Called when train ends.

Return type:

None

on_train_start()[source]

Called when train begins.

Return type:

None

on_validation_end()[source]

Called when validation ends.

Return type:

None

on_validation_start()[source]

Called when validation begins.

Return type:

None

optimizer_state(optimizer)[source]

Returns state of an optimizer.

Allows for syncing/collating optimizer state from processes in custom plugins.

Return type:

Dict[str, Tensor]

optimizer_step(optimizer, opt_idx, closure, model=None, **kwargs)[source]

Performs the actual optimizer step.

Parameters:
  • optimizer (Optimizer) – the optimizer performing the step

  • opt_idx (int) – index of the current optimizer

  • closure (Callable[[], Any]) – closure calculating the loss value

  • model (Union[LightningModule, Module, None]) – reference to the model, optionally defining optimizer step related hooks

  • **kwargs (Any) – Keyword arguments to to optimizer.step

Return type:

Any

post_backward(closure_loss)[source]

Run after precision plugin executes backward.

Return type:

None

pre_backward(closure_loss)[source]

Run before precision plugin executes backward.

Return type:

None

predict_step(*args, **kwargs)[source]

The actual predict step.

See predict_step() for more details

Return type:

Union[Tensor, Dict[str, Any]]

process_dataloader(dataloader)[source]

Wraps the dataloader if necessary.

Parameters:

dataloader (DataLoader) – iterable. Ideally of type: torch.utils.data.DataLoader

Return type:

DataLoader

abstract reduce(tensor, group=None, reduce_op='mean')[source]

Reduces the given tensor (e.g. across GPUs/processes).

Parameters:
  • tensor (Union[Tensor, Any]) – the tensor to sync and reduce

  • group (Optional[Any]) – the process group to reduce

  • reduce_op (Union[ReduceOp, str, None]) – the reduction operation. Defaults to ‘mean’. Can also be a string ‘sum’ or ReduceOp.

Return type:

Union[Tensor, Any]

reduce_boolean_decision(decision, all=True)[source]

Reduce a boolean decision across all processes.

Return type:

bool

remove_checkpoint(filepath)[source]

Remove checkpoint filepath from the filesystem.

Parameters:

filepath (Union[str, Path]) – Path to checkpoint

Return type:

None

save_checkpoint(checkpoint, filepath, storage_options=None)[source]

Save model/training states as a checkpoint file through state-dump and file-write.

Parameters:
  • checkpoint (Dict[str, Any]) – dict containing model and trainer state

  • filepath (Union[str, Path]) – write-target file’s path

  • storage_options (Optional[Any]) – parameter for how to save to storage, passed to CheckpointIO plugin

Return type:

None

setup(trainer)[source]

Setup plugins for the trainer fit and creates optimizers.

Parameters:

trainer (Trainer) – the trainer instance

Return type:

None

setup_environment()[source]

Setup any processes or distributed connections.

This is called before the LightningModule/DataModule setup hook which allows the user to access the accelerator environment before setup is complete.

Return type:

None

setup_optimizers(trainer)[source]

Creates optimizers and schedulers.

Parameters:

trainer (Trainer) – the Trainer, these optimizers should be connected to

Return type:

None

setup_precision_plugin()[source]

Attaches the precision plugin to the accelerator.

Return type:

None

teardown()[source]

This method is called to teardown the training process.

It is the right place to release memory and free other resources.

Return type:

None

test_step(*args, **kwargs)[source]

The actual test step.

See test_step() for more details

Return type:

Union[Tensor, Dict[str, Any], None]

training_step(*args, **kwargs)[source]

The actual training step.

See training_step() for more details

Return type:

Union[Tensor, Dict[str, Any]]

validation_step(*args, **kwargs)[source]

The actual validation step.

See validation_step() for more details

Return type:

Union[Tensor, Dict[str, Any], None]

property handles_gradient_accumulation: bool

Whether the plugin handles gradient accumulation internally.

abstract property is_global_zero: bool

Whether the current process is the rank zero process not only on the local node, but for all nodes.

property lightning_module: Optional[pytorch_lightning.core.module.LightningModule]

Returns the pure LightningModule without potential wrappers.

property lightning_restore_optimizer: bool

Override to disable Lightning restoring optimizers/schedulers.

This is useful for plugins which manage restoring optimizers/schedulers.

property model: Optional[torch.nn.modules.module.Module]

Returns the potentially wrapped LightningModule.

property restore_checkpoint_after_setup: bool

Override to delay restoring from checkpoint till after pre-dispatch. This is useful when the plugin requires all the setup hooks to run before loading checkpoint.

Returns:

If true, restore checkpoint after pre_dispatch.

abstract property root_device: torch.device

Returns the root device.