Torch.no_grad() calls

Hi,

My question might be naive, but I was wondering about this scenario:

Let’s say we have the following class:

class ProtoNet(LightningModule): 
    def __init__(self, model): 
        super().__init__()
        self.model = model 
        self.save_hyperparameters()
 
    def training_step(self, batch, batch_idx): 
        imgs, targets = batch 
        ... 
        loss = nn.CrossEntropyLoss(...)(...)
        
        acc = self._calculate_accuracy(...)

        return loss

    def validation_step(self, batch, batch_idx): 
        imgs, targets = batch 
        ... 
        loss = nn.CrossEntropyLoss(...)(...)
        
        acc = self._calculate_accuracy(...)

    def _calculate_accuracy(self, ...): 
        ... 

In the function _calculate_accuracy(), do I understand it correctly that when it is called from within validation_step(), the model would be in eval mode, whereas when it is called from within training_step(), it would be in train mode?

If this is correct and I put the function decorator @torch.no_grad() for _calculate_accuracy(), would the model be in training mode again once the function is left and training is continued?

Best,
Imahn

@awaelchli could you please help out?

Hi @ImahnShekhzadeh

In the function _calculate_accuracy() , do I understand it correctly that when it is called from within validation_step() , the model would be in eval mode, whereas when it is called from within training_step() , it would be in train mode?

This is correct. The Trainer does model.train() before going into the training loop. Right before entering the validation/testing loop, it does model.eval() and when finishing validation/testing, it does model.train() again.
In addition, the Trainer adds torch.no_grad() for the validation loop, so anything in your validation_step will be already with gradients disabled (no need to add the decorator).

Please note that model.eval() has nothing to do with torch.no_grad. You seem to be confused about this given your next question:

If this is correct and I put the function decorator torch.no_grad() for _calculate_accuracy() , would the model be in training mode again once the function is left and training is continued?

torch.no_grad() disables computing gradients, but model.eval() does not. It just switches the mode of the layer to evaluation mode, which means for example in nn.Dropout(), the actual dropout is disabled during evaluation, but during training it performs the random zeroing of inputs.

So in summary:
No need to add model.eval() or torch.no_grad in your validation loop. Lightning Trainer handles it. If you want to convince your self, in your training_step or validation_step simply print the following

print(self.training)  # see if model is in training mode or eval mode
print(torch.is_grad_enabled)  # whether gradients are enabled or not here
2 Likes

Hi @awaelchli,

Thanks for the reply.

Please note that model.eval() has nothing to do with torch.no_grad . You seem to be confused about this given your next question: […]

You are right, I did indeed confuse .eval() and torch.no_grad(). So, all in all, I understand I can/should put the decorator @torch.no_grad() in front of the function calculate_accuracy(), since I do not need the gradients, and it’s probably savest to put model.eval() at the beginning and model.train() at the end of the function, since if dropout or BatchNorm are used, they should be turned off/be used in evaluation mode.

You’re welcome. Sounds good.

You can do that if you want yes. Not necessary in Lightning if you call that calculate_accuracy function from the LightningModule.