Hi,
I’m trying to use Lightning with a Lhotse sampler that doesn’t implement len()
. I see that there’s some support for such datasets using Rich progress bars, though understandably this method doesn’t try to infer a total length.
I had the thought of trying to implement a progress bar similar to how HuggingFace’s default one behaves. In particular the total steps for the training progress bar is based on the total number of updates (which I can get from trainer.max_steps
).
And for the validation progress bar, I had the idea of using the sanity check to count the number of steps since a) the validation dataset should typically be much smaller than the training dataset (a full epoch shouldn’t be too costly) and b) the dataset doesn’t get shuffled (so the steps will be the same every time).
The following is a minimal working example (also in a Colab here). I was wondering if anyone had better suggestions or foresee any pitfalls in implementing a progress bar this way.
Thanks!
MWE
# Setup
pip install lightning==2.0.3 lhotse==1.15.0
import time
import torch
import lightning.pytorch as pl
from lightning.pytorch.callbacks import ProgressBar
from lhotse import CutSet
from lhotse.recipes import download_librispeech, prepare_librispeech
from lhotse.dataset.sampling import DynamicBucketingSampler
from torch.utils.data import DataLoader
from tqdm import tqdm
class GlobalProgressBar(ProgressBar):
def __init__(self):
super().__init__() # don't forget this :)
self.enable = True
self.sanity_val_check_done = False
self.sanity_val_check_steps = 0
def disable(self):
self.enable = False
def on_sanity_check_end(self, trainer, pl_module):
self.sanity_val_check_done = True
def on_train_start(self, trainer, pl_module):
self.train_pbar = tqdm(total=trainer.max_steps)
def on_train_epoch_start(self, trainer, pl_module):
self.train_pbar.set_description_str(f"Epoch: {trainer.current_epoch}")
def on_before_optimizer_step(self, trainer, pl_module, optimizer):
super().on_before_optimizer_step(trainer, pl_module, optimizer) # don't forget this :)
self.train_pbar.update(1)
def on_train_end(self, trainer, pl_module):
self.train_pbar.close()
def on_validation_start(self, trainer, pl_module):
if not self.sanity_val_check_done:
self.val_pbar = tqdm(desc="Running full epoch to estimate number of validation batches...")
else:
self.val_pbar = tqdm(desc=f"Running validation", total=self.sanity_val_check_steps)
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx) # don't forget this :)
if not self.sanity_val_check_done:
self.sanity_val_check_steps += 1
else:
self.val_pbar.update(1)
def on_validation_end(self, trainer, pl_module):
self.val_pbar.close()
class MinimalASRDataset(torch.utils.data.Dataset):
def __getitem__(self, cuts: CutSet) -> dict:
cuts = cuts.sort_by_duration()
return cuts
class LibrisDataModule(pl.LightningDataModule):
def prepare_data(self,) -> None:
download_librispeech(dataset_parts="mini_librispeech")
def setup(self, stage = None):
libri = prepare_librispeech(corpus_dir="LibriSpeech", output_dir="data/")
self.cuts_train = CutSet.from_manifests(**libri['train-clean-5'])
self.cuts_valid = CutSet.from_manifests(**libri["dev-clean-2"])
def train_dataloader(self):
train_sampler = DynamicBucketingSampler(self.cuts_train, max_duration=100, shuffle=True, drop_last=True)
return DataLoader(MinimalASRDataset(), sampler=train_sampler, batch_size=None, num_workers=1)
def val_dataloader(self):
valid_sampler = DynamicBucketingSampler(self.cuts_valid, max_duration=100, shuffle=False, drop_last=True)
return DataLoader(MinimalASRDataset(), sampler=valid_sampler, batch_size=None, num_workers=1)
class DummyModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(2, 2)
def configure_optimizers(self):
# required by Trainer, but not relevant for this test
optimizer = torch.optim.AdamW(self.parameters(), lr=1)
return optimizer
def training_step(self, batch, batch_idx):
return None
def validation_step(self, batch, batch_idx):
time.sleep(0.01)
return None
MAX_UPDATES=500
GRAD_ACC=1
VAL_EVERY_N_UPDATES=100
pl.seed_everything(42)
trainer = pl.Trainer(
accelerator="cpu",
max_steps=MAX_UPDATES,
accumulate_grad_batches=GRAD_ACC,
# Prevent Lightning from replacing Lhotse's DDP-compatible sampler
use_distributed_sampler=False,
callbacks=[ GlobalProgressBar() ],
# Run through entire validation set to get number of validation steps
num_sanity_val_steps=-1,
# Turns out val_check_interval is based on dataloader batch steps not update steps
check_val_every_n_epoch=None,
val_check_interval=VAL_EVERY_N_UPDATES * GRAD_ACC,
# Disable for demo
enable_checkpointing=False,
enable_model_summary=False,
logger=None,
)
trainer.fit(DummyModel(), LibrisDataModule())