Hi, I am using Lightning CLI to configure my experiments and i want to know the recommended practice to use torch.compile
with Lightning CLI. By the way, I am 16-mixed in Lightning with RTX3090, and i got the following warning
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
what should i do in Lighting ? Thanks so much.
The correct practice is to compile your model and not the LightningModule
. Here is a pseudocode -
import pytorch_lightning as pl
class LitModel(pl.LightningModule):
def __init__(self):
self.model = torch.compile(load_model())
...
trainer = pl.Trainer()
trainer.fit(model)
Interesting! In my case it’s a simple embedding model:
class EmbeddingModel(pl.LightningModule):
def __init__(self, n_samples, embedding_dim, learning_rate):
super().__init__()
self.learning_rate = learning_rate
self.embedding = nn.Embedding(n_samples, embedding_dim)
and before I load the trainer I set the set_float32_matmul_precision
# pytorch TensorCore
# @see https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
try:
torch.set_float32_matmul_precision('medium' | 'high')
except Exception as e:
log_message('unable to activate TensorCore')
log_error(e)
and the train:
trainer = pl.Trainer(logger=logger,
callbacks=[checkpoint_callback],
max_epochs=hparams["epochs"],
devices=devices,
accelerator=accelerator,
strategy=train_strategy)
trainer.fit(model, trainloader)
So how to wrap it in the pl.LightningModule in that case?