Hi everyone,
I’m trying to train a model on my university’s HPC. It has plenty of GPUs (each with 32 GB RAM). I ran it with 2 GPUs, but I’m still getting the dreaded CUDA out of memory
error (after being in the queue for quite a while, annoyingly).
My model is a 3D UNet that takes on 4x128x128x128 input. My batch size is already 1. The problem is that I’m replacing the conv layers with tensor networks to reduce the number of calculations, but that this (somewhat ironically) blows up my memory demand due to the unfold
operations I’m using to achieve that.
These are the parameters I’m using with the trainer.
# Initialize trainer
log("Initializing trainer")
trainer = Trainer(
max_epochs=200,
logger=tb_logger,
gpus=-1,
deterministic=True,
distributed_backend='ddp',
callbacks=[
LearningRateMonitor(logging_interval="step"),
PrintTableMetricsCallback(),
],
)
My question: how can I make better use of the GPU RAM? It should be a combined 64GB, but this output (see below) gives me the impression that the demand is not appropriately distributed.
PS: I just now added a plugins='ddp_sharded
parameter (having installed fairscale in my venv as well), but I fear that won’t be enough. Still in the queue though, will update once it runs.
Global seed set to 616
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Global seed set to 616
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/2
Global seed set to 616
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Global seed set to 616
initializing ddp: GLOBAL_RANK: 1, MEMBER: 2/2
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Set SLURM handle signals.
Set SLURM handle signals.
| Name | Type | Params
-------------------------------------
0 | net | LowRankUNet | 16.7 M
-------------------------------------
16.7 M Trainable params
0 Non-trainable params
16.7 M Total params
66.900 Total estimated model params size (MB)
Default upsampling behavior when mode=trilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 32 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
Default upsampling behavior when mode=trilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
Traceback (most recent call last):
File "/kyukon/data/gent/417/vsc41768/airhead/train_lightweight.py", line 130, in <module>
trainer.fit(model=model,
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 458, in fit
self._run(model)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 756, in _run
self.dispatch()
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 797, in dispatch
self.accelerator.start_training(self)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 96, in start_training
self.training_type_plugin.start_training(trainer)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 144, in start_training
self._results = trainer.run_stage()
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 807, in run_stage
return self.run_train()
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 842, in run_train
self.run_sanity_check(self.lightning_module)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1107, in run_sanity_check
self.run_evaluation()
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 962, in run_evaluation
output = self.evaluation_loop.evaluation_step(batch, batch_idx, dataloader_idx)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 174, in evaluation_step
output = self.trainer.accelerator.validation_step(args)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 226, in validation_step
return self.training_type_plugin.validation_step(*args)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ddp.py", line 322, in validation_step
return self.model(*args, **kwargs)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 705, in forward
output = self.module(*inputs[0], **kwargs[0])
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 57, in forward
output = self.module.validation_step(*inputs, **kwargs)
File "/kyukon/data/gent/417/vsc41768/airhead/training/lightning.py", line 160, in validation_step
y_hat = self.inference(x, self, **self.inference_params)
File "/kyukon/data/gent/417/vsc41768/airhead/training/inference.py", line 42, in val_inference
output = model(input)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/kyukon/data/gent/417/vsc41768/airhead/training/lightning.py", line 98, in forward
return self.net(x)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/kyukon/data/gent/417/vsc41768/airhead/models/lightweight_unet.py", line 130, in forward
enc_1 = self.enc_1(input)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/kyukon/data/gent/417/vsc41768/airhead/layers/lightweight_conv.py", line 589, in forward
return self.block(input)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/modules/container.py", line 119, in forward
input = module(input)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/kyukon/data/gent/417/vsc41768/airhead/layers/lightweight_conv.py", line 517, in forward
output = self.einsum_expression(patches, *weights)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/opt_einsum/contract.py", line 763, in __call__
return self._contract(ops, out, backend, evaluate_constants=evaluate_constants)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/opt_einsum/contract.py", line 693, in _contract
return _core_contract(list(arrays),
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/opt_einsum/contract.py", line 573, in _core_contract
new_view = _tensordot(*tmp_operands, axes=(tuple(left_pos), tuple(right_pos)), backend=backend)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/opt_einsum/sharing.py", line 131, in cached_tensordot
return tensordot(x, y, axes, backend=backend)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/opt_einsum/contract.py", line 374, in _tensordot
return fn(x, y, axes=axes)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/opt_einsum/backends/torch.py", line 54, in tensordot
return torch.tensordot(x, y, dims=axes)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/functional.py", line 1002, in tensordot
return _VF.tensordot(a, b, dims_a, dims_b) # type: ignore
RuntimeError: CUDA out of memory. Tried to allocate 39.87 GiB (GPU 1; 31.75 GiB total capacity; 7.44 GiB already allocated; 22.61 GiB free; 7.78 GiB reserved in total by PyTorch)
Traceback (most recent call last):
File "train_lightweight.py", line 130, in <module>
trainer.fit(model=model,
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 458, in fit
self._run(model)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 756, in _run
self.dispatch()
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 797, in dispatch
self.accelerator.start_training(self)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 96, in start_training
self.training_type_plugin.start_training(trainer)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 144, in start_training
self._results = trainer.run_stage()
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 807, in run_stage
return self.run_train()
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 842, in run_train
self.run_sanity_check(self.lightning_module)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1107, in run_sanity_check
self.run_evaluation()
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 962, in run_evaluation
output = self.evaluation_loop.evaluation_step(batch, batch_idx, dataloader_idx)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 174, in evaluation_step
output = self.trainer.accelerator.validation_step(args)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 226, in validation_step
return self.training_type_plugin.validation_step(*args)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ddp.py", line 322, in validation_step
return self.model(*args, **kwargs)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 705, in forward
output = self.module(*inputs[0], **kwargs[0])
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 57, in forward
output = self.module.validation_step(*inputs, **kwargs)
File "/kyukon/data/gent/417/vsc41768/airhead/training/lightning.py", line 160, in validation_step
y_hat = self.inference(x, self, **self.inference_params)
File "/kyukon/data/gent/417/vsc41768/airhead/training/inference.py", line 42, in val_inference
output = model(input)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/kyukon/data/gent/417/vsc41768/airhead/training/lightning.py", line 98, in forward
return self.net(x)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/kyukon/data/gent/417/vsc41768/airhead/models/lightweight_unet.py", line 130, in forward
enc_1 = self.enc_1(input)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/kyukon/data/gent/417/vsc41768/airhead/layers/lightweight_conv.py", line 589, in forward
return self.block(input)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/modules/container.py", line 119, in forward
input = module(input)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/kyukon/data/gent/417/vsc41768/airhead/layers/lightweight_conv.py", line 517, in forward
output = self.einsum_expression(patches, *weights)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/opt_einsum/contract.py", line 763, in __call__
return self._contract(ops, out, backend, evaluate_constants=evaluate_constants)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/opt_einsum/contract.py", line 693, in _contract
return _core_contract(list(arrays),
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/opt_einsum/contract.py", line 573, in _core_contract
new_view = _tensordot(*tmp_operands, axes=(tuple(left_pos), tuple(right_pos)), backend=backend)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/opt_einsum/sharing.py", line 131, in cached_tensordot
return tensordot(x, y, axes, backend=backend)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/opt_einsum/contract.py", line 374, in _tensordot
return fn(x, y, axes=axes)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/opt_einsum/backends/torch.py", line 54, in tensordot
return torch.tensordot(x, y, dims=axes)
File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/functional.py", line 1002, in tensordot
return _VF.tensordot(a, b, dims_a, dims_b) # type: ignore
RuntimeError: CUDA out of memory. Tried to allocate 39.87 GiB (GPU 0; 31.75 GiB total capacity; 7.44 GiB already allocated; 22.61 GiB free; 7.78 GiB reserved in total by PyTorch)