Hello, I found that the validation interval would drift? For example, I set val_check_interval to 15, and accumulate to 3 (I want to validate after every 5 effective batches). I have 67 batches, so there will be one batch left at the end (which is not dropped, but stepped the optimizer and increased global_step). The global_step when the validation happens will drift away. For example, Followed the progress bar is a call to ModelCheckpoint’s on_validation_end
, which I intercepted and printed the global step and the step I logged in training_step
.
Epoch 0: 22% 15/67 [00:00<00:01, 33.66it/s, v_num=8, train_loss=6.66e+3, step=4.000, val_loss=9.9e+3
global_step: 5 monitor_candidates: {'train_loss': tensor(6658.2930), 'step': tensor(4, dtype=torch.int32), 'val_loss': tensor(9903.5596), 'epoch': tensor(0)}
Epoch 0: 45% 30/67 [00:00<00:01, 33.94it/s, v_num=8, train_loss=1.04e+4, step=9.000, val_loss=9.86e+3
global_step: 10 monitor_candidates: {'train_loss': tensor(10375.1748), 'step': tensor(9, dtype=torch.int32), 'val_loss': tensor(9856.7959), 'epoch': tensor(0)}
Epoch 0: 67% 45/67 [00:01<00:00, 34.05it/s, v_num=8, train_loss=1.24e+4, step=14.00, val_loss=9.81e+3
global_step: 15 monitor_candidates: {'train_loss': tensor(12416.3926), 'step': tensor(14, dtype=torch.int32), 'val_loss': tensor(9808.8018), 'epoch': tensor(0)}
Epoch 0: 90% 60/67 [00:01<00:00, 34.10it/s, v_num=8, train_loss=7.05e+3, step=19.00, val_loss=9.76e+3
global_step: 20 monitor_candidates: {'train_loss': tensor(7053.8413), 'step': tensor(19, dtype=torch.int32), 'val_loss': tensor(9761.5410), 'epoch': tensor(0)}
Epoch 1: 12% 8/67 [00:00<00:02, 29.09it/s, v_num=8, train_loss=6.08e+3, step=25.00, val_loss=9.72e+3
global_step: 25 monitor_candidates: {'train_loss': tensor(6082.7915), 'step': tensor(25, dtype=torch.int32), 'val_loss': tensor(9715.5293), 'epoch': tensor(1)}
Epoch 1: 34% 23/67 [00:00<00:01, 32.37it/s, v_num=8, train_loss=8.69e+3, step=30.00, val_loss=9.67e+3
global_step: 30 monitor_candidates: {'train_loss': tensor(8689.9600), 'step': tensor(30, dtype=torch.int32), 'val_loss': tensor(9671.3135), 'epoch': tensor(1)}
Epoch 1: 57% 38/67 [00:01<00:00, 33.18it/s, v_num=8, train_loss=1.04e+4, step=35.00, val_loss=9.63e+3
global_step: 35 monitor_candidates: {'train_loss': tensor(10446.6523), 'step': tensor(35, dtype=torch.int32), 'val_loss': tensor(9628.3018), 'epoch': tensor(1)}
Epoch 1: 79% 53/67 [00:01<00:00, 33.51it/s, v_num=8, train_loss=6.14e+3, step=40.00, val_loss=9.59e+3
global_step: 40 monitor_candidates: {'train_loss': tensor(6136.0024), 'step': tensor(40, dtype=torch.int32), 'val_loss': tensor(9585.5215), 'epoch': tensor(1)}
Epoch 2: 1% 1/67 [00:00<00:07, 8.82it/s, v_num=8, train_loss=6.51e+3, step=46.00, val_loss=9.54e+3
global_step: 46 monitor_candidates: {'train_loss': tensor(6509.6055), 'step': tensor(46, dtype=torch.int32), 'val_loss': tensor(9535.0029), 'epoch': tensor(2)}
Epoch 2: 24% 16/67 [00:00<00:01, 29.17it/s, v_num=8, train_loss=9.43e+3, step=51.00, val_loss=9.49e+3
global_step: 51 monitor_candidates: {'train_loss': tensor(9433.3418), 'step': tensor(51, dtype=torch.int32), 'val_loss': tensor(9493.0371), 'epoch': tensor(2)}
Epoch 2: 46% 31/67 [00:00<00:01, 31.55it/s, v_num=8, train_loss=8.94e+3, step=56.00, val_loss=9.45e+3
global_step: 56 monitor_candidates: {'train_loss': tensor(8937.6436), 'step': tensor(56, dtype=torch.int32), 'val_loss': tensor(9449.7881), 'epoch': tensor(2)}
Epoch 2: 69% 46/67 [00:01<00:00, 32.45it/s, v_num=8, train_loss=8.57e+3, step=61.00, val_loss=9.41e+3
global_step: 61 monitor_candidates: {'train_loss': tensor(8565.3506), 'step': tensor(61, dtype=torch.int32), 'val_loss': tensor(9406.0615), 'epoch': tensor(2)}
Epoch 2: 91% 61/67 [00:01<00:00, 32.92it/s, v_num=8, train_loss=8.15e+3, step=66.00, val_loss=9.36e+3
global_step: 66 monitor_candidates: {'train_loss': tensor(8152.2539), 'step': tensor(66, dtype=torch.int32), 'val_loss': tensor(9362.8193), 'epoch': tensor(2)}
Epoch 3: 13% 9/67 [00:00<00:01, 30.31it/s, v_num=8, train_loss=5.61e+3, step=71.00, val_loss=9.31e+3
global_step: 72 monitor_candidates: {'train_loss': tensor(5610.6045), 'step': tensor(71, dtype=torch.int32), 'val_loss': tensor(9311.8271), 'epoch': tensor(3)}
Both the logged global_step from training_step
, and the immediate trainer.global_step in the model checkpoint callback drifts