Hi @goku, apologies for late reply, but here is the expected workflow:
model.train()
torch.set_grad_enabled(True)
train_loader = zip(labeled_trainloader, unlabeled_trainloader)
for batch_idx, (data_x, data_u) in enumerate(train_loader):
inputs_x, targets_x = data_x
(inputs_u_w, inputs_u_s), _ = data_u
inputs = torch.cat((inputs_x, inputs_u_w, inputs_u_s)).
logits = model(inputs)
Lx = (calculate labled loss using targets_x and logits)
Lu = (callculate unlabeled loss using some targets_u , logits)
loss = Lx + Lu
loss.backward()
optimizer.step()