Translate PyTorch geometric script to PyTorch Lightning

Dear experts,

I would like to “translate” a pytorch geometric script to pytorch lightning (to be used with 4 GPUs or even more in several nodes).
The issue is that when using pytorch lightning the GPUs’ memories are filled way over than it happens with the pytorch script.

For example, this is what happens with pytorch:

[0] Tesla V100-SXM2-16GB | 51'C,  59 % |  5937 / 16160 MB | dzulian1(4885M) root(8M)
[1] Tesla V100-SXM2-16GB | 68'C,  60 % |  4924 / 16160 MB | dzulian1(4913M) root(8M)
[2] Tesla V100-SXM2-16GB | 59'C,  56 % |  4892 / 16160 MB | dzulian1(4881M) root(8M)
[3] Tesla V100-SXM2-16GB | 67'C,  61 % |  4924 / 16160 MB | dzulian1(4913M) root(8M)

and this is with pytorch lightning (same number of training events, same batch size):

[0] Tesla V100-SXM2-16GB | 54'C, 100 % |  9223 / 16160 MB | dzulian1(8171M) root(8M)
[1] Tesla V100-SXM2-16GB | 70'C,  99 % |  7364 / 16160 MB | dzulian1(7353M) root(8M)
[2] Tesla V100-SXM2-16GB | 62'C,  85 % |  8182 / 16160 MB | dzulian1(8171M) root(8M)
[3] Tesla V100-SXM2-16GB | 68'C,  72 % |  8166 / 16160 MB | dzulian1(8155M) root(8M)

also the RAM usage is doubled with pytorch lightning.

This is the pytorch code:

if MULTI_GPU:
    train_loader = DataListLoader(train_set, **loader_params)
    val_loader = DataListLoader(val_set, **loader_params)
else:
    train_loader = DataLoader(train_set, **loader_params)
    val_loader = DataLoader(val_set, **loader_params)
    
    
ttl = AdaMT_loss(model=DataParallel(Net(hparams=hparams, max_overlap=5).float()), sigma_E=[1.0], sigma_X=[1.0], sigma_Y=[1.0], sigma_C=[1.0])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ttl.model.to(device)
ttl.to(device)
optimizer = torch.optim.AdamW(ttl.parameters(),lr=hparams.learning_rate,
        weight_decay=hparams.lr_weight_decay)
scheduler = ReduceLROnPlateau(optimizer, 'min', verbose=True, patience=20)

print("\n\n")
print("================ MODEL SUMMARY =====================")
print(ttl.model)
print()
model_trainable__params = sum(p.numel() for p in ttl.model.parameters() if p.requires_grad)
print(f">>>> NUMBER OF TRAINABLE PARAMETERS: {model_trainable__params}")
print("====================================================\n")


#++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
#                   === START TRAINING LOOP ===
start = timeit.default_timer()

master_train_loss = []
master_val_loss = []


master_sigma_E = []
master_sigma_X = []
master_sigma_Y = []
master_sigma_C = []
master_percHist = []

# train
block_evolution_counter = 0.0
for epoch in tqdm(range(NUM_EPOCHS+1)):
    
    train_loss_container = []
    val_loss_container = []
    std_E_container = []
    std_X_container = []
    std_Y_container = []
    std_C_container = []
    std_hits_container = []
    
    
    ttl.model.train()
    
    for data_list in train_loader:
       
    
        optimizer.zero_grad()
    
        total_loss, std_E, std_X, std_Y, std_C = ttl(data_list)
        total_loss.to(device)
        
        total_loss.backward()
        train_loss_container.append(total_loss.item() ) 
       
        std_E_container.append(std_E.item() ) 
        std_X_container.append(std_X.item() ) 
        std_Y_container.append(std_Y.item() ) 
        std_C_container.append(std_C.item() ) 
        
        optimizer.step() 
    # validate    
    with torch.no_grad():
        ttl.model.eval()
        for data_list in val_loader:
            
                          
            val_loss, val_std_E, val_std_X, val_std_Y, val_std_C = ttl(data_list)
            val_loss.to(device)
            val_loss_container.append(val_loss.item())
            
    
    if  np.mean(val_loss_container) > np.mean(train_loss_container):
        block_evolution_counter+=1.0
    else:
        block_evolution_counter=0
    
    scheduler.step(np.mean(val_loss_container))
    master_train_loss.append(np.mean(train_loss_container))
    master_sigma_E.append(np.mean(std_E_container))
    master_sigma_X.append(np.mean(std_X_container))
    master_sigma_Y.append(np.mean(std_Y_container))
    master_sigma_C.append(np.mean(std_C_container))
    master_percHist.append(np.mean(std_hits_container))
    master_val_loss.append(np.mean(val_loss_container))
    
    print(f"TRAIN LOSS: {np.mean(train_loss_container)}; VAL LOSS: {np.mean(val_loss_container)};")
    
    if epoch%1==0:
        torch.save({
            'epoch': epoch,
            'model_state_dict': ttl.model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': total_loss.item(),
            'val_loss': val_loss.item(),
            'master_train_loss' : master_train_loss,
            'master_val_loss' : master_val_loss,
            'master_sigma_E' : master_sigma_E,
            'master_sigma_X' : master_sigma_X,
            'master_sigma_Y' : master_sigma_Y,
            'master_sigma_C' : master_sigma_C
            }, f"./training_checkpoints/CLIrun_checkpoint_{epoch}.pt")

    if block_evolution_counter==5.0:
        print("threshold achieved. === STOP ===")
        break

