Question about how Fabric models are saved/loaded

Hi again Fabric team. Slowly been getting used to working with this library, and loving what you guys have done with it more every time I use it. It seems intuitively-written code is the rarest element in the universe but you guys are sitting on a huge cache of it somehow lol. Anyway…
I’m sure this is covered in the docs somewhere and I’m just blind but I can’t seem to find it. Suppose I’m training a model across 8 GPUs with DDP through Fabric. Ideally what I would like to do is, on each epoch e, check whether the val loss on e is lower than the minimum vloss seen so far during training. If so, get the state dict in a form which I can load into a standard vanilla PyTorch nn.Module later, store it as a variable, then at the end of training, save that state dict using pickle.dump. I know this is a little unorthodox, but if it’s possible, it just makes my life a little easier due to some idiosyncrasies with the setup of this code that I’m currently integrating Fabric into.
Failing that, I am happy to use Fabric.save( … ) or whatever built in methods there to do this - the important thing here is just that I need to be able to take whatever’s been saved and load it later as something that can be used to set the weights of a standard PyTorch Module to do inference with. I don’t need any fancy stuff like the state of the optimizer or any other checkpoint information etc etc - all I need is the raw state dict in a form that can be used with an nn.Module outside of Fabric later.
Forgive me if this is in some really obvious place in the docs, I swear I saw something about doing this last time I was looking through them but I cannot for the life of me find it now.
Thanks in advance, keep up the awesome work guys :]

1 Like

@Clion

Thanks for the kind words, glad you find it useful!
That’s an excellent answer. For the record, here are the docs for saving and loading in Fabric. They are quite minimal, so I think we should take your feedback into account to make that page better.

Let me walk you through your use case.
First of all, important to know is that for Fabric.save, there is almost no magic happening under the hood. For most strategies (single device, ddp) it essentially does this:

if fabric.global_rank == 0:
    torch.save(state, filename)
fabric.barrier()

There are some optimizations for other strategies like DeepSpeed and FSDP that make the saving of huge models more efficient. In any case, Fabric.save() and Fabric.load() are completely optional.

Ideally what I would like to do is, on each epoch e, check whether the val loss on e is lower than the minimum vloss seen so far during training. If so, get the state dict in a form which I can load into a standard vanilla PyTorch nn.Module later , store it as a variable, then at the end of training, save that state dict using pickle.dump.

Lets’ do that:

model = MyPyTorchModel()
state = {"model": model}
...
if val_loss < best_val_loss:
    fabric.save(path, state)
    best_val_loss = val_loss

# later on in the same script, load state dict in-place:
fabric.load(path, state)

# Alternatively:
checkpoint = torch.load(path)
model.load_state_dict(checkpoint["model"])

The equivalent code without calling the Fabric methods would be:

model = MyPyTorchModel()
state = {"model": model}
...
if val_loss < best_val_loss:
    if fabric.global_rank == 0:
        torch.save(state, path)
    fabric.barrier()
    best_val_loss = val_loss

# Later
checkpoint = torch.load(path)
model.load_state_dict(checkpoint["model"])

the important thing here is just that I need to be able to take whatever’s been saved and load it later as something that can be used to set the weights of a standard PyTorch Module to do inference with. I don’t need any fancy stuff like the state of the optimizer or any other checkpoint information etc etc - all I need is the raw state dict in a form that can be used with an nn.Module outside of Fabric later.

With the above code I hope I made it clearer that:
a) Fabric saves a regular torch-pickle file that contains only the objects you put in, nothing more.
b) You can load the checkpoint with or without Fabric and there are no dependencies to Fabric in the checkpoint format.

Would adding a section to the docs for loading a Fabric model into non-Fabric code base help?

cheers

1 Like

Thanks, @awaelchli - very helpful answer, clarifies the issue for me 100%.

Regarding what’s in the docs, here’s my perspective as a novice Fabric user. Correct me if I’m wrong, but I think Fabric’s primary purpose is to make the distributed training of absolutely enormous models easier, so fair warning here, my use case has nothing to do with that…except for the ‘distributed training’ part. I’m training a pretty small model (about 550k parameters iirc), but I have an utterly enormous (and growing) amount of data, so iterating through it on a single GPU was just taking forever. I had never done distributed nn training before, so as I began looking into the process, it looked like Fabric was the most straightforward drop-in solution to make that work - and so far I’ve certainly had to spend less time on making it work for me than I likely would have had to if I were directly using Torch’s DDP interface.
I think there are probably quite a few people out there who are in a position similar to mine - with access to cheap multi-GPU compute time, there’s really not much of a reason to not go distributed if you have a sufficiently large dataset, even if you’re not using a super-heavyweight model. Fabric as an interface wrapper around DDP really makes this very easy. So far, the only thing that’s caused me any real headache at all is determining exactly where, as you put it, ‘magic’ happens inside Fabric. For cases such as mine where what I’m looking for is an as-out-of-the-box-as-possible solution to turn a serial training process into a distributed one, then take the model produced and use it as a drop-in replacement for a serially-trained model, Fabric is a very easy-to-use solution, with the sole exception that it’s not made entirely clear where it introduces incompatibilities with the serial code it replaces.

Which is just a very verbose way of saying “yes” to “Would adding a section to the docs for loading a Fabric model into non-Fabric code base help?” lol. Think what you guys have built could, with very minimal additions/changes (mainly to the docs), be used as a sort of data-distributed training black box for small models where users can feed in data exactly as they would to a serial process, not really have to worry about what’s going on under the hood, and be able to expect an output that they can use with [approx.] 0 changes in place of one that’s been trained in a typical serial setting…I mean, it’s basically this already anyway lol. I suspect that the number of users out there who would love to make use of a pipeline like this is uh…very large.

Anyway thanks again for the reply, very helpful :+1:

Uh actually, one final clarification, just to be absolutely sure I’ve got this right. Say I’m currently doing the following:

aNet = ANet( ... ) #Standard PyTorch nn.Module
opt  = torch.optim.Adam( aNet.parameters(), ... )
nGPU = torch.cuda.device_count()
Fab  = Fabric( accelerator='cuda', devices=nGPU, strategy='ddp', ... )
Fab.launch()

aNet,opt = Fab.setup( aNet,opt ) 
#variable aNet now points ^here, i.e. not to the aNet stored in state={...}?

...

if val_loss < best_val_loss:
    if fabric.global_rank == 0:
        torch.save( state, path )
        #So then is what's saved ^here still the un-updated aNet?
    fabric.barrier()
    best_val_loss = val_loss

Is the model that goes in state = {"model": model} the one declared with the standard ANet(...) constructor, or the one output by Fabric.setup( ... )?

Since your solution puts the declaration of state = {"model": model} right after MyPyTorchModel(), I’d guess the answer is the first one. This is interesting to me, as I would have thought that upon re-assigning aNet to the first output of Fabric.setup, and the Fabric training process updating the parameters of this new aNet object that this variable now points to, the thing that the aNet living in state = { ... } points to would not have its parameters updated. Where’s the error in my reasoning here?

You don’t have to worry about it. In either case, Fabric.save() will take care of unwrapping the model. So you will always get the pure PyTorch model state in the checkpoint dictionary. So you can save a checkpoint before or after calling fabric.setup(). It will result in the same checkpoint.