Is there a pytorch or pytorch lightning implementation of SimCLR

Is there a pytorch / pytorch lightning implementation of SimCLR

yes!

The code + docs are here

But the direct use is:

Getting the model

First install bolts
pip install pytorch-lightning-bolts

Train

import pytorch_lightning as pl
from pl_bolts.models.self_supervised import SimCLR
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.models.self_supervised.simclr.simclr_transforms import (
    SimCLREvalDataTransform, SimCLRTrainDataTransform)

# data
dm = CIFAR10DataModule(num_workers=0)
dm.train_transforms = SimCLRTrainDataTransform(32)
dm.val_transforms = SimCLREvalDataTransform(32)

# model
model = SimCLR(num_samples=dm.num_samples, batch_size=dm.batch_size)

# fit
trainer = pl.Trainer()
trainer.fit(model, dm)

Or use pre-trained

from pl_bolts.models.self_supervised import SimCLR

weight_path = 'https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/simclr-cifar10-v1-exp12_87_52/epoch%3D960.ckpt'
simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)

simclr.freeze()

getting this error dry running the code in your link:


RuntimeError Traceback (most recent call last)
in
2
3 weight_path = ‘https://pl-bolts-weights.s3.us-east-2.amazonaws.com/simclr/simclr-cifar10-v1-exp12_87_52/epoch%3D960.ckpt
----> 4 simclr = SimCLR.load_from_checkpoint(weight_path, strict=False)
5
6 simclr.freeze()

~/.local/lib/python3.6/site-packages/pytorch_lightning/core/saving.py in load_from_checkpoint(cls, checkpoint_path, map_location, hparams_file, strict, **kwargs)
155 checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)
156
→ 157 model = cls._load_model_state(checkpoint, strict=strict, **kwargs)
158 return model
159

~/.local/lib/python3.6/site-packages/pytorch_lightning/core/saving.py in _load_model_state(cls, checkpoint, strict, **cls_kwargs_new)
203
204 # load the state_dict on the model automatically
→ 205 model.load_state_dict(checkpoint[‘state_dict’], strict=strict)
206
207 return model

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
1043 if len(error_msgs) > 0:
1044 raise RuntimeError(‘Error(s) in loading state_dict for {}:\n\t{}’.format(
→ 1045 self.class.name, “\n\t”.join(error_msgs)))
1046 return _IncompatibleKeys(missing_keys, unexpected_keys)
1047

RuntimeError: Error(s) in loading state_dict for SimCLR:
size mismatch for encoder.conv1.weight: copying a param with shape torch.Size([64, 3, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 3, 7, 7]).
size mismatch for projection.model.3.weight: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([128, 2048]).

**tried other links and I get an error on lars optimizer not available. **

Any working code?