Gradient computation in meta-learning: the computation graph is breaking in the outer loop. First Order MAML
I am relying on this pytorch lightning tutorial https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial16/Meta_Learning.html
In outer_loop()
p_global
is None
(all param.grad
in self.model.parameters()
are None
, while those of local_model_parameters()
are not). As a consequence p_global.grad += p_local.grad
can’t be achieved
for p_global, p_local in zip(self.model.parameters(), local_model.parameters()):
p_global.grad += p_local.grad # First-order approx. -> add gradients of finetuned and base model
It could be related to
local_model = deepcopy(self.model)
In adapt_few_shot()
but deepcopy()
is needed, or may be create_graph=True
is needed ?
I tried the following but l am not sure that it’s true
for p_global, p_local in zip(self.model.parameters(), local_model.parameters()):
if p_global.grad is None:
p_global.grad = torch.zeros_like(p_local.grad)
p_global.grad += p_local.grad
Doing this leads to :
p_global.grad = p_local.grad
instead of p_global.grad += p_local.grad
since after each gradient update, opt.step(), opt.zero_grad()
every p_global.grad
in self.model.parameters()
becomes None
Code:
def adapt_few_shot(self, support_imgs, support_targets):
# Determine prototype initialization
support_feats = self.model(support_imgs)
prototypes, classes = ProtoNet.calculate_prototypes(support_feats, support_targets)
support_labels = (classes[None,:] == support_targets[:,None]).long().argmax(dim=-1)
# Create inner-loop model and optimizer
local_model = deepcopy(self.model)
local_model.train()
local_optim = optim.SGD(local_model.parameters(), lr=self.hparams.lr_inner)
local_optim.zero_grad()
# Create output layer weights with prototype-based initialization
init_weight = 2 * prototypes
init_bias = -torch.norm(prototypes, dim=1)**2
output_weight = init_weight.detach().requires_grad_()
output_bias = init_bias.detach().requires_grad_()
# Optimize inner loop model on support set
for _ in range(self.hparams.num_inner_steps):
# Determine loss on the support set
loss, _, _ = self.run_model(local_model, output_weight, output_bias, support_imgs, support_labels)
# Calculate gradients and perform inner loop update
loss.backward()
local_optim.step()
# Update output layer via SGD
# (https://discuss.pytorch.org/t/the-difference-between-torch-tensor-data-and-torch-tensor/25995/4):
with torch.no_grad():
output_weight.copy_(output_weight - self.hparams.lr_output * output_weight.grad)
output_bias.copy_(output_bias - self.hparams.lr_output * output_bias.grad)
# Reset gradients
local_optim.zero_grad()
output_weight.grad.fill_(0)
output_bias.grad.fill_(0)
# Re-attach computation graph of prototypes
output_weight = (output_weight - init_weight).detach() + init_weight
output_bias = (output_bias - init_bias).detach() + init_bias
return local_model, output_weight, output_bias, classes
def outer_loop(self, batch, mode="train"):
accuracies = []
losses = []
self.model.zero_grad()
# Determine gradients for batch of tasks
for task_batch in batch:
imgs, targets = task_batch
support_imgs, query_imgs, support_targets, query_targets = split_batch(imgs, targets)
# Perform inner loop adaptation
local_model, output_weight, output_bias, classes = self.adapt_few_shot(support_imgs, support_targets)
# Determine loss of query set
query_labels = (classes[None,:] == query_targets[:,None]).long().argmax(dim=-1)
loss, preds, acc = self.run_model(local_model, output_weight, output_bias, query_imgs, query_labels)
# Calculate gradients for query set loss
if mode == "train":
loss.backward()
for p_global, p_local in zip(self.model.parameters(), local_model.parameters()):
p_global.grad += p_local.grad # First-order approx. -> add gradients of finetuned and base model
accuracies.append(acc.mean().detach())
losses.append(loss.detach())
# Perform update of base model
if mode == "train":
opt = self.optimizers()
opt.step()
opt.zero_grad()
self.log(f"{mode}_loss", sum(losses) / len(losses))
self.log(f"{mode}_acc", sum(accuracies) / len(accuracies))
def training_step(self, batch, batch_idx):
self.outer_loop(batch, mode="train")
return None # Returning None means we skip the default training optimizer steps by PyTorch Lightning
I have ckecked the following discussions: