Changing batch size during trainig

Hello, I managed to implement gradual unfreezing with callbacks. However since I had to set the batch size to be able to take all parameters after unfreezing the last layer, for a long time the GPU memory has very low utilization. Is it possible to dynamically change batch size before every epoch ? I tried the naive approach by changing the batch size on in BaseFinetuning callback which obviously due to access validation which says that I indeed shouldn’t modify batch_size during the training.

When I consulted the source code the callbacks for on_train_epoch_start are called right after the dataloaders has been setup(which don’t even need to reload when not specified in trainer).

My question is has anyone managed to successfuly modify the batch_size during the training ? If so how have you done it ? The only thing that comes to my mind without heavily modifying lighting source it to run a fit for one epoch → change dataloaders → run for another.

This I believed should work(with modification to optimizer as I use scheduler that depends on total training steps). Is there a better way tho ?

Hi

You can do this by enabling

Trainer(reload_dataloaders_every_n_epochs =1)

This will trigger the train_dataloader() etc. hooks every epoch and fetch a new dataloader instance. This means you can put conditional logic inside your x_dataloader() hooks that would return a batch size dynamically. For example:

def train_dataloader(self):
    if self.trainer.current_epoch > 2:
        batch_size = 10
    else:
        batch_size = 5

    return DataLoader(..., batch_size=batch_size)

Just a dummy example, but I hope it demonstrates it.

EDIT: typo fixes thanks @hynky

That’s great I had no idea that Datamodule also has access to Trainer !!!
Thank you so much :), this means that I can also change gradient accumulation to keep the effective batch_size consistent great !

Just to clarify in 2.0.0 the function would be like this:

def train_dataloader(self):
    if self.trainer.current_epoch > 2:
        batch_size = 10
    else:
        batch_size = 5

    return DataLoader(..., batch_size=batch_size)

Also the parameter in the Trainer is

reload_dataloaders_every_n_epochs

Thank you so much again :slight_smile:

1 Like

Thanks, I fixed the typos in my original reply.
Glad this worked out!