Is there a recommend way to save a model mid-epoch using fabric, when training on multiple nodes/devices? (i.e. save after n training steps, instead of at the end of an epoch). I’m currently trying the following:
if fabric.global_rank == 0:
if num_steps % 100 == 0:
state = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"epoch": epoch,
"step": num_steps,
}
fabric.save(checkpoint_path, state)
However the training seems to hang after the checkpoint is saved?