After training and validation for 1st epoch, my colab runtime crashes with RAM-memory full notification. I am searching of ways to reduce but nothing works, here’s my code for the lightning module
class Bert(pl.LightningModule):
def __init__(self):
super().__init__()
self.bert = BertModel.from_pretrained("/content/bert-base-uncased-hate")
self.bert.trainable = False
self.vit = timm.models.vit_base_patch16_224_in21k(pretrained=True,num_classes = 0)
self.bFc1 = nn.Bilinear(768,768,512)
self.classifier = nn.Sequential(nn.BatchNorm1d(512),
nn.Linear(512,1))
def forward(self, input_ids, images):
img = self.vit(images)
text = self.bert(input_ids=input_ids)
repr = self.bFc1(img,text['pooler_output'])
return self.classifier(repr)
def training_step(self, batch ,batch_idx):
input_ids = batch["input_ids"]
images = batch["images"]
labels = batch['labels']
outputs = self(input_ids=input_ids, images=images)
loss = torch.nn.functional.binary_cross_entropy_with_logits(outputs,labels)
acc = ((outputs>0.5).int()==labels).sum()/labels.size()[0]s
return {"loss": loss}
def validation_step(self, batch,batch_idx):
input_ids = batch["input_ids"]
images = batch["images"]
labels = batch['labels']
outputs = self(input_ids=input_ids, images=images)
loss = torch.nn.functional.binary_cross_entropy_with_logits(outputs,labels)
acc = ((outputs>0.5).int()==labels).sum()/labels.size()[0]
self.log("loss",loss)
self.log("accuracy",acc)
def configure_optimizers(self):
return AdamW(self.parameters(), lr=1e-5)
model = Bert()
And here’s my trainer:
trainer = pl.Trainer(gpus=1,max_epochs=10,accumulate_grad_batches=4,num_sanity_val_steps=10)
trainer.fit(model, train_loader, val_loader)
Please tell me what’s the fault in the code that I’m going through