Collective mismatch at end of training epoch

I’m facing an issue where training a lightning module with DDP on >4 GPUs gets stuck at end of first training epoch (I made sure there is no validation epoch). This doesn’t occur with 2 GPUs.

I made sure that the dataset is balanced, and that the total batch size is equal to number of GPUs.

Detecting unused parameters is on. There are unused parameters (and that’s intentional).

I obtained a stack traces with TORCH_CPP_LOG_LEVEL=INFO and TORCH_DISTRIBUTED_DEBUG=DETAIL.

I’m having difficulty understanding these stack traces, since they include >10 layers of PyTorch Lightning calls, and I don’t have a good enough understanding of Lighting’s internals. Perhaps someone can glance at this and get a sense for what are the top possible causes?

Stack trace from rank 7:

Traceback (most recent call last):
  File "/workspace/solr/app/main_lightning.py", line 164, in <module>
    trainer.fit(model, datamodule=datamodule)
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/trainer.py", line 770, in fit
    self._call_and_handle_interrupt(
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/trainer.py", line 723, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/trainer.py", line 811, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/trainer.py", line 1236, in _run
    results = self._run_stage()
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/trainer.py", line 1323, in _run_stage
    return self._run_train()
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/trainer.py", line 1353, in _run_train
    self.fit_loop.run()
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/loops/base.py", line 205, in run
    self.on_advance_end()
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/loops/fit_loop.py", line 294, in on_advance_end
    self.trainer._call_callback_hooks("on_train_epoch_end")
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/trainer.py", line 1636, in _call_callback_hooks
    fn(self, self.lightning_module, *args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 308, in on_train_epoch_end
    self._save_topk_checkpoint(trainer, monitor_candidates)
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 381, in _save_topk_checkpoint
    self._save_none_monitor_checkpoint(trainer, monitor_candidates)
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 658, in _save_none_monitor_checkpoint
    filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer)
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 619, in _get_metric_interpolated_filepath_name
    while self.file_exists(filepath, trainer) and filepath != del_filepath:
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py", line 720, in file_exists
    return trainer.strategy.broadcast(exists)
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/strategies/ddp.py", line 319, in broadcast
    torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD)
  File "/usr/local/lib/python3.9/dist-packages/torch/distributed/distributed_c10d.py", line 1869, in broadcast_object_list
    broadcast(object_sizes_tensor, src=src, group=group)
  File "/usr/local/lib/python3.9/dist-packages/torch/distributed/distributed_c10d.py", line 1187, in broadcast
    work = default_pg.broadcast([tensor], opts)
RuntimeError: Detected mismatch between collectives on ranks. Rank 7 is running inconsistent collective: CollectiveFingerPrint(OpType=BROADCAST, TensorShape=[1], TensorDtypes=Long, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt))

Stack trace from rank 2 (ranks 0,1,3,4,5,6 are also similar):

Traceback (most recent call last):
  File "/workspace/solr/app/main_lightning.py", line 164, in <module>
    trainer.fit(model, datamodule=datamodule)
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/trainer.py", line 770, in fit
    self._call_and_handle_interrupt(
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/trainer.py", line 723, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/trainer.py", line 811, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/trainer.py", line 1236, in _run
    results = self._run_stage()
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/trainer.py", line 1323, in _run_stage
    return self._run_train()
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/trainer.py", line 1353, in _run_train
    self.fit_loop.run()
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/loops/base.py", line 204, in run
    self.advance(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/loops/fit_loop.py", line 266, in advance
    self._outputs = self.epoch_loop.run(self._data_fetcher)
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/loops/base.py", line 204, in run
    self.advance(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 208, in advance
    batch_output = self.batch_loop.run(batch, batch_idx)
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/loops/base.py", line 204, in run
    self.advance(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 88, in advance
    outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx)
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/loops/base.py", line 204, in run
    self.advance(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 203, in advance
    result = self._run_optimization(
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 256, in _run_optimization
    self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 369, in _optimizer_step
    self.trainer._call_lightning_module_hook(
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/trainer.py", line 1595, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/core/lightning.py", line 1646, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/core/optimizer.py", line 168, in step
    step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/strategies/ddp.py", line 286, in optimizer_step
    optimizer_output = super().optimizer_step(optimizer, opt_idx, closure, model, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/strategies/strategy.py", line 193, in optimizer_step
    return self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 155, in optimizer_step
    return optimizer.step(closure=closure, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/optim/optimizer.py", line 88, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/optim/rmsprop.py", line 96, in step
    loss = closure()
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 140, in _wrap_closure
    closure_result = closure()
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 148, in __call__
    self._result = self.closure(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 143, in closure
    self._backward_fn(step_output.closure_loss)
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 311, in backward_fn
    self.trainer._call_strategy_hook("backward", loss, optimizer, opt_idx)
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/trainer/trainer.py", line 1765, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/strategies/strategy.py", line 168, in backward
    self.precision_plugin.backward(self.lightning_module, closure_loss, *args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/plugins/precision/precision_plugin.py", line 80, in backward
    model.backward(closure_loss, optimizer, *args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/pytorch_lightning/core/lightning.py", line 1391, in backward
    loss.backward(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_tensor.py", line 363, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/usr/local/lib/python3.9/dist-packages/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Detected mismatch between collectives on ranks. Rank 2 is running inconsistent collective: CollectiveFingerPrint(OpType=ALLREDUCE, TensorShape=[283125], TensorDtypes=Float, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt))