Hello,
I am trying to implement DINO SS model, on Pytorch Lightning (i havent seend any public implementation yet)
I have 2 models, teacher and student, on native pytorch the authors backward the loss like this:
student update
optimizer.zero_grad()
param_norms = None
loss.backward()
if args.clip_grad:
param_norms = utils.clip_gradients(student, args.clip_grad)
utils.cancel_gradients_last_layer(epoch, student,
args.freeze_last_layer)
optimizer.step()
EMA update for the teacher
with torch.no_grad():
m = momentum_schedule[it] # momentum parameter
for param_q, param_k in zip(student.module.parameters(), teacher_without_ddp.parameters()):
param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
how can i do this??