Hi all,
I have a fairly specific use case. I am experimenting with neural architecture search, and am trying to convert my code from vanilla PyTorch to PL. My algorithm goes something like this:
- Generate a neural net’s Blueprint in the form of a graph,
- Instantiate a Model class object by translating the Blueprint into a sequence of nn.Module components wrapped in a LightningModule.
- Train the Model, using checkpointing to save either the best-seen model (with or without early stopping) or the model at the end of the last epoch, depending on the stage of the algorithm where this training takes place.
I already implemented all of the above in PL, which allowed me to implement the DDP training strategy with multiple GPUs. Now, occasionally, I will take an existing Blueprint and apply random mutations: for instance, change a layer’s operation or its hyperparameters. When this happens to a Blueprint whose Model has already been trained, I want to retain the trained weights in the unmodified layers. For instance, if the model has four layers and only layer 2 is mutated, I want to keep the trained weights in layers 1, 3, and 4.
To do this, I instantiate the old and new Models, and compare their layers. If they are identically defined, I copy the weights over. The resulting model has pretrained weights in most layers, and random ones in the modified layers. I would now like to save this model to train it later, starting from these inherited weights.
This is where I start struggling, because as far as I can tell, saving a model’s weights with PL can only be done through a Trainer, which itself is not linked to a model until fit/test/predict is called. But as you can see, at that point, I haven’t yet trained the model. Also, with DDP, instantiating a Trainer and a fit/test/predict step is quite slow, so I don’t want to do that over hundreds of models.
What I would like is something like this:
During training (this is taken from a customer Evaluator class):
checkpoint_cb = ModelCheckpoint(
filename=ntw.filename,
dirpath=to_path,
mode="max",
monitor="val_acc",
every_n_epochs=1,
save_last=not self.interim_checkpoints,
save_weights_only=True,
verbose=True,
)
checkpoint_cb.FILE_EXTENSION = ""
checkpoint_cb.CHECKPOINT_NAME_LAST = ntw.filename
callbacks = [checkpoint_cb]
if self.early_stop:
early_stop_cb = EarlyStopping(
monitor="val_acc",
min_delta=self.hparams["thresh"],
patience=self.hparams["patience"],
mode="max",
log_rank_zero_only=True,
)
callbacks.append(early_stop_cb)
if pretrained:
model = Model.load_from_checkpoint(
os.path.join(from_path, ntw.filename),
ntw=ntw,
data_provider=self.data_provider,
hparams=self.hparams,
)
else:
model = Model(ntw, self.data_provider, hparams=self.hparams)
model.random_init()
trainer = Trainer(
default_root_dir=self.to_path,
accelerator='gpu',
strategy='ddp',
callbacks=callbacks,
max_epochs=max_epochs,
gradient_clip_val=self.hparams["grad_norm_clip"],
gradient_clip_algorithm="norm",
logger=self.logger,
check_val_every_n_epoch=1,
num_sanity_val_steps=0,
enable_model_summary=False,
)
trainer.fit(model)
[...]
And during mutation:
# old_model has already been trained:
old_model = Model.load_from_checkpoint(
os.path.join(from_path, ntw.filename),
ntw=old_ntw,
data_provider=self.data_provider,
hparams=self.hparams,
)
# Create a new Model from the mutated blueprint
new_model = Model(
ntw=new_ntw,
data_provider=self.data_provider,
hparams=self.hparams,
)
for layer_i in range(len(new_ntw)):
if new_ntw[i] == old_ntw[i]:
copy_parameters(old_model, new_model, layer_i)
else:
random_init(new_model, layer_i)
< save new_model weights >
I cannot figure out how to perform the “save new_model weights” step in a format that is identical to that produced by ModelCheckpoint and compatible with LightningModule.load_from_checkpoint(). Possible solutions:
- Since I don’t really care about the model’s or the trainer’s states, I could override ModelCheckpoint to only keep the PyTorch modules’ state_dict, but how will that play with DDP?
- Keep using PL’s mechanics, and extend them to the ‘save’ step by instantiating a Trainer, running a validation step (for instance) to attach a model to it, and use trainer.save_checkpoint(), but as I said above, I fear this will be slow if I specify DDP as the trainer’s strategy. And if I don’t, I’m not sure what will happen when I subsequently load this model to train it with DDP.
I would appreciate any ideas!
Thanks in advance