How can I train a model using DDP on two GPUs, but only test on one GPU?

Within one function, how can I train a model using DDP on two GPUs, but only test on one GPU? (As suggested by the documentation, it is not recommended to test on two GPUs which are used with a DDP sampler.)

datamodule = MyDataModule()
model = GSDG()

monitor = 'val_acc'
logdir = osp.join(args.dump_path, f'seed{args.seed}')
checkpoint_callback = ModelCheckpoint(monitor=monitor, dirpath=logdir, filename='{epoch:02d}-{' + f'{monitor}' + ':.3f}', mode='max', save_last=True)

trainer = pl.Trainer(
    logger=TensorBoardLogger(save_dir=osp.join(logdir, f'tflog'), name=f'ep_n:{args.epochs}'),
    max_epochs=args.epochs,
    accelerator='gpu', 
    devices=2,
    sync_batchnorm=True, 
    strategy='ddp_find_unused_parameters_true',
    callbacks=[checkpoint_callback],
    default_root_dir=logdir,
)
trainer.fit(model, datamodule)  

best_ckpt_path = trainer.checkpoint_callback.best_model_path
trainer_tst = pl.Trainer(
    logger=TensorBoardLogger(save_dir=osp.join(logdir, f'tflog'), name=f'ep_n:{args.epochs}'),
    max_epochs=args.epochs,
    accelerator='gpu',  # cpu gpu
    devices=1,
    sync_batchnorm=True,  # False, 
    strategy='ddp_find_unused_parameters_true',
    default_root_dir=logdir,
)
model_tst = GSDG().load_from_checkpoint(best_ckpt_path)
trainer.test(model, datamodule)

In my multiple runs, this script seems to start TWO single-GPU test processes on each GPU, because the time consumption is much longer than training and testing on a single GPU, and ‘nvidia-smi’ shows duplicated python GPU tasks.

Then I tried explicitly training on 2 GPUs and testing only on Rank-0:

datamodule = MyDataModule()
model = GSDG()

monitor = 'val_acc'
logdir = osp.join(args.dump_path, f'seed{args.seed}')
checkpoint_callback = ModelCheckpoint(monitor=monitor, dirpath=logdir, filename='{epoch:02d}-{' + f'{monitor}' + ':.3f}', mode='max', save_last=True)

trainer = pl.Trainer(
    logger=TensorBoardLogger(save_dir=osp.join(logdir, f'tflog'), name=f'ep_n:{args.epochs}'),
    max_epochs=args.epochs,
    accelerator='gpu', 
    devices=2,
    sync_batchnorm=True, 
    strategy='ddp_find_unused_parameters_true',
    callbacks=[checkpoint_callback],
    default_root_dir=logdir,
)
trainer.fit(model, datamodule)  

best_ckpt_path = trainer.checkpoint_callback.best_model_path  

if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
	trainer_tst = pl.Trainer(
        logger=TensorBoardLogger(save_dir=osp.join(logdir, f'tflog'), name=f'ep_n:{args.epochs}'),
        max_epochs=args.epochs,
        accelerator='gpu',  # cpu gpu
        devices=1,
        sync_batchnorm=True,  # False, 
       strategy='ddp_find_unused_parameters_true',
       default_root_dir=logdir,
   )
    model_tst = GSDG().load_from_checkpoint(best_ckpt_path)
    trainer.test(model, datamodule)

However, the code gets stuck when it reaches 100% during testing.

For practical reasons, if it is possible for you, I would suggest to separate training and testing into separate scripts. The training would run in a mult-gpu distributed fashion, save checkpoints and then independently of that, after or during training, you could then kick off the testing script on a single GPU.

This is what I personally would do and found useful in the past.

Note. This pattern of

trainer.fit()
if trainer.global_rank == 0:
    new_trainer = Trainer(devices=1)
    new_trainer.test()

will not work correctly I think, because the new trainer will still probably see the previous process group with world size N. You would either destroy the process group before that, or probably remove the strategy argument from the Trainer.

Thanks for your response.

1 Like

Hi Adrian. Can you please elaborate you statement about removing strategy (I think this could be easier than the first one)? I fitted my model on 2 GPUs and would like to use its prediction (on 1 GPU) immediately to feed another model. I saw some weird behaviors and felt exhausted to deal with them.

I have exactly the same problem! My codes get stuck when it reaches 100% during testing as you did! Do you find an alternative way to do so?