Hi
I have tried to convert a training loop and test loop into a pytorch lightning module with the exact same model architecture. The manual process reduces the training loss and works - so I know the issue must be in my pytorch LightningModule implementation, where the performance on the training set itself does not improve (the model architecture and data loaders work fine as shown by the traditional pytorch loops)
Here is the pl code, would appreciate if anyone spots anything wrong:
class dogClassifier(LightningModule):
def __init__(self, lr=0.001):
super().__init__()
self.lr = lr
self.model = create_model()
self.val_accuracy = Accuracy()
self.test_accuracy = Accuracy()
def forward(self, x):
x = self.model(x)
return F.softmax(x, dim=1)
def training_step(self, batch, batch_idx):
x,y = batch
outputs = self(x)
loss = nn.CrossEntropyLoss()(outputs,y)
self.log('loss', loss)
return loss
def validation_step(self, batch, batch_idx):
x,y = batch
outputs= self(x)
loss = nn.CrossEntropyLoss()(outputs,y)
preds = torch.argmax(outputs,dim=1)
self.val_accuracy.update(preds,y)
self.log('val_loss',loss)
self.log("val_acc",self.val_accuracy, prog_bar=True)
def test_step(self, batch, batch_idx):
x,y = batch
outputs = self(x)
loss = nn.CrossEntropyLoss()(outputs,y)
preds = torch.argmax(outputs,dim=1)
self.test_accuracy.update(preds,y)
self.log('test_loss'.loss)
self.log('test_acc',self.test_accuracy, prog_bar=True)
def configure_optimizers(self):
return torch.optim.AdamW(self.parameters(), self.lr)
The trainer and run steps:
from pytorch_lightning.trainer.trainer import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.callbacks.progress import TQDMProgressBar
trainer = Trainer(
accelerator='auto',
default_root_dir=".",
max_epochs=10,
callbacks = [#ModelCheckpoint(dirpath='models/',monitor='val_acc',save_top_k=3,mode="max"),
TQDMProgressBar(refresh_rate=10)],
logger=CSVLogger(save_dir="logs/"),
log_every_n_steps=10,
)
dogbreedModel = dogClassifier(lr=0.001)
trainer.fit(dogbreedModel , train_dataloaders=trainLoader, val_dataloaders=validLoader)