Hello there,
I am currently encountering an issue with my PyTorch-Lightning pipeline related to memory consumption. During the 3th epoch, I encounter a CUDA OutOfMemory error: torch. cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 5.00 GiB (GPU 0; 79.15 GiB total capacity; 66.54 GiB already allocated; 3.43 GiB free; 74.69 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
The issue being, my batch size is already equal to 1 (I work with large objects) and I am checking that every sample can fit the model by doing a dry run on the biggest sample from the dataset prior to the run (with precheck_memory()
), which is completed correctly. I looked at the GPU Memory Allocated panels (on Weight&Biases - see below), which indicates that my batches correctly fit the memory (peak around t = 10min), but then there is a slow but constant increase of the memory over the training which ultimately leads to a memory overflow at the end.
I assume this issue is coming from some results or metrics of the training that are kept on the GPU memory and, as those data are growing along the training, lead to this memory overflow.
Here is the model I am using (I shortened the code to make the reading easy):
class Model(ptl.LightningModule):
def __init__(self, config, pos_weight_factor=0.9, global_step=0, criterion=pt.nn.BCEWithLogitsLoss(reduction="none")):
super(Model, self).__init__()
self.em = pt.nn.Sequential(
pt.nn.Linear(N0, N1),
pt.nn.ELU(),
......
)
self.sum = pt.nn.Sequential(....)
self.spl = StatePoolLayer(.....)
self.dm = pt.nn.Sequential(
pt.nn.Linear(2*N0, N1),
pt.nn.ELU(),
...
)
self.criterion = criterion
self.pos_ratios = 0.05*pt.ones(N2, dtype=pt.float)
self.pos_weight_factor = pos_weight_factor
self.metrics = {'auc': tm.AUROC(task='multiclass', num_classes=20),
'f1': tm.F1Score(task='multiclass', num_classes=20),
'mcc': tm.MatthewsCorrCoef(task='multiclass', num_classes=20),
'mlrap': tm.classification.MultilabelRankingAveragePrecision(num_labels=20),
'mlrl': tm.classification.MultilabelRankingLoss(num_labels=20)}
def setup(self, stage: str) -> None:
# Change device here - setup is called after the model is moved to the device
self.criterion = self.criterion.to(self.device)
self.pos_ratios = self.pos_ratios.to(self.device)
self.metrics = {k: v.to(self.device) for k, v in self.metrics.items()}
return super().setup(stage)
def precheck_memory(self, dataset):
# quick training step on largest data: memory check and pre-allocation
try:
with self.eval():
data = dataset.get_largest()
p = self.predict_step(data)
except:
raise MemoryError(f"Not enough memory to pre-allocate a model containing {data[0].shape[0]} residues (pdb of {np.max(dataset.sizes[dataset.m,0])}o)")
def forward(self, X, ids_topk, q0, M):
........
return z
def apply_loss(self, gt, pred, is_training=False):
# compute weighted loss
if is_training:
self.pos_ratios += ((pt.mean(gt,dim=0) - self.pos_ratios) / (1.0 + np.sqrt(self.global_step))).to(self.device)
self.criterion.pos_weight = self.pos_weight_factor * (1.0 - self.pos_ratios) / (self.pos_ratios + 1e-6)
self.log_dict({'pos_ratios_mean': pt.mean(self.pos_ratios), 'pos_weight_mean': pt.mean(self.criterion.pos_weight)}, sync_dist=True)
dloss = self.criterion(pred, gt).to(self.device)
# re-weighted losses
loss_factors = (self.pos_ratios / pt.sum(self.pos_ratios)).reshape(1,-1).to(self.device)
losses = (loss_factors * dloss) / dloss.shape[0]
# backward propagation
loss = pt.sum(losses)
return loss
def apply_recovery(self, gt, pred):
aa_mask = pt.any(gt > 0.5, axis=1)
if pt.any(aa_mask):
gt_seq = pt.argmax(gt[aa_mask], axis=1)
# Compute ability to recover sequence from a certain confidence
acc_any = pt.sum(pt.gt(pred[aa_mask, gt_seq], self.pred_threshold)) / gt_seq.shape[0]
# Compute direct sequence recovery
acc = pt.sum(pt.argmax(pred[aa_mask], axis=1) == gt_seq) / gt_seq.shape[0]
else:
raise NotImplementedError("....")
return acc_any.detach().cpu().item(), acc.detach().cpu().item()
def apply_metrics(self, gt, pred):
return {
'auc': pt.mean(self.metrics['auc'](pred, pt.argmax(gt, axis=1))).detach().cpu().item(),
'f1': pt.mean(self.metrics['f1'](pred, gt)).detach().cpu().item(),
'mcc': pt.mean(self.metrics['mcc'](pred, pt.argmax(gt, axis=1)).float()).detach().cpu().item(),
'mlrap': pt.mean(self.metrics['mlrap'](pred, gt)).detach().cpu().item(),
'mlrl': pt.mean(self.metrics['mlrl'](pred, gt)).detach().cpu().item()
}
def training_step(self, batch):
X, q, M, y, rids_sel = batch
ids_topk, _, _, _, _ = extract_topology(X, 64)
p = self.forward(X, ids_topk, q, M)
loss = self.apply_loss(y, p, is_training=True)
acc_any, acc = self.apply_recovery(y, p)
scores = self.apply_metrics(y, p)
self.log_dict({'train_loss': loss, 'train_accuracy_any': acc_any, 'train_recovery': acc, **{'train_'+k:v for k,v in scores.items()}}, sync_dist=True)
return loss
def validation_step(self, batch):
X, q, M, y, rids_sel = batch
ids_topk, _, _, _, _ = extract_topology(X, 64)
p = self.forward(X, ids_topk, q, M)
loss = self.apply_loss(y, p, is_training=False)
acc_any, acc = self.apply_recovery(y, p)
scores = self.apply_metrics(y, p)
self.log_dict({'val_loss': loss, 'val_accuracy_any': acc_any, 'val_recovery': acc, **{'val_'+k:v for k,v in scores.items()}}, sync_dist=True)
return loss
def test_step(self, batch):
X, q, M, y, rids_sel = batch
ids_topk, _, _, _, _ = extract_topology(X, 64)
p = self.forward(X, ids_topk, q, M)
loss = self.apply_loss(y, p, is_training=False)
acc_any, acc = self.apply_recovery(y, p)
scores = self.apply_metrics(y, p)
self.log_dict({'test_loss': loss, 'test_accuracy_any': acc_any, 'test_recovery': acc, **{'test_'+k:v for k,v in scores.items()}}, sync_dist=True)
return loss
def predict_step(self, batch):
X, q, M, y, rids_sel = batch
ids_topk, _, _, _, _ = extract_topology(X, 64)
p = self.forward(X, ids_topk, q, M)
return p
def configure_optimizers(self, lr: float = 1e-3) -> pt.optim.Optimizer:
return pt.optim.Adam(self.parameters(), lr=lr)
Would you have any idea to solve this issue, or save some memory by placing monitoring tensors on the CPU?