Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0!

Hi!
I am trying to run the below code, using ddp strategy on 2 gpus

# using a base model here for unsupervised trial.
model_id = "meta-llama/Llama-2-7b-chat-hf"
cache_dir='.'

local_path = model_id
local_save_path = f"{cache_dir}/{local_path}"

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch

import json
import random
from datasets import Dataset
import torch.nn as nn

from torch.utils.data import DataLoader, Dataset
import torch
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from torch.optim.lr_scheduler import CosineAnnealingLR
import pytorch_lightning as pl
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.optim.lr_scheduler import CosineAnnealingLR
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import Trainer
import os

dataset = []

with open('data.json') as f:
    lines = f.readlines()
    random.shuffle(lines)
    random.shuffle(lines)

    for line in lines:
        json_line = json.loads(line)
        input = json_line['input']
        output = json_line['output']

        formatted_question = f"<s>[INST]{input}[/INST]{output}</s>"

        formatted_question_obj = {'input': input, "output": output}

        dataset.append(formatted_question_obj)



hf_dataset = Dataset.from_list(dataset)


class TextDataset(Dataset):
    def __init__(self, encodings, response_lengths, input_lengths):
        self.encodings = encodings
        self.response_lengths = response_lengths
        self.input_lengths = input_lengths

    def __getitem__(self, idx):
        if isinstance(idx, int):
            # print(f"__getitem__ called with index {idx}")
            item = {key: val[idx].clone().detach() for key, val in self.encodings.items()}
            response_start_position = self.input_lengths[idx]
            response_end_position = self.input_lengths[idx] + self.response_lengths[idx]
        elif isinstance(idx, list):
            # print(f"__getitem__ called with list {idx}")
            item = {key: torch.stack([val[i].clone().detach() for i in idx]) for key, val in self.encodings.items()}
            response_start_position = [self.input_lengths[i] for i in idx]
            response_end_position = [self.input_lengths[i] + self.response_lengths[i] for i in idx]

        item["labels"] = item["input_ids"].clone()

        item["labels"][:-1] = item["input_ids"][1:]

        # Replace the token after the response with an EOS token
        item["labels"][response_end_position - 1] = 2

        # Replace the token after the response with an 1 in the loss mask

        return item

    def __len__(self):
        return len(self.encodings["input_ids"])


data_length = 500


def prepare_dataset(dataset, tokenizer):
    # Define the roles and markers
    B_SYS, E_SYS = "<<SYS>>", "<</SYS>>"
    B_INST, E_INST = "[INST]", "[/INST]"

    formatted_dataset = dataset.map(
        lambda x: {
            "input_text": "".join([
                # f"{B_INST} {B_SYS}{system_prompt.strip()}{E_SYS}{x['prompt'].strip()} {E_INST}\n\n"
                f"{B_INST} {x['input'].strip()} {E_INST}"
                f"{x['output'].strip()}",  # appending the EOS token in TextData...
            ]),
            "response_text": "".join([
                f"{x['output'].strip()}",  # appending the EOS token in TextData...
            ]),
        }
    )

    # Tokenize the datasets
    encodings = tokenizer([dialogue["input_text"] for dialogue in formatted_dataset], truncation=True, padding=True, max_length=data_length, return_tensors='pt', add_special_tokens=True)

    # Tokenize the response one by one without padding and special tokens for the purpose of calculating length
    response_lengths = [len(tokenizer.encode(dialogue["response_text"], truncation=True, max_length=data_length, padding=False, add_special_tokens=False)) for dialogue in formatted_dataset]

    # Tokenize the input one by one without padding and with the initial special token for the purpose of calculating length
    total_lengths = [len(tokenizer.encode(dialogue["input_text"], truncation=True, max_length=data_length, padding=False, add_special_tokens=True)) for dialogue in formatted_dataset]
    input_lengths = [total_length - response_length for total_length, response_length in zip(total_lengths, response_lengths)]

    # Create TextDataset
    text_dataset = TextDataset(encodings, response_lengths, input_lengths)

    return text_dataset


tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

train_dataset = prepare_dataset(hf_dataset, tokenizer)

from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from torch.optim.lr_scheduler import CosineAnnealingLR
import pytorch_lightning as pl
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.optim.lr_scheduler import CosineAnnealingLR
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import Trainer
import os
wandb_logger = WandbLogger(name='personal_assistant2', project='MyProject2')

total_steps = (len(dataset) // 4) * 10

class ConversationalModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = AutoModelForCausalLM.from_pretrained(local_save_path, quantization_config=BitsAndBytesConfig(load_in_4bit=True, device_map='auto', bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16))
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.tokenizer.pad_token = self.tokenizer.eos_token


    def forward(self, input_ids, attention_mask=None, labels=None):
        output = self.model(input_ids, attention_mask=attention_mask, labels=labels)
        return output

    def training_step(self, batch, batch_idx):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']
        outputs = self(input_ids, attention_mask, labels)
       
        return outputs.loss
    
    def configure_optimizers(self):

        optimizer = torch.optim.AdamW(self.model.parameters(), lr=2e-3)
        scheduler = CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=0)
        
        # In PyTorch Lightning, schedulers require a dict for configuration
        scheduler_config = {
            'scheduler': scheduler,
            'interval': 'step',  # or 'epoch' for epoch-wise scheduling
            'frequency': 1,  # determines how often the scheduler is updated
            'monitor': 'train_loss',  # metric to monitor for other types of schedulers
            'strict': True,  # whether to enforce that the `monitor` value is available
        }
        
        return [optimizer], [scheduler_config]

class DataModule(pl.LightningDataModule):
    def __init__(self):
        super().__init__()
        self.num_workers = os.cpu_count()
        print(f"num_workers set to {self.num_workers}")

    def setup(self, stage=None) -> None:
        self._dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)

    def train_dataloader(self):
        return self._dataloader

lr_monitor = LearningRateMonitor(logging_interval='step')

checkpoint_callback = ModelCheckpoint(
    monitor='train_loss',
    dirpath='checkpoints/',
    filename='best-checkpoint',
    save_top_k=1,
    mode='min',
    verbose=True,
)

if __name__ == '__main__':
    trainer = Trainer(
        max_epochs=10,
        precision='32',
        callbacks=[checkpoint_callback, lr_monitor],
        logger=wandb_logger,
        accelerator="gpu",
        devices=2,
        num_nodes=1,
        strategy='ddp'
    )
    
    dm = DataModule()
    model = ConversationalModel()
    trainer.fit(model, dm)

I store the code in file and run it in jupyter notebook via

python train.py

I get the below error:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

Any help would be greatly appreciated