RAM usage increases quickly over the training step

Why the RAM usage increases quickly over the training step?

Can I see your full ULSEModel? This issue description might help you debug: Add a docs peformance section · Issue #12398 · Lightning-AI/lightning · GitHub

Sure, the ULSEModel looks like this:

import torch
from pytorch_lightning.core.lightning import LightningModule
from torchmetrics import MetricCollection, F1Score

class ULSEModel(LightningModule):

    def __init__(self, hparams):
        super(ULSEModel, self).__init__()

        # encoders
        self.embeddings = torch.nn.Embedding(hparams.vocab_size, hparams.embedding_dim)

        # classification head
        self.cls_head = torch.nn.Sequential(
            torch.nn.Linear(hparams.rpr_dim, hparams.vocab_size),

        # metrics
        self.val_metrics = self._get_metrics(prefix="val_")

        # loss function
        self.loss = torch.nn.NLLLoss()

    def _get_metrics(self, prefix):
        return MetricCollection(
                "Mic-F1": F1Score(task="binary", num_classes=self.hparams.vocab_size, average="micro"),
                "Mac-F1": F1Score(task="binary", num_classes=self.hparams.vocab_size, average="macro"),

    def forward(self, a, b):
        a_rpr = torch.sum(self.embeddings(a), 1)
        b_rpr = torch.sum(self.embeddings(b), 1)
        rpr = 5 * a_rpr + .5 * b_rpr
        return self.cls_head(rpr)

    def training_step(self, batch, batch_idx, optimizer_idx=None):
        a, b, true_cls = batch["a"], batch["b"], batch["cls"]
        pred_cls = self(a, b)
        # log training loss
        train_loss = self.loss(pred_cls, true_cls)
        self.log('train_loss', train_loss)

        return train_loss

    def validation_step(self, batch, batch_idx):
        a, b, true_cls = batch["a"], batch["b"], batch["cls"]
        pred_cls = self(a, b)

        # log val metrics
        self.log_dict(self.val_metrics(torch.argmax(pred_cls, dim=-1), true_cls), prog_bar=True)

    def validation_epoch_end(self, outs):

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr, betas=(0.9, 0.999),
                                      eps=1e-08, weight_decay=self.hparams.weight_decay, amsgrad=True)
        return optimizer