CUDA out of memory error for tensorized network

Hi everyone,

I’m trying to train a model on my university’s HPC. It has plenty of GPUs (each with 32 GB RAM). I ran it with 2 GPUs, but I’m still getting the dreaded CUDA out of memory error (after being in the queue for quite a while, annoyingly).

My model is a 3D UNet that takes on 4x128x128x128 input. My batch size is already 1. The problem is that I’m replacing the conv layers with tensor networks to reduce the number of calculations, but that this (somewhat ironically) blows up my memory demand due to the unfold operations I’m using to achieve that.

These are the parameters I’m using with the trainer.

# Initialize trainer
log("Initializing trainer")
trainer = Trainer(
    max_epochs=200,
    logger=tb_logger,
    gpus=-1,
    deterministic=True,
    distributed_backend='ddp',
    callbacks=[
        LearningRateMonitor(logging_interval="step"),
        PrintTableMetricsCallback(),
    ],
)

My question: how can I make better use of the GPU RAM? It should be a combined 64GB, but this output (see below) gives me the impression that the demand is not appropriately distributed.

PS: I just now added a plugins='ddp_sharded parameter (having installed fairscale in my venv as well), but I fear that won’t be enough. Still in the queue though, will update once it runs.

Global seed set to 616
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Global seed set to 616
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/2
Global seed set to 616
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Global seed set to 616
initializing ddp: GLOBAL_RANK: 1, MEMBER: 2/2
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
Set SLURM handle signals.
Set SLURM handle signals.

  | Name | Type        | Params
