The actual meaning of the len(batch) in lightning callback

I tried to collect the batch size using lightning hook, e.g.:

def on_train_batch_start(  # type: ignore[override]
        self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int
    ) -> None:
      batch_size = len(batch)

However, the batch_size is not the actual batch size of data loading. What’s the meaning of this value? Is it a “row size” or “feature vector size”? Thank you.

@Qihan_Wang

The “batch” is what is being returned by your dataloader. This could be many things and is really specific to the model/task you are training with.

Sometimes batch is just a pure tensor, in which case len(batch) is just the first dimension of that tensor (typically referred to as the batch size).

Sometimes, or probably very often, this is actually a container like a dict, a tuple or a list. In this case, len(batch) is simply the “length” of that dict, tuple or list. Example: supervised training, batch will contain an input tensor and a label tensor.

So to answer your question: The meaning of len(batch) really depends on your task and the data loader you are using. If you are not sure, just print it and inspect what’s inside of it!

1 Like