How do I convert different LightningModules?

I defined two different LightningModules, T5Finetuner and RankT5Finetuner. Both classes inherit from LightningModule and implement their respective training_step, validation_step, and so on.

I want to train T5Finetuner first, and continue to use the resulting model for RankT5Finetuner training. However, I do not know how to accomplish this step.

As shown in the following picture, i have compared two Fintuner-trained models and they have different names and the same structure.

T5Finetuner(
  (T5ForConditionalGeneration): ModifiedT5ForConditionalGeneration(
(shared)
(encoder)
(decoder)
(lm_head)
)
RankT5Finetuner(
  (T5ForConditionalGeneration): RankT5ForConditionalGeneration(
(shared)
(encoder)
(decoder)
(lm_head)
)

I have tried the following solutions:
first :
model=RankT5Finetuner.load_from_checkpoint('trained_by_T5Finetuner.ckpt')
it does not work , because the model obtained by T5Finetuner has hyperparameters and cannot be read directly with RankT5Finetuner.

second:

model=torch.load('trained_by_T5Finetuner.ckpt')
new_model=RankT5Finetuner.load_state_dict(model['state_dict'])

This does not work either, and the following error is reported:

Traceback (most recent call last):
  File "/root/.pycharm_helpers/pydev/pydevconsole.py", line 364, in runcode
    coro = func()
  File "<input>", line 1, in <module>
TypeError: load_state_dict() missing 1 required positional argument: 'state_dict'

This problem has been bothering me for a long time, I hope someone can help me :pleading_face: :pleading_face:

Hi @yczhangnaxin

This was almost working:

model=torch.load('trained_by_T5Finetuner.ckpt')
new_model=RankT5Finetuner.load_state_dict(model['state_dict'])

I think you should change it to:

checkpoint = torch.load('trained_by_T5Finetuner.ckpt')

new_model = RankT5Finetuner()  # <- pass in hyperparameters here
new_model.load_state_dict(checkpoint['state_dict'])

This should work if I don’t have a typo. This of course assumes your model definition in both classes is the same (the layers have the same names).

hi @awaelchli , thanks for advice. i follow your code, it still report the error:

Traceback (most recent call last):
  File "/root/.pycharm_helpers/pydev/pydevconsole.py", line 364, in runcode
    coro = func()
  File "<input>", line 1, in <module>
TypeError: load_state_dict() missing 1 required positional argument: 'state_dict'

hi @awaelchli , thank you for your help all the time. I found a solution that hope will help others with the same problem.

checkpoint=T5Finetuner.load_from_checkpoint('trained_by_T5Finetuner.ckpt')
checkpoint_weights=checkpoint.state_dict()
new_model=RankT5Finetuner('t5-small')
new_model.load_state_dict(checkpoint_weights)