Is there a pytorch / pytorch lightning implementation of SimCLR
yes!
But the direct use is:
Getting the model
First install bolts
pip install pytorch-lightning-bolts
Train
import pytorch_lightning as pl
from pl_bolts.models.self_supervised import SimCLR
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.models.self_supervised.simclr.simclr_transforms import (
SimCLREvalDataTransform, SimCLRTrainDataTransform)
# data
dm = CIFAR10DataModule(num_workers=0)
dm.train_transforms = SimCLRTrainDataTransform(32)
dm.val_transforms = SimCLREvalDataTransform(32)
# model
model = SimCLR(num_samples=dm.num_samples, batch_size=dm.batch_size)
# fit
trainer = pl.Trainer()
trainer.fit(model, dm)
Or use pre-trained
from pl_bolts.models.self_supervised import SimCLR
weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/simclr-cifar10-v1-exp12_87_52/epoch%3D960.ckpt'
simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)
simclr.freeze()
getting this error dry running the code in your link:
RuntimeError Traceback (most recent call last)
in
2
3 weight_path = ‘https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/simclr-cifar10-v1-exp12_87_52/epoch%3D960.ckpt’
----> 4 simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)
5
6 simclr.freeze()
~/.local/lib/python3.6/site-packages/pytorch_lightning/core/saving.py in load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)
155 checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)
156
→ 157 model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
158 return model
159
~/.local/lib/python3.6/site-packages/pytorch_lightning/core/saving.py in _load_model_state(cls, checkpoint, strict, **cls_kwargs_new)
203
204 # load the state_dict on the model automatically
→ 205 model.load_state_dict(checkpoint[‘state_dict’], strict=strict)
206
207 return model
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
1043 if len(error_msgs) > 0:
1044 raise RuntimeError(‘Error(s) in loading state_dict for {}:\n\t{}’.format(
→ 1045 self.class.name, “\n\t”.join(error_msgs)))
1046 return _IncompatibleKeys(missing_keys, unexpected_keys)
1047
RuntimeError: Error(s) in loading state_dict for SimCLR:
size mismatch for encoder.conv1.weight: copying a param with shape torch.Size([64, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 3, 7, 7]).
size mismatch for projection.model.3.weight: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([128, 2048]).
**tried other links and I get an error on lars optimizer not available. **
Any working code?