`self.all_gather` used in `on_training_epoch_end` reports `RuntimeError`

Description

In my training module, I use a sklearn.preprocessing.StandardScaler to record the statistic info of training data incrementally.

When I used deepspeed_stage_2 strategy to fit the model on 2 gpu ranks, the ranks reported different statistical results due to different input batches. I want to reduce them into one result after each epoch so that shared by both ranks, so I used self.all_gather() in on_train_epoch_end() but it reported RuntimeError.

Firstly it reports Rank 0 is running collective: CollectiveFingerPrint(OpType=BARRIER..., but Rank 1 is running collective: CollectiveFingerPrint(OpType=ALLGATHER_COALESCED). followed by Rank 1: OpType=ALLGATHER, but Rank 0: OpType=ALLREDUCE_COALESCED

Then I add self.trainer.strategy.barrier() to synchronize processes, but now it reports Rank 1: OpType=ALLGATHER, but Rank 0: OpType=_ALLGATHER_BASE, and Rank 0: OpType=ALLREDUCE, but Rank 1: OpType=ALLGATHER_COALESCED

What version are you seeing the problem on?

v2.1

How to reproduce the bug

class MyModule(pl.LightningModule)
    def __init__(self):
        self.train_data_scaler = StandardScaler()
        self.training_mean = np.array([0.])

    ...

    def on_train_epoch_end(self) -> None:
        self.trainer.strategy.barrier()    ## For synchronization
        print("Start gathering scaler...")

        ### Error at this step!
        mean_all = self.all_gather(torch.from_numpy(self.train_data_scaler.mean_))
        ... (some calculation)
        self.training_mean = torch.sum(mean_all)  # just an example, actually more complex

Error messages and logs

With self.trainer.strategy.barrier()

Traceback (most recent call last):
  File "/home/qianrt/project/training/train.py", line 632, in <module>
    main()
  File "/home/qianrt/project/training/train.py", line 628, in main
    train_model(args)
  File "/home/qianrt/project/training/train.py", line 596, in train_model
    trainer.fit(model, ckpt_path=weight_ckpt)
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 544, in fit
    call._call_and_handle_interrupt(
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 43, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 102, in launch
    return function(*args, **kwargs)
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 580, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 989, in _run
    results = self._run_stage()
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1035, in _run_stage
    self.fit_loop.run()
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py", line 203, in run
    self.on_advance_end()
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py", line 373, in on_advance_end
    call._call_lightning_module_hook(trainer, "on_train_epoch_end")
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 157, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/home/qianrt/project/training/train.py", line 338, in on_train_epoch_end
    mean_all = self.all_gather(torch.from_numpy(self.train_data_scaler.mean_))
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/lightning/pytorch/core/module.py", line 664, in all_gather
    return apply_to_collection(data, Tensor, all_gather, group=group, sync_grads=sync_grads)
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/lightning_utilities/core/apply_func.py", line 64, in apply_to_collection
    return function(data, *args, **kwargs)
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/lightning/pytorch/strategies/parallel.py", line 87, in all_gather
    return _all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/lightning/fabric/utilities/distributed.py", line 254, in _all_gather_ddp_if_available
    gathered_tensors = all_gather(tensor, group)
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/torch/distributed/nn/functional.py", line 116, in all_gather
    return _AllGather.apply(group, tensor)
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/torch/distributed/nn/functional.py", line 325, in forward
    dist.all_gather(out_tensor_list, tensor, group=group)
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 1451, in wrapper
    return func(*args, **kwargs)
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 2450, in all_gather
Traceback (most recent call last):
  File "/home/qianrt/project/training/train.py", line 632, in <module>
    work = group.allgather([tensor_list], [tensor])
RuntimeError: Detected mismatch between collectives on ranks. Rank 1 is running collective: CollectiveFingerPrint(OpType=ALLGATHER, TensorShape=[6], TensorDtypes=Double, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt))), but Rank 0 is running collective: CollectiveFingerPrint(OpType=_ALLGATHER_BASE).
    main()
  File "/home/qianrt/project/training/train.py", line 628, in main
    train_model(args)
  File "/home/qianrt/project/training/train.py", line 596, in train_model
    trainer.fit(model, ckpt_path=weight_ckpt)
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 544, in fit
    call._call_and_handle_interrupt(
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 43, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 102, in launch
    return function(*args, **kwargs)
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 580, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 989, in _run
    results = self._run_stage()
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1035, in _run_stage
    self.fit_loop.run()
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py", line 203, in run
    self.on_advance_end()
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py", line 372, in on_advance_end
    call._call_callback_hooks(trainer, "on_train_epoch_end", monitoring_callbacks=False)
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 208, in _call_callback_hooks
    fn(trainer, trainer.lightning_module, *args, **kwargs)
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/lightning/pytorch/callbacks/progress/tqdm_progress.py", line 271, in on_train_epoch_end
    self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/lightning/pytorch/callbacks/progress/progress_bar.py", line 195, in get_metrics
    pbar_metrics = trainer.progress_bar_metrics
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1653, in progress_bar_metrics
    return self._logger_connector.progress_bar_metrics
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py", line 245, in progress_bar_metrics
    metrics = self.metrics["pbar"]
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py", line 226, in metrics
    return self.trainer._results.metrics(on_step)
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py", line 475, in metrics
    value = self._get_cache(result_metric, on_step)
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py", line 439, in _get_cache
    result_metric.compute()
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py", line 284, in wrapped_func
    self._computed = compute(*args, **kwargs)
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py", line 247, in compute
    value = self.meta.sync(self.value.clone())  # `clone` because `sync` is in-place
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/lightning/pytorch/strategies/ddp.py", line 332, in reduce
    return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/lightning/fabric/utilities/distributed.py", line 171, in _sync_ddp_if_available
    return _sync_ddp(result, group=group, reduce_op=reduce_op)
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/lightning/fabric/utilities/distributed.py", line 221, in _sync_ddp
    torch.distributed.all_reduce(result, op=op, group=group, async_op=False)
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 1451, in wrapper
    return func(*args, **kwargs)
  File "/home/qianrt/soft/anaconda/build/envs/md_ml/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 1702, in all_reduce
    work = group.allreduce([tensor], opts)
RuntimeError: Detected mismatch between collectives on ranks. Rank 0 is running collective: CollectiveFingerPrint(OpType=ALLREDUCE, TensorShape=[], TensorDtypes=Float, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt))), but Rank 1 is running collective: CollectiveFingerPrint(OpType=ALLGATHER_COALESCED).

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): LightningModule
#- PyTorch Lightning Version (e.g., 1.5.0): 2.1.3
#- PyTorch Version (e.g., 2.0): 2.0.1
#- Python version (e.g., 3.9): 3.10.12
#- OS (e.g., Linux): Linux
#- CUDA/cuDNN version: cuda 11.7 + cudnn 8.5
#- GPU models and configuration: NVIDIA A40
#- How you installed Lightning(`conda`, `pip`, source): conda