Using multiple optimizers inside a loop with Lightning

I am working on a big project which for which I need to call manual_backward and optimizer.step inside a loop for every batch.
Here is some reference code for a training_function that works, and another that doesn’t:

def loss_fn_working(self, batch: Any, batch_idx: int):
    env = self.envs[self.p]
    actions = None
    prev_log_rewards = torch.empty(env.done.shape[0]).type_as(env.state)
    prev_forward_logprob = None
    loss = torch.tensor(0.0, requires_grad=True)
    TERM = env.terminal_index

    while not torch.all(env.done):
        active = ~env.done
        forward_logprob, back_logprob = self.forward(env)

        log_rewards = -self.get_rewards()

        if actions is not None:
            error = log_rewards - prev_log_rewards[active]
            error += back_logprob.gather(1, actions[actions != TERM, None]).squeeze(1)
            error += prev_forward_logprob[active, -1]
            error -= forward_logprob[:, -1].detach()
            error -= (
                prev_forward_logprob[active]
                .gather(1, actions[actions != TERM, None])
                .squeeze(1)
            )
            loss = loss + F.huber_loss(
                error,
                torch.zeros_like(error),
                delta=1.0,
                reduction="none",
            )
            loss = loss * log_rewards.softmax(0)
            loss = loss.mean(0)

        actions = self.sample_actions(forward_logprob, active, TERM)
        env.step(actions)

        # save previous log-probs and log-rewards
        if prev_forward_logprob is None:
            prev_forward_logprob = torch.empty_like(forward_logprob)
        prev_forward_logprob[active] = forward_logprob
        prev_log_rewards[active] = log_rewards

    return loss, log_rewards


def calculate_loss(
        self,
        loss,
        log_rewards,
        prev_log_rewards,
        back_log_prob,
        actions,
        stop_prob,
        prefix, # Added for debugging
        prev_stop_prob=None,
        prev_forward_log_prob=None,
    ):
        error = torch.tensor(0.0, requires_grad=True) + log_rewards - prev_log_rewards  # [B]
        error = error + (back_log_prob).gather(1, actions.unsqueeze(1)).squeeze(1)  # P_B(s|s')
        error = error - stop_prob.detach()  # P(s_f|s')
        if prev_stop_prob is not None and prev_forward_log_prob is not None:
            error = error + prev_stop_prob.detach()  # P(s_f|s)
            error = error - (prev_forward_log_prob).gather(
                1, actions.unsqueeze(1)
            ).squeeze(1)

        loss = loss + F.huber_loss(  # accumulate losses
            error,
            torch.zeros_like(error),
            delta=1.0,
            reduction="none",
        )

        loss = loss * log_rewards.softmax(0)
        return loss.mean(0)

def loss_fn_not_working(self, batch, batch_size, prefix, batch_idx):
    gfn_opt, rep_opt = self.optimizers()
    # some code here
    losses = []
    rep_losses = []
    prev_forward_log_prob = None
    prev_stop_prob = torch.zeros(batch_size, device='cuda')
    loss = torch.tensor(0.0, requires_grad=True, device='cuda')
    active = torch.ones((batch_size,), dtype=bool, device='cuda')
    graph = torch.diag_embed(torch.ones(batch_size, self.n_dim)).cuda()
    while active.any():
        graph_hat = graph[active].clone()
        adj_mat = graph_hat.clone()
        rep_loss, latent_var = self.rep_model(torch.cat((adj_mat, next_id.unsqueeze(-1)), axis = -1))
        rep_loss_tensor = torch.tensor(0.0, requires_grad=True) + rep_loss

        forward_log_prob, Fs_masked, back_log_prob, next_prob, stop_prob = (
            self.gfn_model(latent_var)
        )
        with torch.no_grad():
            actions = self.sample_actions(Fs_masked)
            graph = self.update_graph(actions)
            #######################

            log_rewards = -self.energy_model(graph_hat, batch, False, self.current_epoch)

        if counter==0:
            loss = self.calculate_loss(loss, log_rewards, prev_log_rewards, back_log_prob, actions, stop_prob, prefix)
        else:
            loss = self.calculate_loss(loss, log_rewards, prev_log_rewards, back_log_prob, actions, stop_prob, prefix, prev_stop_prob[active], prev_forward_log_prob[active])
        losses.append(loss.item())
        rep_losses.append(rep_loss.item())
        if prefix == 'train':
            rep_opt.zero_grad()
            self.manual_backward(rep_loss_tensor, retain_graph=True)
            rep_opt.step()
            gfn_opt.zero_grad()
            self.manual_backward(loss)
            self.clip_gradients(gfn_opt, gradient_clip_val=0.5, gradient_clip_algorithm="norm") # NEEDED??
            gfn_opt.step()

        with torch.no_grad():
            active[indices_to_deactivate] = False #active updated appropriately
            indices = indices[~current_stop]
            # active_indices = ~current_stop # Not being used?
            next_id = F.one_hot(indices, num_classes=self.n_dim)
            prev_log_rewards = log_rewards[~current_stop]
            counter += 1
        if prev_forward_log_prob is None:
            prev_forward_log_prob = torch.empty_like(forward_log_prob)
        prev_forward_log_prob[active] = forward_log_prob[~current_stop]
        prev_stop_prob[active] = stop_prob[~current_stop]

    return losses, graph, log_rewards, counter, rep_losses

Here, the main variable of importance is prev_forward_log_prob in loss_fn_not_working. The loss is being calculated using calculate_loss() function.
I have kept manual_optimization as True.

When using loss_fn_not_working, and keeping retain_graph as false for loss, I get the following error:

Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

If I do keep retain_graph as True for loss (i.e. the loss for the second optimizer), I get the following error instead:

one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [64, 10]], which is output 0 of AsStridedBackward0, is at version 3; expected version 1 instead. 

If I use loss_fn_working, there is no problem. So, I understand that the problem arises when using backward calls inside the loop. I am not really making any in-place operations, so how can I make the 2nd loss_fn work? I tried cloning the relevant variables but it doesn’t work until we detach them, which I can’t do.