Creating custom LightningModule for Fine Tuning LLMs

I am creating a TextSummaryModule that takes a base model and fine tunes it

The code for the model definition

class TextSummaryModel(L.LightningModule):
    def __init__(self,model,
                     epochs=2):
        super(TextSummaryModel,self).__init__()
        self.model = model


    def set_model(self,model):
        self.model = model

    def forward(self, 
                input_ids, 
                attention_mask, 
                labels = None, 
                decoder_attention_mask = None):
        
        outputs = self.model(input_ids=input_ids,
                             attention_mask=attention_mask,
                             labels=labels,
                             decoder_attention_mask=decoder_attention_mask)

        return outputs.loss, outputs.logits

    def training_step(self,batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        decoder_attention_mask = batch["summary_mask"]

        loss , output = self(input_ids = input_ids,
                            attention_mask = attention_mask,
                            labels = labels,
                            decoder_attention_mask = decoder_attention_mask)

        return loss

    def validation_step(self , batch , batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        decoder_attention_mask = batch["summary_mask"]

        loss , output = self(input_ids = input_ids,
                            attention_mask = attention_mask,
                            labels = labels,
                            decoder_attention_mask = decoder_attention_mask)
        return loss

    def test_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        loss, output = self(input_ids=input_ids, 
                            attention_mask=attention_mask)
        return loss


    def configure_optimizers(self):
        optimizer = AdamW(self.model.parameters(), lr=0.0001)
        scheduler = get_linear_schedule_with_warmup(
                optimizer, num_warmup_steps=0,
                num_training_steps=epochs*total_documents)
        return {'optimizer': optimizer, 'lr_scheduler': scheduler}

When I try to load the model in this way

textsummarymodel = TextSummaryModel.load_from_checkpoint(
    trainer.checkpoint_callback.best_model_path
)
textsummarymodel.freeze()

I get this error

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/home/verma.shi/LLM/LitArt/notebooks/TextSummarizationPOC.ipynb Cell 6 line 1
----> 1 textsummarymodel = TextSummaryModel.load_from_checkpoint(
      2     trainer.checkpoint_callback.best_model_path
      3 )
      4 textsummarymodel.freeze()

File /work/LitArt/verma/capstone/lib/python3.11/site-packages/lightning/pytorch/utilities/model_helpers.py:125, in _restricted_classmethod_impl.__get__.<locals>.wrapper(*args, **kwargs)
    120 if instance is not None and not is_scripting:
    121     raise TypeError(
    122         f"The classmethod `{cls.__name__}.{self.method.__name__}` cannot be called on an instance."
    123         " Please call it on the class type and make sure the return value is used."
    124     )
--> 125 return self.method(cls, *args, **kwargs)

File /work/LitArt/verma/capstone/lib/python3.11/site-packages/lightning/pytorch/core/module.py:1581, in LightningModule.load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)
   1492 @_restricted_classmethod
   1493 def load_from_checkpoint(
   1494     cls,
   (...)
   1499     **kwargs: Any,
   1500 ) -> Self:
   1501     r"""Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments
   1502     passed to ``__init__``  in the checkpoint under ``"hyper_parameters"``.
...
--> 158 obj = cls(**_cls_kwargs)
    160 if isinstance(obj, pl.LightningDataModule):
    161     if obj.__class__.__qualname__ in checkpoint:

TypeError: TextSummaryModel.__init__() missing 1 required positional argument: 'model'