Organize Your Code¶
Any raw PyTorch can be converted to Fabric with zero refactoring required, giving maximum flexibility in how you want to organize your projects.
However, when developing a project in a team or sharing the code publicly, it can be beneficial to conform to a standard format of how core pieces of the code are organized. This is what the LightningModule was made for!
Here is how you can neatly separate the research code (model, loss, optimization, etc.) from the “trainer” code (training loop, checkpointing, logging, etc.).
Step 1: Move your code into LightningModule hooks¶
Take these main ingredients and put them in a LightningModule:
The PyTorch model(s) as an attribute (e.g.
The forward, including loss computation, goes into
Setup of optimizer(s) goes into
Setup of the training data loader goes into
import lightning as L class LitModel(L.LightningModule): def __init__(self): super().__init__() self.model = ... def training_step(self, batch, batch_idx): # Main forward, loss computation, and metrics goes here x, y = batch y_hat = self.model(x) loss = self.loss_fn(y, y_hat) acc = self.accuracy(y, y_hat) ... return loss def configure_optimizers(self): # Return one or several optimizers return torch.optim.Adam(self.parameters(), ...) def train_dataloader(self): # Return your dataloader for training return DataLoader(...) def on_train_start(self): # Do something at the beginning of training ... def any_hook_you_like(self, *args, **kwargs): ...
This is a minimal LightningModule, but there are many other useful hooks you can use.
Step 2: Call hooks from your Fabric code¶
In your Fabric training loop, you can now call the hooks of the LightningModule interface. It is up to you to call everything at the right place.
import lightning as L fabric = L.Fabric(...) # Instantiate the LightningModule model = LitModel() # Get the optimizer(s) from the LightningModule optimizer = model.configure_optimizers() # Get the training data loader from the LightningModule train_dataloader = model.train_dataloader() # Set up objects model, optimizer = fabric.setup(model, optimizer) train_dataloader = fabric.setup_dataloaders(train_dataloader) # Call the hooks at the right time model.on_train_start() model.train() for epoch in range(num_epochs): for i, batch in enumerate(dataloader): optimizer.zero_grad() loss = model.training_step(batch, i) fabric.backward(loss) optimizer.step() # Control when hooks are called if condition: model.any_hook_you_like()
Your code is now modular. You can switch out the entire LightningModule implementation for another one, and you don’t need to touch the training loop:
# Instantiate the LightningModule - model = LitModel() + model = DopeModel() ...