Tuner
- class lightning.pytorch.tuner.tuning.Tuner(trainer)[source]
Bases:
object
Tuner class to tune your model.
- lr_find(model, train_dataloaders=None, val_dataloaders=None, dataloaders=None, datamodule=None, method='fit', min_lr=1e-08, max_lr=1, num_training=100, mode='exponential', early_stop_threshold=4.0, update_attr=True, attr_name='')[source]
Enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking a good starting learning rate.
- Parameters
model (
LightningModule
) – Model to tune.train_dataloaders (
Union
[Any
,LightningDataModule
,None
]) – A collection oftorch.utils.data.DataLoader
or aLightningDataModule
specifying training samples. In the case of multiple dataloaders, please see this section.val_dataloaders (
Optional
[Any
]) – Atorch.utils.data.DataLoader
or a sequence of them specifying validation samples.dataloaders (
Optional
[Any
]) – Atorch.utils.data.DataLoader
or a sequence of them specifying val/test/predict samples used for running tuner on validation/testing/prediction.datamodule (
Optional
[LightningDataModule
]) – An instance ofLightningDataModule
.method (
Literal
[‘fit’, ‘validate’, ‘test’, ‘predict’]) – Method to run tuner on. It can be any of("fit", "validate", "test", "predict")
.min_lr (
float
) – minimum learning rate to investigatemax_lr (
float
) – maximum learning rate to investigatenum_training (
int
) – number of learning rates to testmode (
str
) –Search strategy to update learning rate after each batch:
'exponential'
: Increases the learning rate exponentially.'linear'
: Increases the learning rate linearly.
early_stop_threshold (
float
) – Threshold for stopping the search. If the loss at any point is larger than early_stop_threshold*best_loss then the search is stopped. To disable, set to None.update_attr (
bool
) – Whether to update the learning rate attribute or not.attr_name (
str
) – Name of the attribute which stores the learning rate. The names ‘learning_rate’ or ‘lr’ get automatically detected. Otherwise, set the name here.
- Raises
MisconfigurationException – If learning rate/lr in
model
ormodel.hparams
isn’t overridden, or if you are using more than one optimizer.- Return type
Optional
[_LRFinder
]
- scale_batch_size(model, train_dataloaders=None, val_dataloaders=None, dataloaders=None, datamodule=None, method='fit', mode='power', steps_per_trial=3, init_val=2, max_trials=25, batch_arg_name='batch_size')[source]
Iteratively try to find the largest batch size for a given model that does not give an out of memory (OOM) error.
- Parameters
model (
LightningModule
) – Model to tune.train_dataloaders (
Union
[Any
,LightningDataModule
,None
]) – A collection oftorch.utils.data.DataLoader
or aLightningDataModule
specifying training samples. In the case of multiple dataloaders, please see this section.val_dataloaders (
Optional
[Any
]) – Atorch.utils.data.DataLoader
or a sequence of them specifying validation samples.dataloaders (
Optional
[Any
]) – Atorch.utils.data.DataLoader
or a sequence of them specifying val/test/predict samples used for running tuner on validation/testing/prediction.datamodule (
Optional
[LightningDataModule
]) – An instance ofLightningDataModule
.method (
Literal
[‘fit’, ‘validate’, ‘test’, ‘predict’]) – Method to run tuner on. It can be any of("fit", "validate", "test", "predict")
.mode (
str
) –Search strategy to update the batch size:
'power'
: Keep multiplying the batch size by 2, until we get an OOM error.'binsearch'
: Initially keep multiplying by 2 and after encountering an OOM errordo a binary search between the last successful batch size and the batch size that failed.
steps_per_trial (
int
) – number of steps to run with a given batch size. Ideally 1 should be enough to test if an OOM error occurs, however in practise a few are neededinit_val (
int
) – initial batch size to start the search withmax_trials (
int
) – max number of increases in batch size done before algorithm is terminatedbatch_arg_name (
str
) –name of the attribute that stores the batch size. It is expected that the user has provided a model or datamodule that has a hyperparameter with that name. We will look for this attribute name in the following places
model
model.hparams
trainer.datamodule
(the datamodule passed to the tune method)
- Return type