Code review/suggestions: Progress bar for dataset with no `len()`

Hi,

I’m trying to use Lightning with a Lhotse sampler that doesn’t implement len(). I see that there’s some support for such datasets using Rich progress bars, though understandably this method doesn’t try to infer a total length.

I had the thought of trying to implement a progress bar similar to how HuggingFace’s default one behaves. In particular the total steps for the training progress bar is based on the total number of updates (which I can get from trainer.max_steps).

And for the validation progress bar, I had the idea of using the sanity check to count the number of steps since a) the validation dataset should typically be much smaller than the training dataset (a full epoch shouldn’t be too costly) and b) the dataset doesn’t get shuffled (so the steps will be the same every time).

The following is a minimal working example (also in a Colab here). I was wondering if anyone had better suggestions or foresee any pitfalls in implementing a progress bar this way.

Thanks!

MWE

# Setup
pip install lightning==2.0.3 lhotse==1.15.0
import time
import torch
import lightning.pytorch as pl

from lightning.pytorch.callbacks import ProgressBar

from lhotse import CutSet
from lhotse.recipes import download_librispeech, prepare_librispeech
from lhotse.dataset.sampling import DynamicBucketingSampler

from torch.utils.data import DataLoader

from tqdm import tqdm

class GlobalProgressBar(ProgressBar):
    def __init__(self):
        super().__init__()  # don't forget this :)
        self.enable = True
        
        self.sanity_val_check_done = False
        self.sanity_val_check_steps = 0

    def disable(self):
        self.enable = False

    def on_sanity_check_end(self, trainer, pl_module):
        self.sanity_val_check_done = True

    def on_train_start(self, trainer, pl_module):
        self.train_pbar = tqdm(total=trainer.max_steps)

    def on_train_epoch_start(self, trainer, pl_module):
        self.train_pbar.set_description_str(f"Epoch: {trainer.current_epoch}")

    def on_before_optimizer_step(self, trainer, pl_module, optimizer):
        super().on_before_optimizer_step(trainer, pl_module, optimizer)  # don't forget this :)
        self.train_pbar.update(1)

    def on_train_end(self, trainer, pl_module):
        self.train_pbar.close()

    def on_validation_start(self, trainer, pl_module):
        if not self.sanity_val_check_done:
            self.val_pbar = tqdm(desc="Running full epoch to estimate number of validation batches...")
        else:
            self.val_pbar = tqdm(desc=f"Running validation", total=self.sanity_val_check_steps)

    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx)  # don't forget this :)

        if not self.sanity_val_check_done:
            self.sanity_val_check_steps += 1
        else:
            self.val_pbar.update(1)

    def on_validation_end(self, trainer, pl_module):
        self.val_pbar.close()

class MinimalASRDataset(torch.utils.data.Dataset):
    def __getitem__(self, cuts: CutSet) -> dict:
        cuts = cuts.sort_by_duration()
        return cuts

class LibrisDataModule(pl.LightningDataModule):
    def prepare_data(self,) -> None:
        download_librispeech(dataset_parts="mini_librispeech")

    def setup(self, stage = None):
        libri = prepare_librispeech(corpus_dir="LibriSpeech", output_dir="data/")
        self.cuts_train = CutSet.from_manifests(**libri['train-clean-5'])
        self.cuts_valid = CutSet.from_manifests(**libri["dev-clean-2"])

    def train_dataloader(self):
        train_sampler = DynamicBucketingSampler(self.cuts_train, max_duration=100, shuffle=True, drop_last=True)

        return DataLoader(MinimalASRDataset(), sampler=train_sampler, batch_size=None, num_workers=1)

    def val_dataloader(self):
        valid_sampler = DynamicBucketingSampler(self.cuts_valid,  max_duration=100, shuffle=False, drop_last=True)
        return DataLoader(MinimalASRDataset(), sampler=valid_sampler, batch_size=None, num_workers=1)

class DummyModel(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(2, 2)

    def configure_optimizers(self):
        # required by Trainer, but not relevant for this test
        optimizer = torch.optim.AdamW(self.parameters(), lr=1)
        return optimizer

    def training_step(self, batch, batch_idx):
        return None

    def validation_step(self, batch, batch_idx):
        time.sleep(0.01)
        return None

MAX_UPDATES=500
GRAD_ACC=1
VAL_EVERY_N_UPDATES=100

pl.seed_everything(42)

trainer = pl.Trainer(
    accelerator="cpu",
    max_steps=MAX_UPDATES,
    accumulate_grad_batches=GRAD_ACC,
    # Prevent Lightning from replacing Lhotse's DDP-compatible sampler
    use_distributed_sampler=False,
    callbacks=[ GlobalProgressBar() ],
    # Run through entire validation set to get number of validation steps
    num_sanity_val_steps=-1,
    # Turns out val_check_interval is based on dataloader batch steps not update steps
    check_val_every_n_epoch=None,
    val_check_interval=VAL_EVERY_N_UPDATES * GRAD_ACC,
    # Disable for demo
    enable_checkpointing=False,
    enable_model_summary=False,
    logger=None,
)

trainer.fit(DummyModel(), LibrisDataModule())