Accelerator¶
- class pytorch_lightning.accelerators.Accelerator(precision_plugin, training_type_plugin)[source]¶
Bases:
object
The Accelerator Base Class. An Accelerator is meant to deal with one type of Hardware.
Currently there are accelerators for:
CPU
GPU
TPU
Each Accelerator gets two plugins upon initialization: One to handle differences from the training routine and one to handle different precisions.
- Parameters
precision_plugin¶ (
PrecisionPlugin
) – the plugin to handle precision-specific partstraining_type_plugin¶ (
TrainingTypePlugin
) – the plugin to handle different training routines
- all_gather(tensor, group=None, sync_grads=False)[source]¶
Function to gather a tensor from several distributed processes.
- Parameters
- Return type
- Returns
A tensor of shape (world_size, batch, …)
- batch_to_device(batch, device=None, dataloader_idx=None)[source]¶
Moves the batch to the correct device. The returned batch is of the same type as the input batch, just having all tensors on the correct device.
- broadcast(obj, src=0)[source]¶
Broadcasts an object to all processes, such that the src object is broadcast to all other ranks if needed.
- clip_gradients(optimizer, clip_val, gradient_clip_algorithm=<GradClipAlgorithmType.NORM: 'norm'>)[source]¶
clips all the optimizer parameters to the given value
- Return type
- connect_precision_plugin(plugin)[source]¶
Attaches the precision plugin to the accelerator
- Return type
- connect_training_type_plugin(plugin, model)[source]¶
Attaches the training type plugin to the accelerator. Also transfers ownership of the model to this plugin
- Return type
- dispatch(trainer)[source]¶
Hook to do something before the training/evaluation/prediction starts.
- Return type
- lightning_module_state_dict()[source]¶
Returns state of model. Allows for syncing/collating model state from processes in custom plugins.
- model_sharded_context()[source]¶
Provide hook to create modules in a distributed aware context. This is useful for when we’d like to shard the model instantly - useful for extremely large models. Can save memory and initialization time.
- on_reset_predict_dataloader(dataloader)[source]¶
Called before resetting the predict dataloader.
- Return type
- on_reset_test_dataloader(dataloader)[source]¶
Called before resetting the test dataloader.
- Return type
- on_reset_train_dataloader(dataloader)[source]¶
Called before resetting the train dataloader.
- Return type
- on_reset_val_dataloader(dataloader)[source]¶
Called before resetting the val dataloader.
- Return type
- on_train_batch_start(batch, batch_idx, dataloader_idx)[source]¶
Called in the training loop before anything happens for that batch.
- Return type
- optimizer_state(optimizer)[source]¶
Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom plugins.
- optimizer_step(optimizer, opt_idx, lambda_closure, **kwargs)[source]¶
performs the actual optimizer step.
- optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx)[source]¶
Zeros all model parameter’s gradients
- Return type
- post_dispatch(trainer)[source]¶
Hook to do something after the training/evaluation/prediction starts.
- Return type
- pre_dispatch(trainer)[source]¶
Hook to do something before the training/evaluation/prediction starts.
- Return type
- predict_step(step_kwargs)[source]¶
The actual predict step.
- Parameters
step_kwargs¶ (
Dict
[str
,Union
[Any
,int
]]) –the arguments for the models predict step. Can consist of the following:
batch (
Tensor
| (Tensor
, …) | [Tensor
, …]): The output of yourDataLoader
. A tensor, tuple or list.batch_idx (int): The index of this batch.
dataloader_idx (int): The index of the dataloader that produced this batch (only if multiple predict dataloaders used).
- Return type
- process_dataloader(dataloader)[source]¶
Wraps the dataloader if necessary
- Parameters
dataloader¶ (
Union
[Iterable
,DataLoader
]) – iterable. Ideally of type:torch.utils.data.DataLoader
- Return type
- save_checkpoint(checkpoint, filepath)[source]¶
Save model/training states as a checkpoint file through state-dump and file-write.
- setup(trainer, model)[source]¶
Setup plugins for the trainer fit and creates optimizers.
- Parameters
model¶ (
LightningModule
) – the LightningModule
- Return type
- setup_environment()[source]¶
Setup any processes or distributed connections. This is called before the LightningModule/DataModule setup hook which allows the user to access the accelerator environment before setup is complete.
- Return type
- setup_training_type_plugin(model)[source]¶
Attaches the training type plugin to the accelerator.
- Return type
- teardown()[source]¶
This method is called to teardown the training process. It is the right place to release memory and free other resources.
- Return type
- test_step(step_kwargs)[source]¶
The actual test step.
- Parameters
step_kwargs¶ (
Dict
[str
,Union
[Any
,int
]]) –the arguments for the models test step. Can consist of the following:
batch (
Tensor
| (Tensor
, …) | [Tensor
, …]): The output of yourDataLoader
. A tensor, tuple or list.batch_idx (int): The index of this batch.
dataloader_idx (int): The index of the dataloader that produced this batch (only if multiple test dataloaders used).
- Return type
- training_step(step_kwargs)[source]¶
The actual training step.
- Parameters
step_kwargs¶ (
Dict
[str
,Union
[Any
,int
]]) –the arguments for the models training step. Can consist of the following:
batch (
Tensor
| (Tensor
, …) | [Tensor
, …]): The output of yourDataLoader
. A tensor, tuple or list.batch_idx (int): Integer displaying index of this batch
optimizer_idx (int): When using multiple optimizers, this argument will also be present.
hiddens(
Tensor
): Passed in iftruncated_bptt_steps
> 0.
- Return type
- validation_step(step_kwargs)[source]¶
The actual validation step.
- Parameters
step_kwargs¶ (
Dict
[str
,Union
[Any
,int
]]) –the arguments for the models validation step. Can consist of the following:
batch (
Tensor
| (Tensor
, …) | [Tensor
, …]): The output of yourDataLoader
. A tensor, tuple or list.batch_idx (int): The index of this batch
dataloader_idx (int): The index of the dataloader that produced this batch (only if multiple val dataloaders used)
- Return type
- property call_configure_sharded_model_hook: bool¶
Allow model parallel hook to be called in suitable environments determined by the training type plugin. This is useful for when we want to shard the model once within fit.
- Returns
True if we want to call the model parallel setup hook.
- property lightning_module: pytorch_lightning.core.lightning.LightningModule¶
Returns the pure LightningModule. To get the potentially wrapped model use
Accelerator.model
- property model: torch.nn.modules.module.Module¶
Returns the model. This can also be a wrapped LightningModule. For retrieving the pure LightningModule use
Accelerator.lightning_module
- property results: Any¶
The results of the last run will be cached within the training type plugin. In distributed training, we make sure to transfer the results to the appropriate master process.
- property root_device: torch.device¶
Returns the root device
- property setup_optimizers_in_pre_dispatch: bool¶
Override to delay setting optimizers and schedulers till after dispatch. This is useful when the TrainingTypePlugin requires operating on the wrapped accelerator model. However this may break certain precision plugins such as APEX which require optimizers to be set.
- Returns
If True, delay setup optimizers until pre_dispatch, else call within setup.