Shortcuts

Accelerator

class pytorch_lightning.accelerators.Accelerator(precision_plugin, training_type_plugin)[source]

Bases: object

The Accelerator Base Class. An Accelerator is meant to deal with one type of Hardware.

Currently there are accelerators for:

  • CPU

  • GPU

  • TPU

  • IPU

Each Accelerator gets two plugins upon initialization: One to handle differences from the training routine and one to handle different precisions.

Parameters
  • precision_plugin (PrecisionPlugin) – the plugin to handle precision-specific parts

  • training_type_plugin (TrainingTypePlugin) – the plugin to handle different training routines

all_gather(tensor, group=None, sync_grads=False)[source]

Function to gather a tensor from several distributed processes.

Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.all_gather directly.

Parameters
  • tensor (Tensor) – tensor of shape (batch, …)

  • group (Optional[Any]) – the process group to gather results from. Defaults to all processes (world)

  • sync_grads (bool) – flag that allows users to synchronize gradients for all_gather op

Return type

Tensor

Returns

A tensor of shape (world_size, batch, …)

abstract static auto_device_count()[source]

Get the devices when set to auto.

Return type

int

backward(closure_loss, *args, **kwargs)[source]

Forwards backward-calls to the precision plugin.

Parameters

closure_loss (Tensor) – a tensor holding the loss value to backpropagate

Return type

Tensor

barrier(name=None)[source]

Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.barrier directly.

Return type

None

batch_to_device(batch, device=None, dataloader_idx=0)[source]

Moves the batch to the correct device. The returned batch is of the same type as the input batch, just having all tensors on the correct device.

Parameters
  • batch (Any) – The batch of samples to move to the correct device

  • device (Optional[device]) – The target device

  • dataloader_idx (int) – The index of the dataloader to which the batch belongs.

Return type

Any

broadcast(obj, src=0)[source]

Broadcasts an object to all processes, such that the src object is broadcast to all other ranks if needed.

Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.broadcast directly.

Parameters
  • obj (object) – Object to broadcast to all process, usually a tensor or collection of tensors.

  • src (int) – The source rank of which the object will be broadcast from

Return type

object

connect(model)[source]

Transfers ownership of the model to this plugin.

See deprecation warning below. :rtype: None

Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.on_train_batch_start directly.

dispatch(trainer)[source]

Hook to do something before the training/evaluation/prediction starts.

Return type

None

get_device_stats(device)[source]

Gets stats for a given device.

Parameters

device (Union[str, device]) – device for which to get stats

Return type

Dict[str, Any]

Returns

Dictionary of device stats

lightning_module_state_dict()[source]

Returns state of model. :rtype: Dict[str, Union[Any, Tensor]]

Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.lightning_module_state_dict directly.

Allows for syncing/collating model state from processes in custom plugins.

model_sharded_context()[source]

Provide hook to create modules in a distributed aware context. This is useful for when we’d like to.

shard the model instantly - useful for extremely large models. Can save memory and initialization time.

Return type

Generator[None, None, None]

Returns

Model parallel context.

on_predict_end()[source]

Called when predict ends.

See deprecation warning below. :rtype: None

Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.on_predict_end directly.

on_predict_start()[source]

Called when predict begins.

See deprecation warning below. :rtype: None

Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.on_predict_start directly.

on_test_end()[source]

Called when test end.

See deprecation warning below. :rtype: None

Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.on_test_end directly.

on_test_start()[source]

Called when test begins.

See deprecation warning below. :rtype: None

Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.on_test_start directly.

on_train_batch_start(batch, batch_idx, dataloader_idx=0)[source]

Called in the training loop before anything happens for that batch.

See deprecation warning below. :rtype: None

Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.on_train_batch_start directly.

on_train_end()[source]

Called when train ends.

See deprecation warning below. :rtype: None

Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.on_train_end directly.

on_train_start()[source]

Called when train begins.

Return type

None

on_validation_end()[source]

Called when validation ends.

See deprecation warning below. :rtype: None

Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.on_validation_end directly.

on_validation_start()[source]

Called when validation begins.

See deprecation warning below. :rtype: None

Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.on_validation_start directly.

optimizer_state(optimizer)[source]

Returns state of an optimizer.

Allows for syncing/collating optimizer state from processes in custom plugins.

Return type

Dict[str, Tensor]

optimizer_step(optimizer, opt_idx, closure, model=None, **kwargs)[source]

performs the actual optimizer step.

Parameters
  • optimizer (Optimizer) – the optimizer performing the step

  • opt_idx (int) – index of the current optimizer

  • closure (Callable[[], Any]) – closure calculating the loss value

  • model (Union[LightningModule, Module, None]) – reference to the model, optionally defining optimizer step related hooks

  • **kwargs (Any) – Any extra arguments to optimizer.step

