Accelerator¶
- class pytorch_lightning.accelerators.Accelerator(precision_plugin, training_type_plugin)[source]¶
Bases:
object
The Accelerator Base Class. An Accelerator is meant to deal with one type of Hardware.
Currently there are accelerators for:
CPU
GPU
TPU
IPU
Each Accelerator gets two plugins upon initialization: One to handle differences from the training routine and one to handle different precisions.
- Parameters
precision_plugin¶ (
PrecisionPlugin
) – the plugin to handle precision-specific partstraining_type_plugin¶ (
TrainingTypePlugin
) – the plugin to handle different training routines
- all_gather(tensor, group=None, sync_grads=False)[source]¶
Function to gather a tensor from several distributed processes.
Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.all_gather directly.
- Parameters
- Return type
- Returns
A tensor of shape (world_size, batch, …)
- barrier(name=None)[source]¶
Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.barrier directly.
- Return type
- batch_to_device(batch, device=None, dataloader_idx=0)[source]¶
Moves the batch to the correct device. The returned batch is of the same type as the input batch, just having all tensors on the correct device.
- broadcast(obj, src=0)[source]¶
Broadcasts an object to all processes, such that the src object is broadcast to all other ranks if needed.
Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.broadcast directly.
- connect(model)[source]¶
Transfers ownership of the model to this plugin.
See deprecation warning below.
Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.on_train_batch_start directly.
- Return type
- dispatch(trainer)[source]¶
Hook to do something before the training/evaluation/prediction starts.
- Return type
- lightning_module_state_dict()[source]¶
Returns state of model.
Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.lightning_module_state_dict directly.
Allows for syncing/collating model state from processes in custom plugins.
- model_sharded_context()[source]¶
Provide hook to create modules in a distributed aware context. This is useful for when we’d like to.
shard the model instantly - useful for extremely large models. Can save memory and initialization time.
- on_predict_end()[source]¶
Called when predict ends.
See deprecation warning below.
Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.on_predict_end directly.
- Return type
- on_predict_start()[source]¶
Called when predict begins.
See deprecation warning below.
Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.on_predict_start directly.
- Return type
- on_test_end()[source]¶
Called when test end.
See deprecation warning below.
Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.on_test_end directly.
- Return type
- on_test_start()[source]¶
Called when test begins.
See deprecation warning below.
Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.on_test_start directly.
- Return type
- on_train_batch_start(batch, batch_idx, dataloader_idx=0)[source]¶
Called in the training loop before anything happens for that batch.
See deprecation warning below.
Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.on_train_batch_start directly.
- Return type
- on_train_end()[source]¶
Called when train ends.
See deprecation warning below.
Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.on_train_end directly.
- Return type
- on_validation_end()[source]¶
Called when validation ends.
See deprecation warning below.
Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.on_validation_end directly.
- Return type
- on_validation_start()[source]¶
Called when validation begins.
See deprecation warning below.
Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.on_validation_start directly.
- Return type
- optimizer_state(optimizer)[source]¶
Returns state of an optimizer.
Allows for syncing/collating optimizer state from processes in custom plugins.
- optimizer_step(optimizer, opt_idx, closure, model=None, **kwargs)[source]¶
performs the actual optimizer step.
- optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx)[source]¶
Zeros all model parameter’s gradients.
- Return type
- post_dispatch(trainer)[source]¶
Hook to do something after the training/evaluation/prediction starts.
- Return type
- post_training_step()[source]¶
Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.post_training_step directly.
- Return type
- pre_dispatch(trainer)[source]¶
Hook to do something before the training/evaluation/prediction starts.
- Return type
- predict_step(step_kwargs)[source]¶
The actual predict step.
See
predict_step()
for more details
- process_dataloader(dataloader)[source]¶
Wraps the dataloader if necessary.
Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.process_dataloader directly.
- Parameters
dataloader¶ (
Union
[Iterable
,DataLoader
]) – iterable. Ideally of type:torch.utils.data.DataLoader
- Return type
- save_checkpoint(checkpoint, filepath)[source]¶
Save model/training states as a checkpoint file through state-dump and file-write.
Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.save_checkpoint directly.
- setup_environment()[source]¶
Setup any processes or distributed connections.
This is called before the LightningModule/DataModule setup hook which allows the user to access the accelerator environment before setup is complete.
- Return type
- setup_training_type_plugin()[source]¶
Attaches the training type plugin to the accelerator.
- Return type
- start_evaluating(trainer)[source]¶
Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.start_evaluating directly.
- Return type
- start_predicting(trainer)[source]¶
Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.start_predicting directly.
- Return type
- start_training(trainer)[source]¶
Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.start_training directly.
- Return type
- teardown()[source]¶
This method is called to teardown the training process.
It is the right place to release memory and free other resources.
- Return type
- test_step(step_kwargs)[source]¶
The actual test step.
See
test_step()
for more details
- test_step_end(output)[source]¶
A hook to do something at the end of the test step.
Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.test_step_end directly.
- training_step(step_kwargs)[source]¶
The actual training step.
See
training_step()
for more details
- training_step_end(output)[source]¶
A hook to do something at the end of the training step.
Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.training_step_end directly.
- validation_step(step_kwargs)[source]¶
The actual validation step.
See
validation_step()
for more details
- validation_step_end(output)[source]¶
A hook to do something at the end of the validation step.
Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.validation_step_end directly.
- property lightning_module: pytorch_lightning.core.lightning.LightningModule¶
Returns the pure LightningModule.
To get the potentially wrapped model use
Accelerator.model
- Return type
- property model: torch.nn.modules.module.Module¶
Returns the model.
This can also be a wrapped LightningModule. For retrieving the pure LightningModule use
Accelerator.lightning_module
- Return type
- property restore_checkpoint_after_pre_dispatch: bool¶
Override to delay restoring from checkpoint till after pre-dispatch. This is useful when the plugin requires all the setup hooks to run before loading checkpoint.
Deprecated since version v1.5: This property is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.restore_checkpoint_after_pre_dispatch directly.
- Return type
- Returns
If true, restore checkpoint after pre_dispatch.
- property results: Any¶
The results of the last run will be cached within the training type plugin.
Deprecated since version v1.5: This property is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.results directly.
In distributed training, we make sure to transfer the results to the appropriate master process.
- Return type
- property setup_optimizers_in_pre_dispatch: bool¶
Override to delay setting optimizers and schedulers till after dispatch. This is useful when the TrainingTypePlugin requires operating on the wrapped accelerator model. However this may break certain precision plugins such as APEX which require optimizers to be set.
Deprecated since version v1.5: This property is deprecated in v1.5 and will be removed in v1.6. Please call training_type_plugin.setup_optimizers_in_pre_dispatch directly.
- Return type
- Returns
If True, delay setup optimizers until pre_dispatch, else call within setup.