End all distributed process after ddp

So after running the ddp for a model training or inference , I want to do some other action such as checking the accuracy or seeing how the model is . but since for example if I have 3 gpus and ddp starts 3 process for them and all the action after the model training or inference is over it repeats it 3 times , which I would like to avoid , is there any way to stop the process from the python script itself , I have attached an example script for reference

import torch.nn.functional as F
from pytorch_lightning import seed_everything, LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping
from torch import nn, optim, rand, sum as tsum, reshape, save
from torch.utils.data import Dataset
from distributed_proxy_sampler import (
    DistributedProxySampler,
)
import torch.utils.data as data
from multiprocessing import active_children
import pandas as pd
import torch
import random
import sys
import string
import os
SAMPLE_DIM = 21000

class CustomDataset(Dataset):
    def __init__(self, samples=42):
        self.dataset = rand(samples, SAMPLE_DIM).cpu().float() * 2 - 1
        letters = string.ascii_lowercase
        random.seed(100)
        self.id = []
        for x in range(samples):
            self.id.append(''.join(random.choice(letters) for i in range(20)))
        
    def __getitem__(self, index):
        return (self.id[index], self.dataset[index], (tsum(self.dataset[index]) > 0).cpu().float())

    def __len__(self):
        return self.dataset.size()[0]

class OurModel(LightningModule):
    def __init__(self):
        super(OurModel, self).__init__()
   
        # Network layers
        self.linear = nn.Linear(SAMPLE_DIM, 2048)
        self.linear2 = nn.Linear(2048, 1)
        self.output = nn.Sigmoid()
        # Hyper-parameters, that we will auto-tune using lightning!
        self.lr = 0.000001
        self.batch_size = 256
        self.num_process = 10
        
    def forward(self, x):
        x = self.linear(x)
        x = self.linear2(x)
        output = self.output(x)
        return reshape(output, (-1,))
    """
    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.lr)

    def train_dataloader(self):
        
        train_ds = CustomDataset(samples=5000)
        sampler = data.RandomSampler(train_ds, replacement=True)
        sampler = DistributedProxySampler(sampler)
        loader = data.DataLoader(
            train_ds,
            batch_size=self.batch_size,
            num_workers=self.num_process,
            pin_memory=True,
            persistent_workers=True,
        )
        return loader

    def training_step(self, batch, batch_nb):
        id_, x, y = batch
        loss = F.binary_cross_entropy(self(x), y)
        return {'loss': loss, 'log': {'train_loss': loss}}

    def val_dataloader(self):

        val_ds = CustomDataset(samples=1234)
        loader = data.DataLoader(
            val_ds,
            batch_size=self.batch_size,
            num_workers=self.num_process,
            pin_memory=True,
            shuffle=False,
            persistent_workers=True,
        )
        return loader
        
    def validation_step(self, batch, batch_nb):
        id_, x, y = batch
        loss = F.binary_cross_entropy(self(x), y)
        return {'val_loss': loss, 'log': {'val_loss': loss}}
        
    def validation_epoch_end(self, outputs):
        val_loss_mean = sum([o['val_loss'] for o in outputs]) / len(outputs)
        # show val_acc in progress bar but only log val_loss
        results = {'progress_bar': {'val_loss': val_loss_mean.item()}, 'log': {'val_loss': val_loss_mean.item()},
                   'val_loss': val_loss_mean.item()}
        print("OUR LR:",self.lr)
        return results
    """
    def predict_dataloader(self):
        val_ds = CustomDataset(samples=100000)
        loader = data.DataLoader(
            val_ds,
            batch_size=self.batch_size,
            num_workers=self.num_process,
            pin_memory=True,
            shuffle=False,
            persistent_workers=True,
        )
        return loader

    
    
    def predict_step(self, batch, batch_idx):
        id_, x, y = batch
        logits = self(x)
        #loss = F.binary_cross_entropy(self(x), y)
        return logits

    def on_predict_batch_end(self,  outputs, batch, batch_idx, dataloader_idx):
        
        output = {}
        id_, x, y = batch
        output["id"] = id_
        output["prediction"] = outputs.cpu().numpy()
        device = outputs.get_device()
        if(device==0):
            df = pd.DataFrame(output)
            df.to_csv(f"out.csv", mode='a',  index=False)

def check():
    df = pd.read_csv("out.csv")
    print(len(df))
    
if __name__ == '__main__':
    seed_everything(42)
    model = OurModel()
    trainer = Trainer(enable_model_summary=True,
            max_epochs=5,
            detect_anomaly=False,
            auto_lr_find=False,
            devices=-1,
            strategy="ddp",
            auto_select_gpus=False,
            auto_scale_batch_size=False,
            accelerator="auto",
            replace_sampler_ddp=False,
            sync_batchnorm=True,
            benchmark=True)
    
    sd = torch.load("model.pt")
    print(model.load_state_dict(sd, strict=True))
    out = trainer.predict(model)
    check()
    """
    active = active_children()
    
    for child in active:
        child.terminate()
        
    for child in active:
        child.join()
    
    check()
    exit()
    sd = torch.load("model2.pt")
    print(model.load_state_dict(sd, strict=True))
    out = trainer.predict(model)
    save(model.state_dict(), 'model.pt')
    """

In this case after the prediction inference it runs the check function 3 times and I want it to run only once , any way to achieve this ?


want to avoid it print the len(df) after the predicting dataloader 3 times and it should only print it once. Any help would be appreciated thanks.

How about doing just this:

if trainer.global_rank == 0:
    # do something only on main process
    # for example, your inference
    check()

trainer.strategy.barrier()  # all processes meet

# rest of code here

Killing the processes shouldn’t be necessary, but if you really want to, you could do:

if trainer.global_rank > 0:
    exit(0)
1 Like

yep this works thanks for the help,
btw how would you approach such a problem where after the training is complete u want it to run another simple checking process or something not training related , would you run a sub process or use the above mentioned method.
Maybe I need to reconfigure my pipeline.

Happy this worked for you :slight_smile:

Personally I prefer to keep these things separated in different scripts, because it’s more practical to run each part individually if I have to, instead of managing the processes.

Btw, in my code above I forgot to add a barrier() after the if-block (edited). It’s needed so that the processes don’t fall out of sync. A caveat is that the processes waiting at the barrier can time out if the process inside the if-condition takes too long. Another reason why maybe this should only be used for simple workloads.

For more complex pipelines where for example I want to combine model training with model deployment or other training unrelated things, there is the Lightning app framework for coordinating these things easily. So there each component runs separately and so one could use multiprocessing for training while the other does not.

1 Like

Hi awaelchli

Thanks for the edit and the detailed response.After working with my script yesterday , I felt it would be more practical to work with them individually

The reason I wanted to integrate them is because I usually run my models overnight and after that I need to run a post training routine where I check the model test results and get the metrics such as AUROC or accuracy, integrating both of them would make it easier as I can just wake up and see the test results too but had a few cuda errors and model check pointing issues which was interfering with the pipeline so decided to run them individually.

Never used the Lightning app framework , will check it out and see if it helps me .

Thanks for the help.

1 Like