I study the code provided here SSL SimCLR on colab, and implemented a similar code.
I have modified the datamodule to load mine, and I can run the code. However, if I use SSLOnlineEvaluator. I cannot make it work. If I specify the callbacks for model, the error shows:
TypeError: on_train_batch_end() takes 6 positional arguments but 7 were given
I have the following code:
def to_device(batch, device):
(img1, _), y = batch
img1 = img1.to(device)
y = y.to(device)
return img1, y
online_finetuner = SSLOnlineEvaluator(z_dim=2048 * 2 * 2, num_classes = 7)
online_finetuner.to_device = to_device
lr_logger = LearningRateMonitor()
callbacks = [online_finetuner, lr_logger]
# pick data
rafdb_height=100
batch_size=64
# data
dm = RAFDBDataModule(num_workers=0)
dm.train_transforms = SimCLRTrainDataTransform(input_height=rafdb_height)
dm.val_transforms = SimCLREvalDataTransform(input_height=rafdb_height)
dm.test_transforms = SimCLREvalDataTransform(input_height=rafdb_height)
# model
model = SimCLR(num_samples=dm.num_samples, batch_size=batch_size)
# fit
trainer = pl.Trainer(gpus=1, max_epochs=20, callbacks=callbacks)
trainer.fit(model, dm)
Full Output information:
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
| Name | Type | Params
------------------------------------------------------
0 | encoder | ResNet | 25 M
1 | projection | Projection | 4 M
2 | non_linear_evaluator | SSLEvaluator | 8 M
Epoch 0: 0%| | 0/383 [00:00<?, ?it/s] Traceback (most recent call last):
File "H:/Project/PycharmProject/FER_Long-tailed/train_SimCLR.py", line 236, in <module>
trainer.fit(model, dm)
File "F:\Program\Anaconda3\envs\pytorch-lightning\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 440, in fit
results = self.accelerator_backend.train()
File "F:\Program\Anaconda3\envs\pytorch-lightning\lib\site-packages\pytorch_lightning\accelerators\gpu_accelerator.py", line 54, in train
results = self.train_or_test()
File "F:\Program\Anaconda3\envs\pytorch-lightning\lib\site-packages\pytorch_lightning\accelerators\accelerator.py", line 66, in train_or_test
results = self.trainer.train()
File "F:\Program\Anaconda3\envs\pytorch-lightning\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 483, in train
self.train_loop.run_training_epoch()
File "F:\Program\Anaconda3\envs\pytorch-lightning\lib\site-packages\pytorch_lightning\trainer\training_loop.py", line 557, in run_training_epoch
self.on_train_batch_end(epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx)
File "F:\Program\Anaconda3\envs\pytorch-lightning\lib\site-packages\pytorch_lightning\trainer\training_loop.py", line 249, in on_train_batch_end
self.trainer.call_hook('on_train_batch_end', epoch_end_outputs, batch, batch_idx, dataloader_idx)
File "F:\Program\Anaconda3\envs\pytorch-lightning\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 823, in call_hook
trainer_hook(*args, **kwargs)
File "F:\Program\Anaconda3\envs\pytorch-lightning\lib\site-packages\pytorch_lightning\trainer\callback_hook.py", line 147, in on_train_batch_end
callback.on_train_batch_end(self, self.get_model(), outputs, batch, batch_idx, dataloader_idx)
TypeError: on_train_batch_end() takes 6 positional arguments but 7 were given
Epoch 0: 0%| | 0/383 [00:02<?, ?it/s]
Process finished with exit code 1