with Net and AdaMT_loss defined in the following way:

class Net(torch.nn.Module):
    """CaloGNN"""
    def __init__(self, hparams, max_overlap=5, yolo_x_size=6, yolo_y_size=6):
        super(Net, self).__init__()
        # network layers and hyperparameter scans
        gnn_flts = hparams.edge_flts
        self.master_flts = hparams.conv_flts
        kNN = hparams.kNN
        aggr = hparams.agrr
        self.conv1 = DynamicEdgeConv(MLP(hparams, [2 * 4, gnn_flts, gnn_flts]),  k=kNN, aggr=aggr)
        self.conv2 = DynamicEdgeConv(MLP(hparams, [2 * gnn_flts, gnn_flts, gnn_flts]), k=kNN, aggr=aggr)
        self.conv3 = DynamicEdgeConv(MLP(hparams, [2 * gnn_flts, gnn_flts, gnn_flts]), k=kNN, aggr=aggr)
        self.lin1 = MLP(hparams, [3 * gnn_flts, self.master_flts])
        
        # variable number of resnet blocks
        self.n_resnet = hparams.n_resnet
        for i in range(self.n_resnet):
            setattr(self, f"resnet_{i}", ResNet_block(hparams, self.master_flts) )
        
        self.head_E = Conv2d_stack(hparams, self.master_flts, max_overlap)
        self.head_x = Conv2d_stack(hparams, self.master_flts, max_overlap)
        self.head_y = Conv2d_stack(hparams, self.master_flts, max_overlap)
        self.head_kappa = Conv2d_stack(hparams, self.master_flts, max_overlap)

        # spurious numerical parameters
        self.yolo_x_size = yolo_x_size
        self.yolo_y_size = yolo_y_size
        self.max_overlap = max_overlap
   
    def forward(self, data):
        #print(type(data))
        x0, batch, bars = data.x, data.batch, data.bars
        idxs = torch.arange(bars.size(0))
        batchY = idxs//(self.yolo_x_size*self.yolo_y_size)
        batchY = batchY.to(x0.device)

        # egdeconvs and skip connection into filter MLP
        x1 = self.conv1(x0, batch)
        x2 = self.conv2(x1, batch)
        x3 = self.conv3(x2, batch)
        out_edgeconv_stack = self.lin1(torch.cat([x1, x2, x3], dim=-1))
     
        # map pixels onto yolo voxels and maxpool
        x_coords = x0[:,1:3]
        cluster = nearest(x=x_coords, y=bars.float(), batch_x=batch, batch_y = batchY)
        x, batch = max_pool_x(cluster, out_edgeconv_stack, data.batch)
        
        # format for CNNs
        x = x.view(-1, self.master_flts, self.yolo_y_size, self.yolo_x_size) 

        # NN awareness
        for i in range(self.n_resnet):
            x = getattr(self, f"resnet_{i}")(x)
        
        # heads 
        outputs = torch.stack([self.head_x(x).view(-1, self.max_overlap), \
                               self.head_y(x).view(-1, self.max_overlap), \
                               self.head_E(x).view(-1, self.max_overlap), \
                               torch.sigmoid( self.head_kappa(x) ).view(-1, self.max_overlap)], dim=-1)
        return outputs  

and

