Number of steps drifts for `val_check_interval` when gradient accumulation turned on

Hello, I found that the validation interval would drift? For example, I set val_check_interval to 15, and accumulate to 3 (I want to validate after every 5 effective batches). I have 67 batches, so there will be one batch left at the end (which is not dropped, but stepped the optimizer and increased global_step). The global_step when the validation happens will drift away. For example, Followed the progress bar is a call to ModelCheckpoint’s on_validation_end, which I intercepted and printed the global step and the step I logged in training_step.

Epoch 0:  22%  15/67 [00:00<00:01, 33.66it/s, v_num=8, train_loss=6.66e+3, step=4.000, val_loss=9.9e+3
global_step:  5 monitor_candidates: {'train_loss': tensor(6658.2930), 'step': tensor(4, dtype=torch.int32), 'val_loss': tensor(9903.5596), 'epoch': tensor(0)} 
Epoch 0:  45%  30/67 [00:00<00:01, 33.94it/s, v_num=8, train_loss=1.04e+4, step=9.000, val_loss=9.86e+3
global_step:  10 monitor_candidates: {'train_loss': tensor(10375.1748), 'step': tensor(9, dtype=torch.int32), 'val_loss': tensor(9856.7959), 'epoch': tensor(0)}
Epoch 0:  67%  45/67 [00:01<00:00, 34.05it/s, v_num=8, train_loss=1.24e+4, step=14.00, val_loss=9.81e+3
global_step:  15 monitor_candidates: {'train_loss': tensor(12416.3926), 'step': tensor(14, dtype=torch.int32), 'val_loss': tensor(9808.8018), 'epoch': tensor(0)}
Epoch 0:  90%  60/67 [00:01<00:00, 34.10it/s, v_num=8, train_loss=7.05e+3, step=19.00, val_loss=9.76e+3
global_step:  20 monitor_candidates: {'train_loss': tensor(7053.8413), 'step': tensor(19, dtype=torch.int32), 'val_loss': tensor(9761.5410), 'epoch': tensor(0)}

Epoch 1:  12%  8/67 [00:00<00:02, 29.09it/s, v_num=8, train_loss=6.08e+3, step=25.00, val_loss=9.72e+3
global_step:  25 monitor_candidates: {'train_loss': tensor(6082.7915), 'step': tensor(25, dtype=torch.int32), 'val_loss': tensor(9715.5293), 'epoch': tensor(1)}
Epoch 1:  34%  23/67 [00:00<00:01, 32.37it/s, v_num=8, train_loss=8.69e+3, step=30.00, val_loss=9.67e+3
global_step:  30 monitor_candidates: {'train_loss': tensor(8689.9600), 'step': tensor(30, dtype=torch.int32), 'val_loss': tensor(9671.3135), 'epoch': tensor(1)}
Epoch 1:  57%  38/67 [00:01<00:00, 33.18it/s, v_num=8, train_loss=1.04e+4, step=35.00, val_loss=9.63e+3
global_step:  35 monitor_candidates: {'train_loss': tensor(10446.6523), 'step': tensor(35, dtype=torch.int32), 'val_loss': tensor(9628.3018), 'epoch': tensor(1)}
Epoch 1:  79%  53/67 [00:01<00:00, 33.51it/s, v_num=8, train_loss=6.14e+3, step=40.00, val_loss=9.59e+3
global_step:  40 monitor_candidates: {'train_loss': tensor(6136.0024), 'step': tensor(40, dtype=torch.int32), 'val_loss': tensor(9585.5215), 'epoch': tensor(1)} 

Epoch 2:   1%  1/67 [00:00<00:07,  8.82it/s, v_num=8, train_loss=6.51e+3, step=46.00, val_loss=9.54e+3
global_step:  46 monitor_candidates: {'train_loss': tensor(6509.6055), 'step': tensor(46, dtype=torch.int32), 'val_loss': tensor(9535.0029), 'epoch': tensor(2)}
Epoch 2:  24%  16/67 [00:00<00:01, 29.17it/s, v_num=8, train_loss=9.43e+3, step=51.00, val_loss=9.49e+3
global_step:  51 monitor_candidates: {'train_loss': tensor(9433.3418), 'step': tensor(51, dtype=torch.int32), 'val_loss': tensor(9493.0371), 'epoch': tensor(2)}
Epoch 2:  46%  31/67 [00:00<00:01, 31.55it/s, v_num=8, train_loss=8.94e+3, step=56.00, val_loss=9.45e+3
global_step:  56 monitor_candidates: {'train_loss': tensor(8937.6436), 'step': tensor(56, dtype=torch.int32), 'val_loss': tensor(9449.7881), 'epoch': tensor(2)}
Epoch 2:  69%  46/67 [00:01<00:00, 32.45it/s, v_num=8, train_loss=8.57e+3, step=61.00, val_loss=9.41e+3
global_step:  61 monitor_candidates: {'train_loss': tensor(8565.3506), 'step': tensor(61, dtype=torch.int32), 'val_loss': tensor(9406.0615), 'epoch': tensor(2)} 
Epoch 2:  91%  61/67 [00:01<00:00, 32.92it/s, v_num=8, train_loss=8.15e+3, step=66.00, val_loss=9.36e+3
global_step:  66 monitor_candidates: {'train_loss': tensor(8152.2539), 'step': tensor(66, dtype=torch.int32), 'val_loss': tensor(9362.8193), 'epoch': tensor(2)}

Epoch 3:  13%  9/67 [00:00<00:01, 30.31it/s, v_num=8, train_loss=5.61e+3, step=71.00, val_loss=9.31e+3
global_step:  72 monitor_candidates: {'train_loss': tensor(5610.6045), 'step': tensor(71, dtype=torch.int32), 'val_loss': tensor(9311.8271), 'epoch': tensor(3)}

Both the logged global_step from training_step, and the immediate trainer.global_step in the model checkpoint callback drifts