I have a script to fine-tune a HuggingFace model that I wrote using PyLightning. I’m running into a problem where when I call trainer.fit(model, train_loader, val_loader)
the batch size in the data-loader is the batch size of the train_loader + the val_loader, which makes me believe that my validation data is being included in both training and validation. I’m not sure why this is happening? Here’s a snippet of my code:
train_data = TLDataset(train, tokenizer)
logger.info(f"Sucessfully loaded SRC training data: 10000 examples")
val_data = TLDataset(val, tokenizer)
logger.info(f"Sucessfully loaded SRC validation data: 1200 examples")
train_loader = DataLoader(train_data, batch_size=8, drop_last=True)
val_loader = DataLoader(val_data, batch_size=8) #, num_workers=num_cpus//num_gpus
tb_logger = pl_loggers.TensorBoardLogger(save_dir=f"{args.output_dir}logs/{args.file_name}_logs/")
strategy = RayStrategy(num_workers=num_gpus, use_gpu=True if num_gpus > 0 else False, find_unused_parameters=False)
es = EarlyStopping(monitor="val_loss", mode="min", patience=args.src_es_patience)
checkpoint_callback = ModelCheckpoint(monitor='val_loss', dirpath = args.output_dir, filename = args.file_name, mode="min")
val_check_interval = args.val_check_interval
model = T5FineTuner(args)
trainer = pl.Trainer(max_steps = args.src_num_train_steps, strategy=strategy, callbacks = [es, checkpoint_callback], val_check_interval=val_check_interval, logger=tb_logger, replace_sampler_ddp=False)
logger.info("Succesfully loaded model and trainer...")
# print(f'TRAINING DATA LENGTH: {len(train_data)}') # 10000 examples
# print(f"BATCH SIZE: {args.train_bsz}") # 8
# print(f'NUMBER OF BATCHES: {len(train_data)//args.train_bsz}') # 1250 batches
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
When training occurs, the progress bar shows training data = 1250 + 150 = 1400 batches and when it goes into validation it shows 150 batches. Is this expected behavior (i.e. the progress bar shows the entire number of batches for training+val and then shifts to validation only when in a validation loop)? Or am I doing something wrong?