How to use rank_zero_only inside the function

Hi, I understand I can use @rank_zero_only to wrap a whole function, but is it possible to only run part of the function on rank 0 and the rest on all ranks? like:

    def on_train_batch_start(self, batch, batch_idx):
        with rank_zero_only:
            ...
       # on all ranks
        self.examples = batch