Return type

None

optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx)[source]

Zeros all model parameter’s gradients.

Return type

None

post_dispatch(trainer)[source]

Hook to do something after the training/evaluation/prediction starts.

Return type

None

post_training_step()[source]

Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.post_training_step directly.

Return type

None

pre_dispatch(trainer)[source]

Hook to do something before the training/evaluation/prediction starts.

Return type

None

predict_step(step_kwargs)[source]

The actual predict step.

See predict_step() for more details

Return type

Union[Tensor, Dict[str, Any]]

process_dataloader(dataloader)[source]

Wraps the dataloader if necessary.

Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.process_dataloader directly.

Parameters

dataloader (Union[Iterable, DataLoader]) – iterable. Ideally of type: torch.utils.data.DataLoader

Return type

Union[Iterable, DataLoader]

save_checkpoint(checkpoint, filepath)[source]

Save model/training states as a checkpoint file through state-dump and file-write.

Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.save_checkpoint directly.

Parameters
  • checkpoint (Dict[str, Any]) – dict containing model and trainer state

  • filepath (Union[str, Path]) – write-target file’s path

Return type

None

setup(trainer)[source]

Setup plugins for the trainer fit and creates optimizers.

Parameters

trainer (Trainer) – the trainer instance

Return type

None

setup_environment()[source]

Setup any processes or distributed connections.

This is called before the LightningModule/DataModule setup hook which allows the user to access the accelerator environment before setup is complete.

Return type

None

setup_optimizers(trainer)[source]

Creates optimizers and schedulers.

Parameters

trainer (Trainer) – the Trainer, these optimizers should be connected to

Return type

None

setup_precision_plugin()[source]

Attaches the precision plugin to the accelerator.

Return type

None

setup_training_type_plugin()[source]

Attaches the training type plugin to the accelerator.

Return type

None

start_evaluating(trainer)[source]

Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.start_evaluating directly.

Return type

None

start_predicting(trainer)[source]

Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.start_predicting directly.

Return type

None

start_training(trainer)[source]

Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.start_training directly.

Return type

None

teardown()[source]

This method is called to teardown the training process.

It is the right place to release memory and free other resources.

Return type

None

test_step(step_kwargs)[source]

The actual test step.

See test_step() for more details

Return type

Union[Tensor, Dict[str, Any], None]

test_step_end(output)[source]

A hook to do something at the end of the test step.

Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.test_step_end directly.

Parameters

output (Union[Tensor, Dict[str, Any], None]) – the output of the test step

Return type

Union[Tensor, Dict[str, Any], None]

training_step(step_kwargs)[source]

The actual training step.

See training_step() for more details

Return type

Union[Tensor, Dict[str, Any]]

training_step_end(output)[source]

A hook to do something at the end of the training step.

Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.training_step_end directly.

Parameters

output (Union[Tensor, Dict[str, Any]]) – the output of the training step

Return type

Union[Tensor, Dict[str, Any]]

validation_step(step_kwargs)[source]

The actual validation step.

See validation_step() for more details

Return type

Union[Tensor, Dict[str, Any], None]

validation_step_end(output)[source]

A hook to do something at the end of the validation step.

Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.validation_step_end directly.

Parameters

output (Union[Tensor, Dict[str, Any], None]) – the output of the validation step

Return type

Union[Tensor, Dict[str, Any], None]

property lightning_module: pytorch_lightning.core.lightning.LightningModule

Returns the pure LightningModule.

To get the potentially wrapped model use Accelerator.model

property model: torch.nn.modules.module.Module

Returns the model.

This can also be a wrapped LightningModule. For retrieving the pure LightningModule use Accelerator.lightning_module

property restore_checkpoint_after_pre_dispatch: bool

Override to delay restoring from checkpoint till after pre-dispatch. This is useful when the plugin requires all the setup hooks to run before loading checkpoint.

Deprecated since version v1.5: This property is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.restore_checkpoint_after_pre_dispatch directly.

Returns

If true, restore checkpoint after pre_dispatch.

property results: Any

The results of the last run will be cached within the training type plugin.

Deprecated since version v1.5: This property is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.results directly.

In distributed training, we make sure to transfer the results to the appropriate master process.

property root_device: torch.device

Returns the root device.

property setup_optimizers_in_pre_dispatch: bool

Override to delay setting optimizers and schedulers till after dispatch. This is useful when the TrainingTypePlugin requires operating on the wrapped accelerator model. However this may break certain precision plugins such as APEX which require optimizers to be set.

Deprecated since version v1.5: This property is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.setup_optimizers_in_pre_dispatch directly.

Returns

If True, delay setup optimizers until pre_dispatch, else call within setup.