Saving a LightningModule without a Trainer

Hi all,
I have a fairly specific use case. I am experimenting with neural architecture search, and am trying to convert my code from vanilla PyTorch to PL. My algorithm goes something like this:

  • Generate a neural net’s Blueprint in the form of a graph,
  • Instantiate a Model class object by translating the Blueprint into a sequence of nn.Module components wrapped in a LightningModule.
  • Train the Model, using checkpointing to save either the best-seen model (with or without early stopping) or the model at the end of the last epoch, depending on the stage of the algorithm where this training takes place.

I already implemented all of the above in PL, which allowed me to implement the DDP training strategy with multiple GPUs. Now, occasionally, I will take an existing Blueprint and apply random mutations: for instance, change a layer’s operation or its hyperparameters. When this happens to a Blueprint whose Model has already been trained, I want to retain the trained weights in the unmodified layers. For instance, if the model has four layers and only layer 2 is mutated, I want to keep the trained weights in layers 1, 3, and 4.

To do this, I instantiate the old and new Models, and compare their layers. If they are identically defined, I copy the weights over. The resulting model has pretrained weights in most layers, and random ones in the modified layers. I would now like to save this model to train it later, starting from these inherited weights.

This is where I start struggling, because as far as I can tell, saving a model’s weights with PL can only be done through a Trainer, which itself is not linked to a model until fit/test/predict is called. But as you can see, at that point, I haven’t yet trained the model. Also, with DDP, instantiating a Trainer and a fit/test/predict step is quite slow, so I don’t want to do that over hundreds of models.

What I would like is something like this:

During training (this is taken from a customer Evaluator class):

checkpoint_cb = ModelCheckpoint(
    filename=ntw.filename,
    dirpath=to_path,
    mode="max",
    monitor="val_acc",
    every_n_epochs=1,
    save_last=not self.interim_checkpoints,
    save_weights_only=True,
    verbose=True,
)
checkpoint_cb.FILE_EXTENSION = ""
checkpoint_cb.CHECKPOINT_NAME_LAST = ntw.filename

callbacks = [checkpoint_cb]

if self.early_stop:
    early_stop_cb = EarlyStopping(
        monitor="val_acc",
        min_delta=self.hparams["thresh"],
        patience=self.hparams["patience"],
        mode="max",
        log_rank_zero_only=True,
    )
    callbacks.append(early_stop_cb)

if pretrained:
    model = Model.load_from_checkpoint(
        os.path.join(from_path, ntw.filename),
        ntw=ntw,
        data_provider=self.data_provider,
        hparams=self.hparams,
    )
else:
    model = Model(ntw, self.data_provider, hparams=self.hparams)
    model.random_init()

trainer = Trainer(
    default_root_dir=self.to_path,
    accelerator='gpu',
    strategy='ddp',
    callbacks=callbacks,
    max_epochs=max_epochs,
    gradient_clip_val=self.hparams["grad_norm_clip"],
    gradient_clip_algorithm="norm",
    logger=self.logger,
    check_val_every_n_epoch=1,
    num_sanity_val_steps=0,
    enable_model_summary=False,
)

trainer.fit(model)
[...]

And during mutation:

# old_model has already been trained:
old_model = Model.load_from_checkpoint(
        os.path.join(from_path, ntw.filename),
        ntw=old_ntw,
        data_provider=self.data_provider,
        hparams=self.hparams,
)
# Create a new Model from the mutated blueprint
new_model = Model(
        ntw=new_ntw,
        data_provider=self.data_provider,
        hparams=self.hparams,
)
for layer_i in range(len(new_ntw)):
    if new_ntw[i] == old_ntw[i]:
        copy_parameters(old_model, new_model, layer_i)
    else:
        random_init(new_model, layer_i)

< save new_model weights >

I cannot figure out how to perform the “save new_model weights” step in a format that is identical to that produced by ModelCheckpoint and compatible with LightningModule.load_from_checkpoint(). Possible solutions:

  • Since I don’t really care about the model’s or the trainer’s states, I could override ModelCheckpoint to only keep the PyTorch modules’ state_dict, but how will that play with DDP?
  • Keep using PL’s mechanics, and extend them to the ‘save’ step by instantiating a Trainer, running a validation step (for instance) to attach a model to it, and use trainer.save_checkpoint(), but as I said above, I fear this will be slow if I specify DDP as the trainer’s strategy. And if I don’t, I’m not sure what will happen when I subsequently load this model to train it with DDP.

I would appreciate any ideas!
Thanks in advance