Dear experts,
I would like to “translate” a pytorch geometric script to pytorch lightning (to be used with 4 GPUs or even more in several nodes).
The issue is that when using pytorch lightning the GPUs’ memories are filled way over than it happens with the pytorch script.
For example, this is what happens with pytorch:
[0] Tesla V100-SXM2-16GB | 51'C, 59 % | 5937 / 16160 MB | dzulian1(4885M) root(8M)
[1] Tesla V100-SXM2-16GB | 68'C, 60 % | 4924 / 16160 MB | dzulian1(4913M) root(8M)
[2] Tesla V100-SXM2-16GB | 59'C, 56 % | 4892 / 16160 MB | dzulian1(4881M) root(8M)
[3] Tesla V100-SXM2-16GB | 67'C, 61 % | 4924 / 16160 MB | dzulian1(4913M) root(8M)
and this is with pytorch lightning (same number of training events, same batch size):
[0] Tesla V100-SXM2-16GB | 54'C, 100 % | 9223 / 16160 MB | dzulian1(8171M) root(8M)
[1] Tesla V100-SXM2-16GB | 70'C, 99 % | 7364 / 16160 MB | dzulian1(7353M) root(8M)
[2] Tesla V100-SXM2-16GB | 62'C, 85 % | 8182 / 16160 MB | dzulian1(8171M) root(8M)
[3] Tesla V100-SXM2-16GB | 68'C, 72 % | 8166 / 16160 MB | dzulian1(8155M) root(8M)
also the RAM usage is doubled with pytorch lightning.
This is the pytorch code:
if MULTI_GPU:
train_loader = DataListLoader(train_set, **loader_params)
val_loader = DataListLoader(val_set, **loader_params)
else:
train_loader = DataLoader(train_set, **loader_params)
val_loader = DataLoader(val_set, **loader_params)
ttl = AdaMT_loss(model=DataParallel(Net(hparams=hparams, max_overlap=5).float()), sigma_E=[1.0], sigma_X=[1.0], sigma_Y=[1.0], sigma_C=[1.0])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ttl.model.to(device)
ttl.to(device)
optimizer = torch.optim.AdamW(ttl.parameters(),lr=hparams.learning_rate,
weight_decay=hparams.lr_weight_decay)
scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, patience=20)
print("\n\n")
print("================ MODEL SUMMARY =====================")
print(ttl.model)
print()
model_trainable__params = sum(p.numel() for p in ttl.model.parameters() if p.requires_grad)
print(f">>>> NUMBER OF TRAINABLE PARAMETERS: {model_trainable__params}")
print("====================================================\n")
#++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# === START TRAINING LOOP ===
start = timeit.default_timer()
master_train_loss = []
master_val_loss = []
master_sigma_E = []
master_sigma_X = []
master_sigma_Y = []
master_sigma_C = []
master_percHist = []
# train
block_evolution_counter = 0.0
for epoch in tqdm(range(NUM_EPOCHS+1)):
train_loss_container = []
val_loss_container = []
std_E_container = []
std_X_container = []
std_Y_container = []
std_C_container = []
std_hits_container = []
ttl.model.train()
for data_list in train_loader:
optimizer.zero_grad()
total_loss, std_E, std_X, std_Y, std_C = ttl(data_list)
total_loss.to(device)
total_loss.backward()
train_loss_container.append(total_loss.item() )
std_E_container.append(std_E.item() )
std_X_container.append(std_X.item() )
std_Y_container.append(std_Y.item() )
std_C_container.append(std_C.item() )
optimizer.step()
# validate
with torch.no_grad():
ttl.model.eval()
for data_list in val_loader:
val_loss, val_std_E, val_std_X, val_std_Y, val_std_C = ttl(data_list)
val_loss.to(device)
val_loss_container.append(val_loss.item())
if np.mean(val_loss_container) > np.mean(train_loss_container):
block_evolution_counter+=1.0
else:
block_evolution_counter=0
scheduler.step(np.mean(val_loss_container))
master_train_loss.append(np.mean(train_loss_container))
master_sigma_E.append(np.mean(std_E_container))
master_sigma_X.append(np.mean(std_X_container))
master_sigma_Y.append(np.mean(std_Y_container))
master_sigma_C.append(np.mean(std_C_container))
master_percHist.append(np.mean(std_hits_container))
master_val_loss.append(np.mean(val_loss_container))
print(f"TRAIN LOSS: {np.mean(train_loss_container)}; VAL LOSS: {np.mean(val_loss_container)};")
if epoch%1==0:
torch.save({
'epoch': epoch,
'model_state_dict': ttl.model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': total_loss.item(),
'val_loss': val_loss.item(),
'master_train_loss' : master_train_loss,
'master_val_loss' : master_val_loss,
'master_sigma_E' : master_sigma_E,
'master_sigma_X' : master_sigma_X,
'master_sigma_Y' : master_sigma_Y,
'master_sigma_C' : master_sigma_C
}, f"./training_checkpoints/CLIrun_checkpoint_{epoch}.pt")
if block_evolution_counter==5.0:
print("threshold achieved. === STOP ===")
break
with Net
and AdaMT_loss
defined in the following way:
class Net(torch.nn.Module):
"""CaloGNN"""
def __init__(self, hparams, max_overlap=5, yolo_x_size=6, yolo_y_size=6):
super(Net, self).__init__()
# network layers and hyperparameter scans
gnn_flts = hparams.edge_flts
self.master_flts = hparams.conv_flts
kNN = hparams.kNN
aggr = hparams.agrr
self.conv1 = DynamicEdgeConv(MLP(hparams, [2 * 4, gnn_flts, gnn_flts]), k=kNN, aggr=aggr)
self.conv2 = DynamicEdgeConv(MLP(hparams, [2 * gnn_flts, gnn_flts, gnn_flts]), k=kNN, aggr=aggr)
self.conv3 = DynamicEdgeConv(MLP(hparams, [2 * gnn_flts, gnn_flts, gnn_flts]), k=kNN, aggr=aggr)
self.lin1 = MLP(hparams, [3 * gnn_flts, self.master_flts])
# variable number of resnet blocks
self.n_resnet = hparams.n_resnet
for i in range(self.n_resnet):
setattr(self, f"resnet_{i}", ResNet_block(hparams, self.master_flts) )
self.head_E = Conv2d_stack(hparams, self.master_flts, max_overlap)
self.head_x = Conv2d_stack(hparams, self.master_flts, max_overlap)
self.head_y = Conv2d_stack(hparams, self.master_flts, max_overlap)
self.head_kappa = Conv2d_stack(hparams, self.master_flts, max_overlap)
# spurious numerical parameters
self.yolo_x_size = yolo_x_size
self.yolo_y_size = yolo_y_size
self.max_overlap = max_overlap
def forward(self, data):
#print(type(data))
x0, batch, bars = data.x, data.batch, data.bars
idxs = torch.arange(bars.size(0))
batchY = idxs//(self.yolo_x_size*self.yolo_y_size)
batchY = batchY.to(x0.device)
# egdeconvs and skip connection into filter MLP
x1 = self.conv1(x0, batch)
x2 = self.conv2(x1, batch)
x3 = self.conv3(x2, batch)
out_edgeconv_stack = self.lin1(torch.cat([x1, x2, x3], dim=-1))
# map pixels onto yolo voxels and maxpool
x_coords = x0[:,1:3]
cluster = nearest(x=x_coords, y=bars.float(), batch_x=batch, batch_y = batchY)
x, batch = max_pool_x(cluster, out_edgeconv_stack, data.batch)
# format for CNNs
x = x.view(-1, self.master_flts, self.yolo_y_size, self.yolo_x_size)
# NN awareness
for i in range(self.n_resnet):
x = getattr(self, f"resnet_{i}")(x)
# heads
outputs = torch.stack([self.head_x(x).view(-1, self.max_overlap), \
self.head_y(x).view(-1, self.max_overlap), \
self.head_E(x).view(-1, self.max_overlap), \
torch.sigmoid( self.head_kappa(x) ).view(-1, self.max_overlap)], dim=-1)
return outputs
and
class AdaMT_loss(tnn.Module):
def __init__(self, model, sigma_E, sigma_X, sigma_Y, sigma_C):
super(AdaMT_loss, self).__init__()
self.model = model
self.sigma_E = tnn.Parameter(torch.Tensor(sigma_E))
self.sigma_X = tnn.Parameter(torch.Tensor(sigma_X))
self.sigma_Y = tnn.Parameter(torch.Tensor(sigma_Y))
self.sigma_C = tnn.Parameter(torch.Tensor(sigma_C))
def forward(self, data_list):
#print(type(data_list))
#print(data_list)
chained_heads = self.model(data_list)
targets = torch.cat([data.y for data in data_list]).to(chained_heads.device)
overlap_cluster_mask = torch.sum(targets[...,-1], dim=-1)>0.0 #check dim[?]
# cellwise_assignment = list(Munkres_wrapper(distance_matrix(chained_heads[...,0:3], targets, overlap_size=max_overlap).tolist()).numpy())
# chained_heads = chained_heads[torch.arange(chained_heads.size(0)).unsqueeze(1), cellwise_assignment]
# regression heads
mse_E = torch.mean((chained_heads[overlap_cluster_mask][...,2] - targets[overlap_cluster_mask][...,-1])**2).to(chained_heads.device)
mse_X = torch.mean((chained_heads[overlap_cluster_mask][...,0] - targets[overlap_cluster_mask][...,0])**2).to(chained_heads.device)
mse_Y = torch.mean((chained_heads[overlap_cluster_mask][...,1] - targets[overlap_cluster_mask][...,1])**2).to(chained_heads.device)
# object detection
is_on_truth = (targets[...,-1]>0).float()
obj_detection_loss = FocalLoss(gamma=2, alpha=.25).to(chained_heads.device)
is_on_mask = obj_detection_loss(chained_heads[...,-1], is_on_truth)
adaptive_mse = .5*mse_E/(self.sigma_E**2) + .5*mse_X/(self.sigma_X**2) + \
.5*mse_Y/(self.sigma_Y**2) + torch.log(self.sigma_E*self.sigma_X*self.sigma_Y) + \
is_on_mask/(self.sigma_C**2) + torch.log(self.sigma_C)
loss = adaptive_mse.float().to(chained_heads.device)
return loss, self.sigma_E, self.sigma_X, self.sigma_Y, self.sigma_C
This is my “translation” to pytorch lightning:
class LightningML(pl.LightningModule):
def __init__(self, hparams, max_overlap=5, yolo_x_size=6, yolo_y_size=6, sigma_E=[1.0], sigma_X=[1.0], sigma_Y=[1.0], sigma_C=[1.0]):
super(LightningML, self).__init__()
# network layers and hyperparameter scans
gnn_flts = hparams.edge_flts
self.master_flts = hparams.conv_flts
kNN = hparams.kNN
aggr = hparams.agrr
self.conv1 = DynamicEdgeConv(MLP(hparams, [2 * 4, gnn_flts, gnn_flts]), k=kNN, aggr=aggr)
self.conv2 = DynamicEdgeConv(MLP(hparams, [2 * gnn_flts, gnn_flts, gnn_flts]), k=kNN, aggr=aggr)
self.conv3 = DynamicEdgeConv(MLP(hparams, [2 * gnn_flts, gnn_flts, gnn_flts]), k=kNN, aggr=aggr)
self.lin1 = MLP(hparams, [3 * gnn_flts, self.master_flts])
# variable number of resnet blocks
self.n_resnet = hparams.n_resnet
for i in range(self.n_resnet):
setattr(self, f"resnet_{i}", ResNet_block(hparams, self.master_flts) )
self.head_E = Conv2d_stack(hparams, self.master_flts, max_overlap)
self.head_x = Conv2d_stack(hparams, self.master_flts, max_overlap)
self.head_y = Conv2d_stack(hparams, self.master_flts, max_overlap)
self.head_kappa = Conv2d_stack(hparams, self.master_flts, max_overlap)
# spurious numerical parameters
self.yolo_x_size = yolo_x_size
self.yolo_y_size = yolo_y_size
self.max_overlap = max_overlap
self.sigma_E = tnn.Parameter(torch.Tensor(sigma_E))
self.sigma_X = tnn.Parameter(torch.Tensor(sigma_X))
self.sigma_Y = tnn.Parameter(torch.Tensor(sigma_Y))
self.sigma_C = tnn.Parameter(torch.Tensor(sigma_C))
self.batch_size = 250
def forward(self, x, batch, bars):
idxs = torch.arange(bars.size(0))
batchY = torch.div(idxs,self.yolo_x_size*self.yolo_y_size, rounding_mode='floor')
batchY = batchY.to(x.device)
# egdeconvs and skip connection into filter MLP
x1 = self.conv1(x, batch)
x2 = self.conv2(x1, batch)
x3 = self.conv3(x2, batch)
out_edgeconv_stack = self.lin1(torch.cat([x1, x2, x3], dim=-1))
# map pixels onto yolo voxels and maxpool
x_coords = x[:,1:3]
cluster = nearest(x=x_coords, y=bars.float(), batch_x=batch, batch_y = batchY)
print(cluster)
x, batch = max_pool_x(cluster, out_edgeconv_stack, batch)
# format for CNNs
x = x.view(-1, self.master_flts, self.yolo_y_size, self.yolo_x_size)
# NN awareness
for i in range(self.n_resnet):
x = getattr(self, f"resnet_{i}")(x)
# heads
outputs = torch.stack([self.head_x(x).view(-1, self.max_overlap), \
self.head_y(x).view(-1, self.max_overlap), \
self.head_E(x).view(-1, self.max_overlap), \
torch.sigmoid( self.head_kappa(x) ).view(-1, self.max_overlap)], dim=-1)
return outputs
def training_step(self,batch,batch_index):
x, batchh, bars = batch.x, batch.batch, batch.bars
chained_heads = self.forward(x, batchh, bars)
targets = torch.cat([batch.y])
overlap_cluster_mask = torch.sum(targets[...,-1], dim=-1)>0.0 #check dim[?]
# regression heads
mse_E = torch.mean((chained_heads[overlap_cluster_mask][...,2] - targets[overlap_cluster_mask][...,-1])**2)
mse_X = torch.mean((chained_heads[overlap_cluster_mask][...,0] - targets[overlap_cluster_mask][...,0])**2)
mse_Y = torch.mean((chained_heads[overlap_cluster_mask][...,1] - targets[overlap_cluster_mask][...,1])**2)
# object detection
is_on_truth = (targets[...,-1]>0).float()
obj_detection_loss = FocalLoss(gamma=2, alpha=.25)
is_on_mask = obj_detection_loss(chained_heads[...,-1], is_on_truth)
adaptive_mse = .5*mse_E/(self.sigma_E**2) + .5*mse_X/(self.sigma_X**2) + \
.5*mse_Y/(self.sigma_Y**2) + torch.log(self.sigma_E*self.sigma_X*self.sigma_Y) + \
is_on_mask/(self.sigma_C**2) + torch.log(self.sigma_C)
loss = adaptive_mse.float()
if self.global_step % 10 == 0:
torch.cuda.empty_cache()
return {'loss' : loss}
def train_dataloader(self):
return DataLoader(train_set, batch_size=self.batch_size)
def training_step_end(self,outputs):
self.log('train_loss',outputs['loss'],on_epoch=True,sync_dist=True)
return outputs
def training_epoch_end(self,outputs):
return None
def configure_optimizers(self):
return torch.optim.AdamW(self.parameters(),lr=hparams.learning_rate,weight_decay=hparams.lr_weight_decay)
if __name__ == '__main__':
model = LightningML(hparams)
start = timeit.default_timer()
trainer = pl.Trainer(strategy='ddp_find_unused_parameters_false', max_epochs=NUM_EPOCHS,
accelerator='gpu', devices=tot_GPUs, precision=16, profiler="simple")
print("module size (MB)= ", pl.utilities.memory.get_model_size_mb(model))
trainer.fit(model)
stop = timeit.default_timer()
print('Training Time: ', stop - start)
where I basically included Net
and AdaMT_loss
in the lightning module.
Maybe it’s a bit messy, but I can’t see where I’m missing something to justify this increase in memory usage.
Thank you very much for your help