Torch compile and Lightning CLI

Hi, I am using Lightning CLI to configure my experiments and i want to know the recommended practice to use torch.compile with Lightning CLI. By the way, I am 16-mixed in Lightning with RTX3090, and i got the following warning

You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision

what should i do in Lighting ? Thanks so much.

The second part of your question has been answered here: Correct behavior when using precision=16 with Tensor-Cores · Lightning-AI/lightning · Discussion #16698 · GitHub

The correct practice is to compile your model and not the LightningModule. Here is a pseudocode -

import pytorch_lightning as pl

class LitModel(pl.LightningModule):
   def __init__(self):
        self.model = torch.compile(load_model())
    ...

trainer = pl.Trainer()
trainer.fit(model)

Interesting! In my case it’s a simple embedding model:

class EmbeddingModel(pl.LightningModule):
    def __init__(self, n_samples, embedding_dim, learning_rate):
        super().__init__()
        self.learning_rate = learning_rate
        self.embedding = nn.Embedding(n_samples, embedding_dim)

and before I load the trainer I set the set_float32_matmul_precision

# pytorch TensorCore
    # @see https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
    try:
        torch.set_float32_matmul_precision('medium' | 'high')
    except Exception as e:
        log_message('unable to activate TensorCore')
        log_error(e)

and the train:

trainer = pl.Trainer(logger=logger,
                        callbacks=[checkpoint_callback],
                         max_epochs=hparams["epochs"],
                         devices=devices,
                         accelerator=accelerator,
                         strategy=train_strategy)
trainer.fit(model, trainloader)

So how to wrap it in the pl.LightningModule in that case?