FID score in validation_epoch_end

I want to compute the FID score for two lists of images stored on disk in different paths. I call calculate_fid_given_paths from pytorch-fid in my validation_epoch_end method.

While calculate_fid_given_paths is running, I get the following error.

Traceback (most recent call last):
  File "/cw/liir/NoCsBack/testliir/rubenc/miniconda3/envs/tsenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 561, in train
    self.train_loop.run_training_epoch()
  File "/cw/liir/NoCsBack/testliir/rubenc/miniconda3/envs/tsenv/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 625, in run_training_epoch
    self.trainer.run_evaluation(on_epoch=True)
  File "/cw/liir/NoCsBack/testliir/rubenc/miniconda3/envs/tsenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 659, in run_evaluation
    deprecated_eval_results = self.evaluation_loop.evaluation_epoch_end()
  File "/cw/liir/NoCsBack/testliir/rubenc/miniconda3/envs/tsenv/lib/python3.8/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 201, in evaluation_epoch_end
    deprecated_results = self.__run_eval_epoch_end(self.num_dataloaders, using_eval_result)
  File "/cw/liir/NoCsBack/testliir/rubenc/miniconda3/envs/tsenv/lib/python3.8/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 239, in __run_eval_epoch_end
    eval_results = model.validation_epoch_end(eval_results)
  File "/cw/liir/NoCsBack/testliir/rubenc/reteco/src/modules.py", line 441, in validation_epoch_end
    fid = fid_score.calculate_fid_given_paths(
  File "/cw/liir/NoCsBack/testliir/rubenc/miniconda3/envs/tsenv/lib/python3.8/site-packages/pytorch_fid/fid_score.py", line 250, in calculate_fid_given_paths
    m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
  File "/cw/liir/NoCsBack/testliir/rubenc/miniconda3/envs/tsenv/lib/python3.8/site-packages/pytorch_fid/fid_score.py", line 234, in compute_statistics_of_path
    m, s = calculate_activation_statistics(files, model, batch_size,
  File "/cw/liir/NoCsBack/testliir/rubenc/miniconda3/envs/tsenv/lib/python3.8/site-packages/pytorch_fid/fid_score.py", line 220, in calculate_activation_statistics
    act = get_activations(files, model, batch_size, dims, device)
  File "/cw/liir/NoCsBack/testliir/rubenc/miniconda3/envs/tsenv/lib/python3.8/site-packages/pytorch_fid/fid_score.py", line 125, in get_activations
    for batch in tqdm(dataloader):
  File "/cw/liir/NoCsBack/testliir/rubenc/miniconda3/envs/tsenv/lib/python3.8/site-packages/tqdm/std.py", line 1166, in __iter__
    for obj in iterable:
  File "/cw/liir/NoCsBack/testliir/rubenc/miniconda3/envs/tsenv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 435, in __next__
    data = self._next_data()
  File "/cw/liir/NoCsBack/testliir/rubenc/miniconda3/envs/tsenv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1057, in _next_data
    self._shutdown_workers()
  File "/cw/liir/NoCsBack/testliir/rubenc/miniconda3/envs/tsenv/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1177, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/cw/liir/NoCsBack/testliir/rubenc/miniconda3/envs/tsenv/lib/python3.8/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/cw/liir/NoCsBack/testliir/rubenc/miniconda3/envs/tsenv/lib/python3.8/multiprocessing/popen_fork.py", line 44, in wait
    if not wait([self.sentinel], timeout):
  File "/cw/liir/NoCsBack/testliir/rubenc/miniconda3/envs/tsenv/lib/python3.8/multiprocessing/connection.py", line 931, in wait
    ready = selector.select(timeout)
  File "/cw/liir/NoCsBack/testliir/rubenc/miniconda3/envs/tsenv/lib/python3.8/selectors.py", line 415, in select
    fd_event_list = self._selector.poll(timeout)
  File "/cw/liir/NoCsBack/testliir/rubenc/miniconda3/envs/tsenv/lib/python3.8/site-packages/torch/utils/data/_utils/signal_handling.py", line 66, in handler
    _error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 111913) is killed by signal: Terminated. 

calculate_fid_given_paths launches a number of dataloader workers. If I change the method to use 0 workers, the error doesn’t occur, but this is not ideal. Any suggestions on how to fix this? Maybe I should launch calculate_fid_given_paths from a newly spawned process?

The issue occurs both when I’m using DDP and when not. For DDP, how do I make sure the fid isn’t calculated twice? Would wrapping the contents of validation_epoch_end in a if self.local_rank == 0: make sense?