Data collate_fn makes training process super slow!

Dear friends, this is my custom dataset:

class CustomDataset(Dataset):
...
def tokenizer(self,...):
    x = tokenize(
    padding = "max_length"
    ...
    )
    y = tokenize(
    padding = "max_length"
    ...
    )
def __getitem__(self,...)
...
    return {
                "input_ids": x.input_ids.squeeze(),
                "attention_mask": x.attention_mask.squeeze(),
                "target_ids": y.input_ids.squeeze(),
                "target_mask": y.attention_mask.squeeze(),
    }

and the train_dataloader function in my model:

dataset = CustomDataset(...)
return DataLoader(
    dataset=...,
    batchsize=...,
    ....
)

This code is probably working fine. However, I want to customize a Dataloader a little bit using dynamic padding and add additional collate_fn to the Dataloader. Then I remove padding = “max_length” in CustomDataset, and add the collate_fn as follows:

class DataCollatorForSeq2Seq:
    tokenizer: PreTrainedTokenizerBase
    model: Optional[Any] = None
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    label_pad_token_id: int = -100
    return_tensors: str = "pt"

    def __call__(self,features,return_tensors=None):
        # print("___"*30)
        # print("Feature in COLLATOR:")
        # print(features)
        start = time.time()
        if return_tensors is None:
            return_tensors = self.return_tensors
        
        target_ids = [feature["target_ids"] for feature in features] if "target_ids" in features[0].keys() else None
        # print(len(target_ids))
        if target_ids is not None:
            max_target_id_length = max(len(l) for l in target_ids)
            if self.pad_to_multiple_of is not None:
                max_target_id_length = (
                    (max_target_id_length + self.pad_to_multiple_of - 1)
                    // self.pad_to_multiple_of
                    * self.pad_to_multiple_of
                )
            # print(max_target_id_length)
            padding_side = self.tokenizer.padding_side
            for feature in features:
                remainder = [self.label_pad_token_id] * (max_target_id_length - len(feature["target_ids"]))
                remainder_attetion_mask = [0] * (max_target_id_length - len(feature["target_ids"]))
                if isinstance(feature["target_ids"], list):
                    feature["target_ids"] = (
                        feature["target_ids"] + remainder if padding_side == "right" else remainder + feature["target_ids"]
                    )
                    feature["target_mask"] = (feature["target_mask"] + remainder_attetion_mask if padding_side == "right" else remainder_attetion_mask + feature["target_mask"])
                elif padding_side == "right":
                    feature["target_ids"] = np.concatenate([feature["target_ids"], remainder]).astype(np.int64)
                    feature["target_mask"] = np.concatenate([feature["target_mask"], remainder_attetion_mask]).astype(np.int64)
                else:
                    feature["target_ids"] = np.concatenate([remainder, feature["target_ids"]]).astype(np.int64)
                    feature["target_mask"] = np.concatenate([remainder, feature["target_mask"]]).astype(np.int64)
        features = self.tokenizer.pad(
            features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=return_tensors,
        )
        end = time.time()
        print("_"*20)
        print("Time taken for collator: ", end-start)
        return features

The training process works without bugs. However, the execution time for the model to predict 1 batch is super slow, as shown in the following picture:


It takes nearly 230s for 1 batch since the code without the collate_fn function works pretty well with only 0.1s for 1 batch.
I have tried to pass num_workers=0 to DataLoader, but it doesn’t work.
Any comment and explanation would be helpful for me to improve my code. Thank you for your help.