How to combine test and validation results from all GPUs in Fabric

Hi, I cannot find the answer to this in the docs. When I train on multiple GPUs, and run the test, I get two outputs for performance metrics, one for each GPU. How can I ensure results from all GPUs are combined to report the performance for the whole validation and test set?

import os
import torch
import torchvision as tv
import lightning as L
import argparse
from torch.utils.data import DataLoader, Subset, DistributedSampler
import numpy as np
import torch.distributed as dist

def load_data(data_dir):
    transform = tv.transforms.Compose([
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    dataset = tv.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform)
    return dataset

def split_train_val(dataset, val_size):
    """ Randomly split the dataset into training and validation subsets """
    indices = np.arange(len(dataset))
    np.random.shuffle(indices)
    split = int(np.floor(val_size * len(dataset)))
    train_indices, val_indices = indices[split:], indices[:split]
    train_subset = Subset(dataset, train_indices)
    val_subset = Subset(dataset, val_indices)
    return train_subset, val_subset

def train(fabric, model, optimizer, train_dataloader, val_dataloader, num_epochs, validate_every_n_epoch):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        for i, (inputs, labels) in enumerate(train_dataloader):
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = torch.nn.functional.cross_entropy(outputs, labels)
            fabric.backward(loss)  # Use Fabric's backward method for distributed backward pass
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        train_loss = running_loss / total
        train_accuracy = correct / total * 100.0
        print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%")

        if epoch % validate_every_n_epoch == 0:
            val_loss, val_accuracy = validate(fabric, model, val_dataloader)
            print(f"Validation - Loss: {val_loss:.4f}, Accuracy: {val_accuracy:.2f}%")

def validate(fabric, model, dataloader):
    model.eval()
    correct = 0
    total = 0
    running_loss = 0.0
    criterion = torch.nn.CrossEntropyLoss()
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloader):
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    accuracy = correct / total * 100.0
    return running_loss / total, accuracy

def test(fabric, model, dataloader):
    model.eval()
    correct = 0
    total = 0
    running_loss = 0.0
    criterion = torch.nn.CrossEntropyLoss()
    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloader):
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    accuracy = correct / total * 100.0
    print(f"Test Loss: {running_loss / total:.4f}, Test Accuracy: {accuracy:.2f}%")

def main():
    # Argument parsing
    parser = argparse.ArgumentParser(description="Train a CIFAR-10 model")
    parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training and validation')
    parser.add_argument('--num_workers', type=int, default=4, help='Number of workers for DataLoader')
    parser.add_argument('--num_epochs', type=int, default=1, help='Number of epochs for training')
    parser.add_argument('--learning_rate', type=float, default=0.01, help='Initial learning rate')
    parser.add_argument('--devices', type=int, default=1, help='Number of GPUs for training')
    parser.add_argument('--num_nodes', type=int, default=1, help='Number of compute nodes')
    parser.add_argument('--strategy', type=str, default='auto', help='Data parallel strategy')
    parser.add_argument('--val_size', type=float, default=0.1, help='Fraction of training data to use for validation')
    parser.add_argument('--validate_every_n_epoch', type=int, default=1, help='Validate every n epochs')
    args = parser.parse_args()

    data_dir = os.getenv("DATASETS_ROOT", "./data")
    dataset = load_data(data_dir)

    # Randomly split dataset into training and validation subsets
    train_dataset, val_dataset = split_train_val(dataset, args.val_size)

    # Initialize Fabric with distributed settings
    fabric = L.Fabric(
        accelerator="gpu",  # Specify the type of accelerator to use
        devices=args.devices,  # Number of devices per node
        num_nodes=args.num_nodes,  # Number of nodes
        strategy=args.strategy,  # Distributed training strategy
        precision="16-mixed"
    )
    fabric.launch()  # Launch the Fabric environment

    # Initialize ResNet18 model and optimizer with Fabric
    model = tv.models.resnet18()
    optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
    # Setup model and optimizer for distributed training
    model, optimizer = fabric.setup(model, optimizer)

    # Setup data loaders with DistributedSampler for proper data distribution
    train_sampler = DistributedSampler(train_dataset)
    val_sampler = DistributedSampler(val_dataset, shuffle=False)
    
    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, 
                                  sampler=train_sampler, num_workers=args.num_workers)
    val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, 
                                sampler=val_sampler, num_workers=args.num_workers)
    
    # Prepare dataloaders for distributed training
    train_dataloader = fabric.setup_dataloaders(train_dataloader)  
    val_dataloader = fabric.setup_dataloaders(val_dataloader)

    # Run training loop with validation
    train(fabric, model, optimizer, train_dataloader, 
          val_dataloader, args.num_epochs, 
          args.validate_every_n_epoch)

    # Test dataset
    transform = tv.transforms.Compose([
        tv.transforms.ToTensor(),
        tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    test_dataset = tv.datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform)
    test_sampler = DistributedSampler(test_dataset, shuffle=False)
    test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, 
                                 sampler=test_sampler)
    # Prepare dataloader for distributed training
    test_dataloader = fabric.setup_dataloaders(test_dataloader)

    # Test on unseen data
    test(fabric, model, test_dataloader)

if __name__ == "__main__":
    main()

This returns with two devices:

Epoch 1/1 - Train Loss: 2.1072, Train Accuracy: 34.96%
Epoch 1/1 - Train Loss: 2.0986, Train Accuracy: 34.78%
Validation - Loss: 1.5283, Accuracy: 45.60%
Validation - Loss: 1.4918, Accuracy: 44.92%
Files already downloaded and verified
Files already downloaded and verified
Test Loss: 1.5673, Test Accuracy: 43.10%
Test Loss: 1.5509, Test Accuracy: 43.26%