Why can cuda still be not intialized after calling trainer.fit() with ddp_fork

I created a trainer with ddp_fork as the strategy. I printed torch.cuda.is_initialized() before and after calling trainer.fit(), both of which return False. I am curious what mechanism makes cuda still uninitialized after calling trainer.fit(). Do you clean all the cuda processes at the end of ddp_fork, or do you use some magic operation to uninitialize cuda? I want to know the details because I am implementing something that can make sure cuda becomes not initialized after every inference.

Thanks,

@zhiqiangdon It is pretty simple. If cuda gets initialized in the worker processes, then they won’t affect the state in the main process and so cuda remains uninitialized. We use a simple call to torch.multiprocessing.start_processes inside trainer.fit() to launch as many processes as needed (as there are GPUs).

Thanks @awaelchli for the quick response! So, that means the main process just waits there doing nothing during training? Is this (https://github.com/Lightning-AI/lightning/blob/6b7abda5a3be5ee3204c8e1ca5a8cd644e768ba4/src/lightning/fabric/strategies/launchers/multiprocessing.py#L113) the source code you are referring to?

Yes, that’s correct.

@awaelchli , I did some toy experiments by referring to that code:

import torch
from torch import nn
import os
import torch.distributed as dist
import torch.multiprocessing as mp

def toy_func(rank, world_size):
    return torch.tensor(2) * rank
    # return rank

def launch(start_method, function, nprocs):
    context = mp.get_context(start_method)
    return_queue = context.SimpleQueue()
    kwargs = dict(world_size=nprocs)

    process_args = [function, kwargs, return_queue]

    mp.start_processes(
        wrapping_function,
        args=process_args,
        nprocs=nprocs,
        start_method=start_method,
        join=True,
    )
    results = []
    for i in range(nprocs):
        results.append(return_queue.get())
    return results


def wrapping_function(
    process_idx: int,
    function,
    kwargs,
    return_queue,
) -> None:
    kwargs.update({"rank": process_idx})
    results = function(**kwargs)
    return_queue.put(results)

def main():
    results = launch(start_method="fork", function=toy_func, nprocs=2)
    print(f"results: {results}")

if __name__ == '__main__':
    main()

However, I encountered the error:

Traceback (most recent call last):                                                                  
  File "test_1.py", line 46, in <module>                                                            
    main()                                                                                          
  File "test_1.py", line 42, in main                                                                
    results = launch(start_method="fork", function=toy_func, nprocs=2)                              
  File "test_1.py", line 27, in launch                                                              
    results.append(return_queue.get())                                                              
  File "/home/ubuntu/anaconda3/envs/ag-dev-3/lib/python3.8/multiprocessing/queues.py", line 358, in get
    return _ForkingPickler.loads(res)                                                                                      
  File "/home/ubuntu/anaconda3/envs/ag-dev-3/lib/python3.8/site-packages/torch/multiprocessing/reductions.py", line 307, in rebuild_storage_fd
    fd = df.detach()                                                                                           
  File "/home/ubuntu/anaconda3/envs/ag-dev-3/lib/python3.8/multiprocessing/resource_sharer.py", line 57, in detach
    with _resource_sharer.get_connection(self._id) as conn:                                         
  File "/home/ubuntu/anaconda3/envs/ag-dev-3/lib/python3.8/multiprocessing/resource_sharer.py", line 87, in get_connection
    c = Client(address, authkey=process.current_process().authkey)                                  
  File "/home/ubuntu/anaconda3/envs/ag-dev-3/lib/python3.8/multiprocessing/connection.py", line 502, in Client
    c = SocketClient(address)                                                                       
  File "/home/ubuntu/anaconda3/envs/ag-dev-3/lib/python3.8/multiprocessing/connection.py", line 630, in SocketClient
    s.connect(address)                                                                              
FileNotFoundError: [Errno 2] No such file or directory

If the toy_func just returns the rank, instead of a torch tensor, the error disappears. I think this error is related to transferring torch tensors. Any hint how to resolve it? Thanks!