I am creating a TextSummaryModule that takes a base model and fine tunes it
The code for the model definition
class TextSummaryModel(L.LightningModule):
def __init__(self,model,
epochs=2):
super(TextSummaryModel,self).__init__()
self.model = model
def set_model(self,model):
self.model = model
def forward(self,
input_ids,
attention_mask,
labels = None,
decoder_attention_mask = None):
outputs = self.model(input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
decoder_attention_mask=decoder_attention_mask)
return outputs.loss, outputs.logits
def training_step(self,batch, batch_idx):
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
labels = batch["labels"]
decoder_attention_mask = batch["summary_mask"]
loss , output = self(input_ids = input_ids,
attention_mask = attention_mask,
labels = labels,
decoder_attention_mask = decoder_attention_mask)
return loss
def validation_step(self , batch , batch_idx):
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
labels = batch["labels"]
decoder_attention_mask = batch["summary_mask"]
loss , output = self(input_ids = input_ids,
attention_mask = attention_mask,
labels = labels,
decoder_attention_mask = decoder_attention_mask)
return loss
def test_step(self, batch, batch_idx):
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
loss, output = self(input_ids=input_ids,
attention_mask=attention_mask)
return loss
def configure_optimizers(self):
optimizer = AdamW(self.model.parameters(), lr=0.0001)
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=0,
num_training_steps=epochs*total_documents)
return {'optimizer': optimizer, 'lr_scheduler': scheduler}
When I try to load the model in this way
textsummarymodel = TextSummaryModel.load_from_checkpoint(
trainer.checkpoint_callback.best_model_path
)
textsummarymodel.freeze()
I get this error
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/home/verma.shi/LLM/LitArt/notebooks/TextSummarizationPOC.ipynb Cell 6 line 1
----> 1 textsummarymodel = TextSummaryModel.load_from_checkpoint(
2 trainer.checkpoint_callback.best_model_path
3 )
4 textsummarymodel.freeze()
File /work/LitArt/verma/capstone/lib/python3.11/site-packages/lightning/pytorch/utilities/model_helpers.py:125, in _restricted_classmethod_impl.__get__.<locals>.wrapper(*args, **kwargs)
120 if instance is not None and not is_scripting:
121 raise TypeError(
122 f"The classmethod `{cls.__name__}.{self.method.__name__}` cannot be called on an instance."
123 " Please call it on the class type and make sure the return value is used."
124 )
--> 125 return self.method(cls, *args, **kwargs)
File /work/LitArt/verma/capstone/lib/python3.11/site-packages/lightning/pytorch/core/module.py:1581, in LightningModule.load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)
1492 @_restricted_classmethod
1493 def load_from_checkpoint(
1494 cls,
(...)
1499 **kwargs: Any,
1500 ) -> Self:
1501 r"""Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments
1502 passed to ``__init__`` in the checkpoint under ``"hyper_parameters"``.
...
--> 158 obj = cls(**_cls_kwargs)
160 if isinstance(obj, pl.LightningDataModule):
161 if obj.__class__.__qualname__ in checkpoint:
TypeError: TextSummaryModel.__init__() missing 1 required positional argument: 'model'