class AdaMT_loss(tnn.Module):
    def __init__(self, model, sigma_E, sigma_X, sigma_Y, sigma_C):
        super(AdaMT_loss, self).__init__()
        self.model = model
        self.sigma_E = tnn.Parameter(torch.Tensor(sigma_E))
        self.sigma_X = tnn.Parameter(torch.Tensor(sigma_X))
        self.sigma_Y = tnn.Parameter(torch.Tensor(sigma_Y))
        self.sigma_C = tnn.Parameter(torch.Tensor(sigma_C))
    
    def forward(self, data_list):

        #print(type(data_list))
        #print(data_list)
        chained_heads = self.model(data_list)
        targets = torch.cat([data.y for data in data_list]).to(chained_heads.device)
        overlap_cluster_mask = torch.sum(targets[...,-1], dim=-1)>0.0 #check dim[?]

    #         cellwise_assignment = list(Munkres_wrapper(distance_matrix(chained_heads[...,0:3], targets, overlap_size=max_overlap).tolist()).numpy())
    #         chained_heads = chained_heads[torch.arange(chained_heads.size(0)).unsqueeze(1), cellwise_assignment]
    
        # regression heads
        mse_E = torch.mean((chained_heads[overlap_cluster_mask][...,2] - targets[overlap_cluster_mask][...,-1])**2).to(chained_heads.device)
        mse_X = torch.mean((chained_heads[overlap_cluster_mask][...,0] - targets[overlap_cluster_mask][...,0])**2).to(chained_heads.device)
        mse_Y = torch.mean((chained_heads[overlap_cluster_mask][...,1] - targets[overlap_cluster_mask][...,1])**2).to(chained_heads.device)
        
        # object detection
        is_on_truth = (targets[...,-1]>0).float()
        obj_detection_loss = FocalLoss(gamma=2, alpha=.25).to(chained_heads.device)
        is_on_mask = obj_detection_loss(chained_heads[...,-1], is_on_truth)        
        
        adaptive_mse =  .5*mse_E/(self.sigma_E**2) + .5*mse_X/(self.sigma_X**2) + \
                        .5*mse_Y/(self.sigma_Y**2) + torch.log(self.sigma_E*self.sigma_X*self.sigma_Y) + \
                        is_on_mask/(self.sigma_C**2) + torch.log(self.sigma_C)        

        loss = adaptive_mse.float().to(chained_heads.device)  
   
        return loss, self.sigma_E, self.sigma_X, self.sigma_Y, self.sigma_C

This is my “translation” to pytorch lightning:

