I tried to collect the batch size using lightning hook, e.g.:
def on_train_batch_start( # type: ignore[override]
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", batch: Any, batch_idx: int
) -> None:
batch_size = len(batch)
However, the batch_size
is not the actual batch size of data loading. What’s the meaning of this value? Is it a “row size” or “feature vector size”? Thank you.
@Qihan_Wang
The “batch” is what is being returned by your dataloader. This could be many things and is really specific to the model/task you are training with.
Sometimes batch is just a pure tensor, in which case len(batch)
is just the first dimension of that tensor (typically referred to as the batch size).
Sometimes, or probably very often, this is actually a container like a dict, a tuple or a list. In this case, len(batch)
is simply the “length” of that dict, tuple or list. Example: supervised training, batch will contain an input tensor and a label tensor.
So to answer your question: The meaning of len(batch)
really depends on your task and the data loader you are using. If you are not sure, just print it and inspect what’s inside of it!
1 Like