How to use SSLOnlineEvaluator

I study the code provided here SSL SimCLR on colab, and implemented a similar code.

I have modified the datamodule to load mine, and I can run the code. However, if I use SSLOnlineEvaluator. I cannot make it work. If I specify the callbacks for model, the error shows:
TypeError: on_train_batch_end() takes 6 positional arguments but 7 were given

I have the following code:

def to_device(batch, device):
    (img1, _), y = batch
    img1 = img1.to(device)
    y = y.to(device)
    return img1, y

online_finetuner = SSLOnlineEvaluator(z_dim=2048 * 2 * 2, num_classes = 7)
online_finetuner.to_device = to_device

lr_logger = LearningRateMonitor()

callbacks = [online_finetuner, lr_logger]

# pick data
rafdb_height=100
batch_size=64

# data
dm = RAFDBDataModule(num_workers=0)
dm.train_transforms = SimCLRTrainDataTransform(input_height=rafdb_height)
dm.val_transforms = SimCLREvalDataTransform(input_height=rafdb_height)
dm.test_transforms = SimCLREvalDataTransform(input_height=rafdb_height)

# model
model = SimCLR(num_samples=dm.num_samples, batch_size=batch_size)

# fit
trainer = pl.Trainer(gpus=1, max_epochs=20, callbacks=callbacks)
trainer.fit(model, dm)

Full Output information:

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                 | Type         | Params
------------------------------------------------------
0 | encoder              | ResNet       | 25 M  
1 | projection           | Projection   | 4 M   
2 | non_linear_evaluator | SSLEvaluator | 8 M   
Epoch 0:   0%|          | 0/383 [00:00<?, ?it/s] Traceback (most recent call last):
  File "H:/Project/PycharmProject/FER_Long-tailed/train_SimCLR.py", line 236, in <module>
    trainer.fit(model, dm)
  File "F:\Program\Anaconda3\envs\pytorch-lightning\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 440, in fit
    results = self.accelerator_backend.train()
  File "F:\Program\Anaconda3\envs\pytorch-lightning\lib\site-packages\pytorch_lightning\accelerators\gpu_accelerator.py", line 54, in train
    results = self.train_or_test()
  File "F:\Program\Anaconda3\envs\pytorch-lightning\lib\site-packages\pytorch_lightning\accelerators\accelerator.py", line 66, in train_or_test
    results = self.trainer.train()
  File "F:\Program\Anaconda3\envs\pytorch-lightning\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 483, in train
    self.train_loop.run_training_epoch()
  File "F:\Program\Anaconda3\envs\pytorch-lightning\lib\site-packages\pytorch_lightning\trainer\training_loop.py", line 557, in run_training_epoch
    self.on_train_batch_end(epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx)
  File "F:\Program\Anaconda3\envs\pytorch-lightning\lib\site-packages\pytorch_lightning\trainer\training_loop.py", line 249, in on_train_batch_end
    self.trainer.call_hook('on_train_batch_end', epoch_end_outputs, batch, batch_idx, dataloader_idx)
  File "F:\Program\Anaconda3\envs\pytorch-lightning\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 823, in call_hook
    trainer_hook(*args, **kwargs)
  File "F:\Program\Anaconda3\envs\pytorch-lightning\lib\site-packages\pytorch_lightning\trainer\callback_hook.py", line 147, in on_train_batch_end
    callback.on_train_batch_end(self, self.get_model(), outputs, batch, batch_idx, dataloader_idx)
TypeError: on_train_batch_end() takes 6 positional arguments but 7 were given
Epoch 0:   0%|          | 0/383 [00:02<?, ?it/s]

Process finished with exit code 1

I believe this is fixed on the master branch in bolts.
see issue here
https://github.com/PyTorchLightning/pytorch-lightning-bolts/issues/309

1 Like

I do not think the bug is fully fixed. Please see the issue for my new bug.
There is a new bug with not-matched shape in network matrix multiplication.