Hi, I cannot find the answer to this in the docs. When I train on multiple GPUs, and run the test, I get two outputs for performance metrics, one for each GPU. How can I ensure results from all GPUs are combined to report the performance for the whole validation and test set?
import os
import torch
import torchvision as tv
import lightning as L
import argparse
from torch.utils.data import DataLoader, Subset, DistributedSampler
import numpy as np
import torch.distributed as dist
def load_data(data_dir):
transform = tv.transforms.Compose([
tv.transforms.ToTensor(),
tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = tv.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform)
return dataset
def split_train_val(dataset, val_size):
""" Randomly split the dataset into training and validation subsets """
indices = np.arange(len(dataset))
np.random.shuffle(indices)
split = int(np.floor(val_size * len(dataset)))
train_indices, val_indices = indices[split:], indices[:split]
train_subset = Subset(dataset, train_indices)
val_subset = Subset(dataset, val_indices)
return train_subset, val_subset
def train(fabric, model, optimizer, train_dataloader, val_dataloader, num_epochs, validate_every_n_epoch):
model.train()
for epoch in range(num_epochs):
running_loss = 0.0
correct = 0
total = 0
for i, (inputs, labels) in enumerate(train_dataloader):
optimizer.zero_grad()
outputs = model(inputs)
loss = torch.nn.functional.cross_entropy(outputs, labels)
fabric.backward(loss) # Use Fabric's backward method for distributed backward pass
optimizer.step()
running_loss += loss.item() * inputs.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
train_loss = running_loss / total
train_accuracy = correct / total * 100.0
print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%")
if epoch % validate_every_n_epoch == 0:
val_loss, val_accuracy = validate(fabric, model, val_dataloader)
print(f"Validation - Loss: {val_loss:.4f}, Accuracy: {val_accuracy:.2f}%")
def validate(fabric, model, dataloader):
model.eval()
correct = 0
total = 0
running_loss = 0.0
criterion = torch.nn.CrossEntropyLoss()
with torch.no_grad():
for i, (inputs, labels) in enumerate(dataloader):
outputs = model(inputs)
loss = criterion(outputs, labels)
running_loss += loss.item() * inputs.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
accuracy = correct / total * 100.0
return running_loss / total, accuracy
def test(fabric, model, dataloader):
model.eval()
correct = 0
total = 0
running_loss = 0.0
criterion = torch.nn.CrossEntropyLoss()
with torch.no_grad():
for i, (inputs, labels) in enumerate(dataloader):
outputs = model(inputs)
loss = criterion(outputs, labels)
running_loss += loss.item() * inputs.size(0)
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
accuracy = correct / total * 100.0
print(f"Test Loss: {running_loss / total:.4f}, Test Accuracy: {accuracy:.2f}%")
def main():
# Argument parsing
parser = argparse.ArgumentParser(description="Train a CIFAR-10 model")
parser.add_argument('--batch_size', type=int, default=128, help='Batch size for training and validation')
parser.add_argument('--num_workers', type=int, default=4, help='Number of workers for DataLoader')
parser.add_argument('--num_epochs', type=int, default=1, help='Number of epochs for training')
parser.add_argument('--learning_rate', type=float, default=0.01, help='Initial learning rate')
parser.add_argument('--devices', type=int, default=1, help='Number of GPUs for training')
parser.add_argument('--num_nodes', type=int, default=1, help='Number of compute nodes')
parser.add_argument('--strategy', type=str, default='auto', help='Data parallel strategy')
parser.add_argument('--val_size', type=float, default=0.1, help='Fraction of training data to use for validation')
parser.add_argument('--validate_every_n_epoch', type=int, default=1, help='Validate every n epochs')
args = parser.parse_args()
data_dir = os.getenv("DATASETS_ROOT", "./data")
dataset = load_data(data_dir)
# Randomly split dataset into training and validation subsets
train_dataset, val_dataset = split_train_val(dataset, args.val_size)
# Initialize Fabric with distributed settings
fabric = L.Fabric(
accelerator="gpu", # Specify the type of accelerator to use
devices=args.devices, # Number of devices per node
num_nodes=args.num_nodes, # Number of nodes
strategy=args.strategy, # Distributed training strategy
precision="16-mixed"
)
fabric.launch() # Launch the Fabric environment
# Initialize ResNet18 model and optimizer with Fabric
model = tv.models.resnet18()
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate)
# Setup model and optimizer for distributed training
model, optimizer = fabric.setup(model, optimizer)
# Setup data loaders with DistributedSampler for proper data distribution
train_sampler = DistributedSampler(train_dataset)
val_sampler = DistributedSampler(val_dataset, shuffle=False)
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False,
sampler=train_sampler, num_workers=args.num_workers)
val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False,
sampler=val_sampler, num_workers=args.num_workers)
# Prepare dataloaders for distributed training
train_dataloader = fabric.setup_dataloaders(train_dataloader)
val_dataloader = fabric.setup_dataloaders(val_dataloader)
# Run training loop with validation
train(fabric, model, optimizer, train_dataloader,
val_dataloader, args.num_epochs,
args.validate_every_n_epoch)
# Test dataset
transform = tv.transforms.Compose([
tv.transforms.ToTensor(),
tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
test_dataset = tv.datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform)
test_sampler = DistributedSampler(test_dataset, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False,
sampler=test_sampler)
# Prepare dataloader for distributed training
test_dataloader = fabric.setup_dataloaders(test_dataloader)
# Test on unseen data
test(fabric, model, test_dataloader)
if __name__ == "__main__":
main()
This returns with two devices:
Epoch 1/1 - Train Loss: 2.1072, Train Accuracy: 34.96%
Epoch 1/1 - Train Loss: 2.0986, Train Accuracy: 34.78%
Validation - Loss: 1.5283, Accuracy: 45.60%
Validation - Loss: 1.4918, Accuracy: 44.92%
Files already downloaded and verified
Files already downloaded and verified
Test Loss: 1.5673, Test Accuracy: 43.10%
Test Loss: 1.5509, Test Accuracy: 43.26%