To specify validation frequency we have the val_check_interval parameter in Trainer class.
I’d like to have dynamic validation frequency based on current metric score.
What I’ve tried is to create callback overriding on_train_batch_start or on_train_batch_end methods like this
def __init__(self, val_dl):
self.val_dl = val_dl
def on_train_batch_end(self, trainer, pl_module, batch, batch_idx):
but that doesn’t work.
Is there some preferable way to achieve my goal in pl?
Hi @Algernone, could you try updating the
trainer.val_check_interval value using the callback?
def on_train_batch_end(self, *args, **kwargs):
trainer.val_check_interval = NEW_FREQUENCY
I guess the smallest value we can set up the NEW_FREQUENCY is 1.
But how could I validate my model right now if the condition is true with no any delay?
I guess to solve the problem I need to override condition used to validate model in fit_loop but
I’m not quite sure where to look at.
The condition is done here in the epoch loop (accessed via
Note: If you want to run validation immediately, this might not be optimal as this might result in states (especially from training which might also lie on GPU) not being properly cleaned up.
Do you see then any solutions to the problem I’ve described or the best one is just to switch to pytorch?
There are a few options to explore here:
1.) With a custom loop, you could theoretically always call validation, and batches should be cleaned up, but I am not 100% sure of this (more like 90%).
2.) When still wanting to use the “backbone” of lightning, you could give
lightning_lite a try (look here). It is essentially pure pytorch but handles hardware communication, multiprocessing etc. It is not as strict as the Trainer is, but allows a few opt-in conveniences.
thanks, following your first advise I’ve made the following class using Callback functionality.
In case someone is interested here it is:
def __init__(self, val_schedule, metric_name, warmup_epochs_n=0):
self.schedule = val_schedule
self.interval = self.schedule[-1]
self.last_step = 0
self.warmup_epochs_n = warmup_epochs_n
self.metric_name = metric_name
self.trainer = None
def on_fit_start(self, trainer, pl_module):
self.trainer = trainer
self.trainer.fit_loop.epoch_loop._should_check_val_fx = partial(self.step)
step = self.trainer.global_step
epoch = self.trainer.current_epoch
if epoch >= self.warmup_epochs_n and step >= self.interval + self.last_step:
self.last_step = step
def on_validation_end(self, trainer, pl_module):
score = trainer.callback_metrics[self.metric_name]
def update_interval(self, cur_score):
for score, interval in self.schedule:
if cur_score < score:
self.interval = interval