BatchSizeFinder¶
- class lightning.pytorch.callbacks.BatchSizeFinder(mode='power', steps_per_trial=3, init_val=2, max_trials=25, batch_arg_name='batch_size')[source]¶
Bases:
Callback
The
BatchSizeFinder
callback tries to find the largest batch size for a given model that does not give an out of memory (OOM) error. All you need to do is add it as a callback inside Trainer and calltrainer.{fit,validate,test,predict}
. Internally it calls the respective step functionsteps_per_trial
times for each batch size until one of the batch sizes generates an OOM error.Warning
This is an experimental feature.
- Parameters:
search strategy to update the batch size:
'power'
: Keep multiplying the batch size by 2, until we get an OOM error.'binsearch'
: Initially keep multiplying by 2 and after encountering an OOM error do a binary search between the last successful batch size and the batch size that failed.
steps_per_trial¶ (
int
) – number of steps to run with a given batch size. Ideally 1 should be enough to test if an OOM error occurs, however in practice a few are needed.init_val¶ (
int
) – initial batch size to start the search with.max_trials¶ (
int
) – max number of increases in batch size done before algorithm is terminatedname of the attribute that stores the batch size. It is expected that the user has provided a model or datamodule that has a hyperparameter with that name. We will look for this attribute name in the following places
model
model.hparams
trainer.datamodule
(the datamodule passed to the tune method)
Example:
# 1. Customize the BatchSizeFinder callback to run at different epochs. This feature is # useful while fine-tuning models since you can't always use the same batch size after # unfreezing the backbone. from lightning.pytorch.callbacks import BatchSizeFinder class FineTuneBatchSizeFinder(BatchSizeFinder): def __init__(self, milestones, *args, **kwargs): super().__init__(*args, **kwargs) self.milestones = milestones def on_fit_start(self, *args, **kwargs): return def on_train_epoch_start(self, trainer, pl_module): if trainer.current_epoch in self.milestones or trainer.current_epoch == 0: self.scale_batch_size(trainer, pl_module) trainer = Trainer(callbacks=[FineTuneBatchSizeFinder(milestones=(5, 10))]) trainer.fit(...)
Example:
# 2. Run batch size finder for validate/test/predict. from lightning.pytorch.callbacks import BatchSizeFinder class EvalBatchSizeFinder(BatchSizeFinder): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def on_fit_start(self, *args, **kwargs): return def on_test_start(self, trainer, pl_module): self.scale_batch_size(trainer, pl_module) trainer = Trainer(callbacks=[EvalBatchSizeFinder()]) trainer.test(...)