TrainingTypePlugin¶
- class pytorch_lightning.plugins.training_type.TrainingTypePlugin[source]¶
Bases:
pytorch_lightning.plugins.base_plugin.Plugin
,abc.ABC
Base class for all training type plugins that change the behaviour of the training, validation and test-loop.
- abstract all_gather(tensor, group=None, sync_grads=False)[source]¶
Perform a all_gather on all processes
- Return type
- abstract barrier(name=None)[source]¶
Forces all possibly joined processes to wait for each other
- Return type
- abstract broadcast(obj, src=0)[source]¶
Broadcasts an object to all processes
- Return type
TypeVar
(T
)
- connect(model)[source]¶
Called by the accelerator to connect the accelerator and the model with this plugin
- Return type
- model_sharded_context()[source]¶
Provide hook to create modules in a distributed aware context. This is useful for when we’d like to shard the model instantly, which is useful for extremely large models which can save memory and initialization time.
Returns: Model parallel context.
- Return type
- on_reset_predict_dataloader(dataloader)[source]¶
Called before resetting the predict dataloader.
- Return type
- on_reset_test_dataloader(dataloader)[source]¶
Called before resetting the test dataloader.
- Return type
- on_reset_train_dataloader(dataloader)[source]¶
Called before resetting the train dataloader.
- Return type
- on_reset_val_dataloader(dataloader)[source]¶
Called before resetting the val dataloader.
- Return type
- on_train_batch_start(batch, batch_idx, dataloader_idx)[source]¶
Called in the training loop before anything happens for that batch.
- Return type
- post_optimizer_step(optimizer, optimizer_idx, **kwargs)[source]¶
Hook to do something after each optimizer step.
- Return type
- process_dataloader(dataloader)[source]¶
Wraps the dataloader if necessary
- Parameters
dataloader¶ (
Union
[Iterable
,DataLoader
]) – iterable. Ideally of type:torch.utils.data.DataLoader
- Return type
- abstract reduce(tensor, *args, **kwargs)[source]¶
Reduces the given tensor (e.g. across GPUs/processes).
- reduce_boolean_decision(decision)[source]¶
Reduce the early stopping decision across all processes
- Return type
- save_checkpoint(checkpoint, filepath)[source]¶
Save model/training states as a checkpoint file through state-dump and file-write.
- setup_environment()[source]¶
Setup any processes or distributed connections. This is called before the LightningModule/DataModule setup hook which allows the user to access the accelerator environment before setup is complete.
- Return type
- abstract teardown()[source]¶
This method is called to teardown the training process. It is the right place to release memory and free other resources.
- Return type
- update_global_step(total_batch_idx, current_global_step)[source]¶
Provide a hook to count optimizer step calls.
- Parameters
- Return type
Returns: New optimizer step calls
- property call_configure_sharded_model_hook: bool¶
Allow model parallel hook to be called in suitable environments determined by the training type plugin. This is useful for when we want to shard the model once within fit. Returns: True if we want to call the model parallel setup hook.
- abstract property is_global_zero: bool¶
Whether the current process is the rank zero process not only on the local node, but for all nodes.
- property lightning_module: pytorch_lightning.core.lightning.LightningModule¶
Returns the pure LightningModule without potential wrappers
- property model: Optional[torch.nn.modules.module.Module]¶
Returns the potentially wrapped LightningModule
- abstract property on_gpu: bool¶
Returns whether the current process is done on GPU
- abstract property on_tpu: bool¶
Returns whether the current process is done on TPU
- property results: Optional[Union[List[Dict[str, float]], List[Any], List[List[Any]]]]¶
Enables plugin-agnostic access to the result returned by the training/evaluation/prediction run. The result is cached instead of returned directly, because some plugins require transmitting the results from one multiprocessing context to another in a separate step. For example, the plugins that use the “spawn” start-method send the result to the master process through a multiprocessing queue (shared memory).
- abstract property root_device: torch.device¶
Returns the root device
- property setup_optimizers_in_pre_dispatch: bool¶
Override to delay setting optimizers and schedulers till after dispatch. This is useful when the TrainingTypePlugin requires operating on the wrapped accelerator model. However this may break certain precision plugins such as APEX which require optimizers to be set. Returns: If True, delay setup optimizers till pre_dispatch, else call within setup.
- property should_rank_save_checkpoint: bool¶
Returns whether the checkpoint should be saved (rank based)