I use PyTorch Lightning with TorchElastic. My training function looks like this:
import pytorch_lightning as pl
# Each train() call runs as a single worker (one process)
def train(config: InputConfig):
checkpoint_callback = pl.callbacks.ModelCheckpoint(...)
module = MyLightningModule(config)
trainer = pl.Trainer(num_nodes=..., gpus=..., checkpoint_callback=checkpoint_callback, ...)
trainer.fit(module)
return Results(...)
and I leverage torchelastic with the following
import torchelastic.distributed.local_launch as pet
...
def elastic_train(config: InputConfig):
lc = pet.LocalLaunchConfig(
# Assuming devgpu testing, min = max nodes = 1
min_nodes=1,
max_nodes=1,
nproc_per_node=cfg.trainer.gpus if cfg.trainer.gpus else 1,
# run_id just has to be globally unique
run_id=f"your_run_identifier_{uuid4()}",
# for fault tolerance; for testing set it to 0 (no fault tolerance)
max_restarts=0,
function_start_method="spawn",
)
# The "train" function is called inside the elastic_launch
ret = pet.elastic_launch(lc, fn=train)(config)
print(f"Rank 0 results = {ret[0]}")
def main(config: InputConfig):
train_elastic(config)
Sometimes training can fail. At this point, I’d like to resume training from the latest checkpoint. However, I don’t know the path to the latest checkpoint ahead of time, since my train function is wrapped by torchelastic, This means resume_from_checkpoint
may not work for this use case because I don’t know the full path to the checkpoint ahead of time for when training fails, or if there’s even a valid path at all.