-------------------------------------
0 | net  | LowRankUNet | 16.7 M
-------------------------------------
16.7 M    Trainable params
0         Non-trainable params
16.7 M    Total params
66.900    Total estimated model params size (MB)
Default upsampling behavior when mode=trilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
The dataloader, val dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 32 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
Default upsampling behavior when mode=trilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.
Traceback (most recent call last):
  File "/kyukon/data/gent/417/vsc41768/airhead/train_lightweight.py", line 130, in <module>
    trainer.fit(model=model,
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 458, in fit
    self._run(model)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 756, in _run
    self.dispatch()
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 797, in dispatch
    self.accelerator.start_training(self)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 96, in start_training
    self.training_type_plugin.start_training(trainer)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 144, in start_training
    self._results = trainer.run_stage()
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 807, in run_stage
    return self.run_train()
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 842, in run_train
    self.run_sanity_check(self.lightning_module)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1107, in run_sanity_check
    self.run_evaluation()
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 962, in run_evaluation
    output = self.evaluation_loop.evaluation_step(batch, batch_idx, dataloader_idx)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 174, in evaluation_step
    output = self.trainer.accelerator.validation_step(args)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 226, in validation_step
    return self.training_type_plugin.validation_step(*args)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ddp.py", line 322, in validation_step
    return self.model(*args, **kwargs)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 705, in forward
    output = self.module(*inputs[0], **kwargs[0])
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 57, in forward
    output = self.module.validation_step(*inputs, **kwargs)
  File "/kyukon/data/gent/417/vsc41768/airhead/training/lightning.py", line 160, in validation_step
    y_hat = self.inference(x, self, **self.inference_params)
  File "/kyukon/data/gent/417/vsc41768/airhead/training/inference.py", line 42, in val_inference
    output = model(input)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/kyukon/data/gent/417/vsc41768/airhead/training/lightning.py", line 98, in forward
    return self.net(x)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/kyukon/data/gent/417/vsc41768/airhead/models/lightweight_unet.py", line 130, in forward
    enc_1 = self.enc_1(input)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/kyukon/data/gent/417/vsc41768/airhead/layers/lightweight_conv.py", line 589, in forward
    return self.block(input)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/modules/container.py", line 119, in forward
    input = module(input)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/kyukon/data/gent/417/vsc41768/airhead/layers/lightweight_conv.py", line 517, in forward
    output = self.einsum_expression(patches, *weights)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/opt_einsum/contract.py", line 763, in __call__
    return self._contract(ops, out, backend, evaluate_constants=evaluate_constants)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/opt_einsum/contract.py", line 693, in _contract
    return _core_contract(list(arrays),
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/opt_einsum/contract.py", line 573, in _core_contract
    new_view = _tensordot(*tmp_operands, axes=(tuple(left_pos), tuple(right_pos)), backend=backend)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/opt_einsum/sharing.py", line 131, in cached_tensordot
    return tensordot(x, y, axes, backend=backend)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/opt_einsum/contract.py", line 374, in _tensordot
    return fn(x, y, axes=axes)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/opt_einsum/backends/torch.py", line 54, in tensordot
    return torch.tensordot(x, y, dims=axes)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/functional.py", line 1002, in tensordot
    return _VF.tensordot(a, b, dims_a, dims_b)  # type: ignore
RuntimeError: CUDA out of memory. Tried to allocate 39.87 GiB (GPU 1; 31.75 GiB total capacity; 7.44 GiB already allocated; 22.61 GiB free; 7.78 GiB reserved in total by PyTorch)
Traceback (most recent call last):
  File "train_lightweight.py", line 130, in <module>
    trainer.fit(model=model,
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 458, in fit
    self._run(model)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 756, in _run
    self.dispatch()
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 797, in dispatch
    self.accelerator.start_training(self)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 96, in start_training
    self.training_type_plugin.start_training(trainer)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 144, in start_training
    self._results = trainer.run_stage()
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 807, in run_stage
    return self.run_train()
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 842, in run_train
    self.run_sanity_check(self.lightning_module)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1107, in run_sanity_check
    self.run_evaluation()
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 962, in run_evaluation
    output = self.evaluation_loop.evaluation_step(batch, batch_idx, dataloader_idx)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 174, in evaluation_step
    output = self.trainer.accelerator.validation_step(args)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 226, in validation_step
    return self.training_type_plugin.validation_step(*args)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ddp.py", line 322, in validation_step
    return self.model(*args, **kwargs)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 705, in forward
    output = self.module(*inputs[0], **kwargs[0])
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 57, in forward
    output = self.module.validation_step(*inputs, **kwargs)
  File "/kyukon/data/gent/417/vsc41768/airhead/training/lightning.py", line 160, in validation_step
    y_hat = self.inference(x, self, **self.inference_params)
  File "/kyukon/data/gent/417/vsc41768/airhead/training/inference.py", line 42, in val_inference
    output = model(input)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/kyukon/data/gent/417/vsc41768/airhead/training/lightning.py", line 98, in forward
    return self.net(x)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/kyukon/data/gent/417/vsc41768/airhead/models/lightweight_unet.py", line 130, in forward
    enc_1 = self.enc_1(input)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/kyukon/data/gent/417/vsc41768/airhead/layers/lightweight_conv.py", line 589, in forward
    return self.block(input)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/modules/container.py", line 119, in forward
    input = module(input)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/kyukon/data/gent/417/vsc41768/airhead/layers/lightweight_conv.py", line 517, in forward
    output = self.einsum_expression(patches, *weights)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/opt_einsum/contract.py", line 763, in __call__
    return self._contract(ops, out, backend, evaluate_constants=evaluate_constants)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/opt_einsum/contract.py", line 693, in _contract
    return _core_contract(list(arrays),
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/opt_einsum/contract.py", line 573, in _core_contract
    new_view = _tensordot(*tmp_operands, axes=(tuple(left_pos), tuple(right_pos)), backend=backend)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/opt_einsum/sharing.py", line 131, in cached_tensordot
    return tensordot(x, y, axes, backend=backend)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/opt_einsum/contract.py", line 374, in _tensordot
    return fn(x, y, axes=axes)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/opt_einsum/backends/torch.py", line 54, in tensordot
    return torch.tensordot(x, y, dims=axes)
  File "/data/gent/417/vsc41768/miniconda3/envs/airenv/lib/python3.8/site-packages/torch/functional.py", line 1002, in tensordot
    return _VF.tensordot(a, b, dims_a, dims_b)  # type: ignore
RuntimeError: CUDA out of memory. Tried to allocate 39.87 GiB (GPU 0; 31.75 GiB total capacity; 7.44 GiB already allocated; 22.61 GiB free; 7.78 GiB reserved in total by PyTorch)

Update: looks as though the problem is my (triple) use of torch.Tensor.unfold. The reason for doing so, is that I’m replacing convolutional layers with tensorized versions, which imply a manual contraction between unfolded input and a (formatted) weight tensor. From what I gathered so far, I can try to use active checkpointing to offload some of the memory usage, and use sharding. I’m trying to see if I can get any of those optimizations to work. If anyone has any sort of experience with optimization of the unfold function, additional insights would be greatly appreciated!