I’m trying to add checkpoint saving capabilities in my Lightning code, but that somehow stops the terminal logging.
Following the doc, I added the ModelCheckpoint
callback to my trainer as follows:
def get_callbacks(args):
checkpoint_callback = ModelCheckpoint(
monitor='valid_loss_epoch',
# dirpath=f'saved_models',
# filename=f'{args.name}',
save_top_k=3,
mode='min',
verbose=True,
save_last=True,
)
return [checkpoint_callback]
args.default_root_dir = f'saved_models/'
trainer = pl.Trainer.from_argparse_args(args)
trainer.callbacks = get_callbacks(args)
This makes the log in the terminal disappear i.e. the program doesn’t log to the terminal anymore. Commenting out the trainer.callbacks = get_callbacks(args)
starts the terminal logging. I’ve tried combinations of adding and commenting the dirpath
and default_root_dir
, but none of these helps with the logging. Any help would be great!
Details of my Lightning Module step/epoch functions where I have the self.log calls:
def training_step(self, batch, batch_idx, optimizer_idx):
(opt1, opt2) = self.optimizers()
outputs = self(batch)
preds = torch.argmax(outputs, axis=1)
targets = batch["target"]
loss = self.calc_loss(outputs, targets)
acc = self.calc_acc(preds, targets)
self.manual_backward(loss, opt1)
opt1.step()
opt1.zero_grad()
opt2.step()
opt2.zero_grad()
self.log('train_loss_step', loss, prog_bar=True)
self.log('train_acc_step', acc, prog_bar=True)
return {'loss': loss, "preds": preds, "targets": targets}
def backward(self, loss, optimizer, optimizer_idx):
loss.backward()
def training_epoch_end(self, outputs):
preds = torch.cat([x['preds'] for x in outputs])
targets = torch.cat([x['targets'] for x in outputs])
loss = torch.stack([x['loss'] for x in outputs]).mean()
self.log('train_loss_epoch', loss.item())
self.log('train_acc_epoch', self.calc_acc(preds, targets))
def validation_step(self, batch, batch_idx, dataloader_idx=0):
outputs = self(batch)
preds = torch.argmax(outputs, axis=1)
targets = batch["target"]
loss = self.calc_loss(outputs, targets)
acc = self.calc_acc(preds, targets)
self.log('valid_loss_step', loss, prog_bar=True)
self.log('valid_acc_step', acc, prog_bar=True)
return {'loss': loss, "preds": preds, "targets": targets}
def validation_epoch_end(self, outputs):
preds = torch.cat([x['preds'] for x in outputs])
targets = torch.cat([x['targets'] for x in outputs])
loss = torch.stack([x['loss'] for x in outputs]).mean()
self.log('valid_loss_epoch', loss.item())
self.log('valid_acc_epoch', self.calc_acc(preds, targets))