Shortcuts

Strategy

class pytorch_lightning.strategies.Strategy(accelerator=None, checkpoint_io=None, precision_plugin=None)[source]

Bases: abc.ABC

Base class for all strategies that change the behaviour of the training, validation and test- loop.

abstract all_gather(tensor, group=None, sync_grads=False)[source]

Perform an all_gather on all processes.

Parameters
  • tensor (Tensor) – the tensor to all_gather

  • group (Optional[Any]) – the process group to gather results from

  • sync_grads (bool) – flag that allows users to synchronize gradients for all_gather op

Return type

Tensor

backward(closure_loss, *args, **kwargs)[source]

Forwards backward-calls to the precision plugin.

Parameters

closure_loss (Tensor) – a tensor holding the loss value to backpropagate

Return type

Tensor

abstract barrier(name=None)[source]

Synchronizes all processes which blocks processes until the whole group enters this function.

Parameters

name (Optional[str]) – an optional name to pass into barrier.

Return type

None

batch_to_device(batch, device=None, dataloader_idx=0)[source]

Moves the batch to the correct device.

The returned batch is of the same type as the input batch, just having all tensors on the correct device.

Parameters
  • batch (Any) – The batch of samples to move to the correct device

  • device (Optional[device]) – The target device

  • dataloader_idx (int) – The index of the dataloader to which the batch belongs.

Return type

Any

abstract broadcast(obj, src=0)[source]

Broadcasts an object to all processes.

Parameters
  • obj (~TBroadcast) – the object to broadcast

  • src (int) – source rank

Return type

~TBroadcast

connect(model)[source]

Called by the accelerator to connect the accelerator and the model with this plugin.

Return type

None

dispatch(trainer)[source]

Hook to do something before the training/evaluation/prediction starts.

Return type

None

lightning_module_state_dict()[source]

Returns model state.

Return type

Dict[str, Union[Any, Tensor]]

model_sharded_context()[source]

Provide hook to create modules in a distributed aware context. This is useful for when we’d like to shard the model instantly, which is useful for extremely large models which can save memory and initialization time.

Returns: Model parallel context.

Return type

Generator

abstract model_to_device()[source]

Moves the model to the correct device.

Return type

None

on_predict_end()[source]

Called when predict ends.

on_predict_start()[source]

Called when predict begins.

Return type

None

on_test_end()[source]

Called when test end.

Return type

None

on_test_start()[source]

Called when test begins.

Return type

None

on_train_batch_start(batch, batch_idx, dataloader_idx=0)[source]

Called in the training loop before anything happens for that batch.

Return type

None

on_train_end()[source]

Called when train ends.

Return type

None

on_train_start()[source]

Called when train begins.

Return type

None

on_validation_end()[source]

Called when validation ends.

Return type

None

on_validation_start()[source]

Called when validation begins.

Return type

None

optimizer_state(optimizer)[source]

Returns state of an optimizer.

Allows for syncing/collating optimizer state from processes in custom plugins.

Return type

Dict[str, Tensor]

optimizer_step(optimizer, opt_idx, closure, model=None, **kwargs)[source]

Performs the actual optimizer step.

Parameters
  • optimizer (Optimizer) – the optimizer performing the step

  • opt_idx (int) – index of the current optimizer

  • closure (Callable[[], Any]) – closure calculating the loss value

  • model (Union[LightningModule, Module, None]) – reference to the model, optionally defining optimizer step related hooks

  • **kwargs – Any extra arguments to optimizer.step

Return type

Any

post_backward(closure_loss)[source]

Run after precision plugin executes backward.

Return type

None

post_dispatch(trainer)[source]

Deprecated since version v1.6: This method has been deprecated in v1.6 and will be removed in v1.7. Use teardown() instead.

Hook to do something after the training/evaluation/prediction finishes.

Return type

None

pre_backward(closure_loss)[source]

Run before precision plugin executes backward.

Return type

None

predict_step(*args, **kwargs)[source]

The actual predict step.

See predict_step() for more details

Return type

Union[Tensor, Dict[str, Any]]

process_dataloader(dataloader)[source]

Wraps the dataloader if necessary.

Parameters

dataloader (DataLoader) – iterable. Ideally of type: torch.utils.data.DataLoader

Return type

DataLoader

abstract reduce(tensor, group=None, reduce_op='mean')[source]

Reduces the given tensor (e.g. across GPUs/processes).

Parameters
  • tensor (Union[Tensor, Any]) – the tensor to sync and reduce

  • group (Optional[Any]) – the process group to reduce

  • reduce_op (Union[ReduceOp, str, None]) – the reduction operation. Defaults to ‘mean’. Can also be a string ‘sum’ or ReduceOp.

Return type

Union[Tensor, Any]

reduce_boolean_decision(decision)[source]

Reduce the early stopping decision across all processes.

Return type

bool

remove_checkpoint(filepath)[source]

Remove checkpoint filepath from the filesystem.

Parameters

filepath (Union[str, Path]) – Path to checkpoint

Return type

None

save_checkpoint(checkpoint, filepath, storage_options=None)[source]

Save model/training states as a checkpoint file through state-dump and file-write.

Parameters
  • checkpoint (Dict[str, Any]) – dict containing model and trainer state

  • filepath (Union[str, Path]) – write-target file’s path

  • storage_options (Optional[Any]) – parameter for how to save to storage, passed to CheckpointIO plugin

Return type

None

setup(trainer)[source]

Setup plugins for the trainer fit and creates optimizers.

Parameters

trainer (Trainer) – the trainer instance

Return type

None

setup_environment()[source]

Setup any processes or distributed connections.

This is called before the LightningModule/DataModule setup hook which allows the user to access the accelerator environment before setup is complete.

Return type

None

setup_optimizers(trainer)[source]

Creates optimizers and schedulers.

Parameters

trainer (Trainer) – the Trainer, these optimizers should be connected to

Return type

None

setup_precision_plugin()[source]

Attaches the precision plugin to the accelerator.

Return type

None

teardown()[source]

This method is called to teardown the training process.

It is the right place to release memory and free other resources.

Return type

None

test_step(*args, **kwargs)[source]

The actual test step.

See test_step() for more details

Return type

Union[Tensor, Dict[str, Any], None]

training_step(*args, **kwargs)[source]

The actual training step.

See training_step() for more details

Return type

Union[Tensor, Dict[str, Any]]

validation_step(*args, **kwargs)[source]

The actual validation step.

See validation_step() for more details

Return type

Union[Tensor, Dict[str, Any], None]

property handles_gradient_accumulation: bool

Whether the plugin handles gradient accumulation internally.

Return type

bool

abstract property is_global_zero: bool

Whether the current process is the rank zero process not only on the local node, but for all nodes.

Return type

bool

property lightning_module: Optional[pytorch_lightning.core.lightning.LightningModule]

Returns the pure LightningModule without potential wrappers.

Return type

Optional[LightningModule]

property lightning_restore_optimizer: bool

Override to disable Lightning restoring optimizers/schedulers.

This is useful for plugins which manage restoring optimizers/schedulers.

Return type

bool

property model: Optional[torch.nn.modules.module.Module]

Returns the potentially wrapped LightningModule.

Return type

Optional[Module]

property restore_checkpoint_after_setup: bool

Override to delay restoring from checkpoint till after pre-dispatch. This is useful when the plugin requires all the setup hooks to run before loading checkpoint.

Return type

bool

Returns

If true, restore checkpoint after pre_dispatch.

abstract property root_device: torch.device

Returns the root device.

Return type

device