class LightningML(pl.LightningModule):
    def __init__(self, hparams, max_overlap=5, yolo_x_size=6, yolo_y_size=6, sigma_E=[1.0], sigma_X=[1.0], sigma_Y=[1.0], sigma_C=[1.0]):
        super(LightningML, self).__init__()
        # network layers and hyperparameter scans   
        gnn_flts = hparams.edge_flts
        self.master_flts = hparams.conv_flts
        kNN = hparams.kNN
        aggr = hparams.agrr
        self.conv1 = DynamicEdgeConv(MLP(hparams, [2 * 4, gnn_flts, gnn_flts]),  k=kNN, aggr=aggr)
        self.conv2 = DynamicEdgeConv(MLP(hparams, [2 * gnn_flts, gnn_flts, gnn_flts]), k=kNN, aggr=aggr)
        self.conv3 = DynamicEdgeConv(MLP(hparams, [2 * gnn_flts, gnn_flts, gnn_flts]), k=kNN, aggr=aggr)
        self.lin1 = MLP(hparams, [3 * gnn_flts, self.master_flts])
        
        # variable number of resnet blocks
        self.n_resnet = hparams.n_resnet
        for i in range(self.n_resnet):
            setattr(self, f"resnet_{i}", ResNet_block(hparams, self.master_flts) )
        
        self.head_E = Conv2d_stack(hparams, self.master_flts, max_overlap)
        self.head_x = Conv2d_stack(hparams, self.master_flts, max_overlap)
        self.head_y = Conv2d_stack(hparams, self.master_flts, max_overlap)
        self.head_kappa = Conv2d_stack(hparams, self.master_flts, max_overlap)

        # spurious numerical parameters
        self.yolo_x_size = yolo_x_size
        self.yolo_y_size = yolo_y_size
        self.max_overlap = max_overlap

        self.sigma_E = tnn.Parameter(torch.Tensor(sigma_E))
        self.sigma_X = tnn.Parameter(torch.Tensor(sigma_X))
        self.sigma_Y = tnn.Parameter(torch.Tensor(sigma_Y))
        self.sigma_C = tnn.Parameter(torch.Tensor(sigma_C))

        self.batch_size = 250

    
    def forward(self, x, batch, bars):
        idxs = torch.arange(bars.size(0))
        batchY = torch.div(idxs,self.yolo_x_size*self.yolo_y_size, rounding_mode='floor')
        batchY = batchY.to(x.device)

        # egdeconvs and skip connection into filter MLP
        x1 = self.conv1(x, batch)
        x2 = self.conv2(x1, batch)
        x3 = self.conv3(x2, batch)
        out_edgeconv_stack = self.lin1(torch.cat([x1, x2, x3], dim=-1))
     
        # map pixels onto yolo voxels and maxpool
        x_coords = x[:,1:3]
        cluster = nearest(x=x_coords, y=bars.float(), batch_x=batch, batch_y = batchY)
        print(cluster)
        x, batch = max_pool_x(cluster, out_edgeconv_stack, batch)
        
        # format for CNNs
        x = x.view(-1, self.master_flts, self.yolo_y_size, self.yolo_x_size) 

        # NN awareness
        for i in range(self.n_resnet):
            x = getattr(self, f"resnet_{i}")(x)
        
        # heads 
        outputs = torch.stack([self.head_x(x).view(-1, self.max_overlap), \
                               self.head_y(x).view(-1, self.max_overlap), \
                               self.head_E(x).view(-1, self.max_overlap), \
                               torch.sigmoid( self.head_kappa(x) ).view(-1, self.max_overlap)], dim=-1)
        return outputs  
    
    def training_step(self,batch,batch_index):

        x, batchh, bars = batch.x, batch.batch, batch.bars

        chained_heads = self.forward(x, batchh, bars)
        targets = torch.cat([batch.y])
        overlap_cluster_mask = torch.sum(targets[...,-1], dim=-1)>0.0 #check dim[?]

        # regression heads
        mse_E = torch.mean((chained_heads[overlap_cluster_mask][...,2] - targets[overlap_cluster_mask][...,-1])**2)
        mse_X = torch.mean((chained_heads[overlap_cluster_mask][...,0] - targets[overlap_cluster_mask][...,0])**2)
        mse_Y = torch.mean((chained_heads[overlap_cluster_mask][...,1] - targets[overlap_cluster_mask][...,1])**2)
        
        # object detection
        is_on_truth = (targets[...,-1]>0).float()
        obj_detection_loss = FocalLoss(gamma=2, alpha=.25)
        is_on_mask = obj_detection_loss(chained_heads[...,-1], is_on_truth)        
        
        adaptive_mse =  .5*mse_E/(self.sigma_E**2) + .5*mse_X/(self.sigma_X**2) + \
                        .5*mse_Y/(self.sigma_Y**2) + torch.log(self.sigma_E*self.sigma_X*self.sigma_Y) + \
                        is_on_mask/(self.sigma_C**2) + torch.log(self.sigma_C)        

        loss = adaptive_mse.float()
        if self.global_step % 10 == 0:
            torch.cuda.empty_cache()  

        return {'loss' : loss}
    
    def train_dataloader(self):
        return DataLoader(train_set, batch_size=self.batch_size)

    def training_step_end(self,outputs):
        self.log('train_loss',outputs['loss'],on_epoch=True,sync_dist=True)
        return outputs
    
    def training_epoch_end(self,outputs):
        return None

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(),lr=hparams.learning_rate,weight_decay=hparams.lr_weight_decay)
    

if __name__ == '__main__':

    model = LightningML(hparams)


    start = timeit.default_timer()
    trainer = pl.Trainer(strategy='ddp_find_unused_parameters_false', max_epochs=NUM_EPOCHS, 
        accelerator='gpu', devices=tot_GPUs, precision=16, profiler="simple")
    
    print("module size (MB)= ", pl.utilities.memory.get_model_size_mb(model))

    trainer.fit(model)

    stop = timeit.default_timer()
    print('Training Time: ', stop - start) 

where I basically included Net and AdaMT_loss in the lightning module.

Maybe it’s a bit messy, but I can’t see where I’m missing something to justify this increase in memory usage.

Thank you very much for your help :smiley: