Custom validation frequency

Hello.

To specify validation frequency we have the val_check_interval parameter in Trainer class.
I’d like to have dynamic validation frequency based on current metric score.

What I’ve tried is to create callback overriding on_train_batch_start or on_train_batch_end methods like this

class ScheduleCallback(Callback):
    
    def __init__(self, val_dl):
        self.val_dl = val_dl

    def on_train_batch_end(self, trainer, pl_module, batch, batch_idx):
        print('CALLBACK on_train_batch_end')
        if some_condition:
            trainer.validate(pl_module, self.val_dl)

but that doesn’t work.

Is there some preferable way to achieve my goal in pl?
Thank you

Hi @Algernone, could you try updating the trainer.val_check_interval value using the callback?

class ScheduleCallback(pl.Callback):

    def on_train_batch_end(self, *args, **kwargs):
        if CONDITION:
            trainer.val_check_interval = NEW_FREQUENCY

Hey @aniketmaurya
I guess the smallest value we can set up the NEW_FREQUENCY is 1.
But how could I validate my model right now if the condition is true with no any delay?

I guess to solve the problem I need to override condition used to validate model in fit_loop but
I’m not quite sure where to look at.

The condition is done here in the epoch loop (accessed via trainer.fit_loop.epoch_loop).

Note: If you want to run validation immediately, this might not be optimal as this might result in states (especially from training which might also lie on GPU) not being properly cleaned up.

1 Like

I see.
Do you see then any solutions to the problem I’ve described or the best one is just to switch to pytorch?

There are a few options to explore here:

1.) With a custom loop, you could theoretically always call validation, and batches should be cleaned up, but I am not 100% sure of this (more like 90%).

2.) When still wanting to use the “backbone” of lightning, you could give lightning_lite a try (look here). It is essentially pure pytorch but handles hardware communication, multiprocessing etc. It is not as strict as the Trainer is, but allows a few opt-in conveniences.

2 Likes

thanks, following your first advise I’ve made the following class using Callback functionality.
In case someone is interested here it is:


class ValidationScheduler(Callback):
    def __init__(self, val_schedule, metric_name, warmup_epochs_n=0):
        self.schedule = val_schedule
        self.interval = self.schedule[-1][1]
        self.last_step = 0
        self.warmup_epochs_n = warmup_epochs_n
        self.metric_name = metric_name
        self.trainer = None
        
    def on_fit_start(self, trainer, pl_module):
        self.trainer = trainer
        self.trainer.fit_loop.epoch_loop._should_check_val_fx = partial(self.step)

    def step(self):
        step = self.trainer.global_step
        epoch = self.trainer.current_epoch

        if epoch >= self.warmup_epochs_n and step >= self.interval + self.last_step:
            self.last_step = step
            return True
        else:
            return False
        
    def on_validation_end(self, trainer, pl_module):
        score = trainer.callback_metrics[self.metric_name]
        self.update_interval(score)
        
    def update_interval(self, cur_score):
        for score, interval in self.schedule:
            if cur_score < score:
                self.interval = interval
                break
1 Like