Strategy
- class pytorch_lightning.strategies.Strategy(accelerator=None, checkpoint_io=None, precision_plugin=None)[source]
Bases:
abc.ABC
Base class for all strategies that change the behaviour of the training, validation and test- loop.
- abstract all_gather(tensor, group=None, sync_grads=False)[source]
Perform an all_gather on all processes.
- backward(closure_loss, *args, **kwargs)[source]
Forwards backward-calls to the precision plugin.
- abstract barrier(name=None)[source]
Synchronizes all processes which blocks processes until the whole group enters this function.
- batch_to_device(batch, device=None, dataloader_idx=0)[source]
Moves the batch to the correct device.
The returned batch is of the same type as the input batch, just having all tensors on the correct device.
- abstract broadcast(obj, src=0)[source]
Broadcasts an object to all processes.
- Parameters
src (
int
) – source rank
- Return type
~TBroadcast
- connect(model)[source]
Called by the accelerator to connect the accelerator and the model with this plugin.
- Return type
- dispatch(trainer)[source]
Hook to do something before the training/evaluation/prediction starts.
- Return type
- lightning_module_state_dict()[source]
Returns model state.
- model_sharded_context()[source]
Provide hook to create modules in a distributed aware context. This is useful for when we’d like to shard the model instantly, which is useful for extremely large models which can save memory and initialization time.
Returns: Model parallel context.
- Return type
- on_predict_end()[source]
Called when predict ends.
- on_train_batch_start(batch, batch_idx, dataloader_idx=0)[source]
Called in the training loop before anything happens for that batch.
- Return type
- optimizer_state(optimizer)[source]
Returns state of an optimizer.
Allows for syncing/collating optimizer state from processes in custom plugins.
- optimizer_step(optimizer, opt_idx, closure, model=None, **kwargs)[source]
Performs the actual optimizer step.
- Parameters
- Return type
- post_dispatch(trainer)[source]
Deprecated since version v1.6: This method has been deprecated in v1.6 and will be removed in v1.7. Use
teardown()
instead.Hook to do something after the training/evaluation/prediction finishes.
- Return type
- predict_step(*args, **kwargs)[source]
The actual predict step.
See
predict_step()
for more details
- process_dataloader(dataloader)[source]
Wraps the dataloader if necessary.
- Parameters
dataloader (
DataLoader
) – iterable. Ideally of type:torch.utils.data.DataLoader
- Return type
- abstract reduce(tensor, group=None, reduce_op='mean')[source]
Reduces the given tensor (e.g. across GPUs/processes).
- reduce_boolean_decision(decision)[source]
Reduce the early stopping decision across all processes.
- Return type
- remove_checkpoint(filepath)[source]
Remove checkpoint filepath from the filesystem.
- save_checkpoint(checkpoint, filepath, storage_options=None)[source]
Save model/training states as a checkpoint file through state-dump and file-write.
- setup(trainer)[source]
Setup plugins for the trainer fit and creates optimizers.
- setup_environment()[source]
Setup any processes or distributed connections.
This is called before the LightningModule/DataModule setup hook which allows the user to access the accelerator environment before setup is complete.
- Return type
- setup_optimizers(trainer)[source]
Creates optimizers and schedulers.
- teardown()[source]
This method is called to teardown the training process.
It is the right place to release memory and free other resources.
- Return type
- test_step(*args, **kwargs)[source]
The actual test step.
See
test_step()
for more details
- training_step(*args, **kwargs)[source]
The actual training step.
See
training_step()
for more details
- validation_step(*args, **kwargs)[source]
The actual validation step.
See
validation_step()
for more details
- property handles_gradient_accumulation: bool
Whether the plugin handles gradient accumulation internally.
- Return type
- abstract property is_global_zero: bool
Whether the current process is the rank zero process not only on the local node, but for all nodes.
- Return type
- property lightning_module: Optional[pytorch_lightning.core.lightning.LightningModule]
Returns the pure LightningModule without potential wrappers.
- Return type
- property lightning_restore_optimizer: bool
Override to disable Lightning restoring optimizers/schedulers.
This is useful for plugins which manage restoring optimizers/schedulers.
- Return type
- property model: Optional[torch.nn.modules.module.Module]
Returns the potentially wrapped LightningModule.
- property restore_checkpoint_after_setup: bool
Override to delay restoring from checkpoint till after pre-dispatch. This is useful when the plugin requires all the setup hooks to run before loading checkpoint.
- Return type
- Returns
If true, restore checkpoint after pre_dispatch.
- abstract property root_device: torch.device
Returns the root device.
- Return type