Is anyone know why is code using ddp will be stucked on the epoch 3

import math
import os
import signal
import sys
import traceback
from pathlib import Path
from random import randint
import datetime

import torch
import wandb
import randomname
from pytorch_lightning.strategies import DDPStrategy

from util.distinct_colors import DistinctColors
from util.misc import visualize_depth, probability_to_normalized_entropy, get_boundary_mask
from util.warmup_scheduler import GradualWarmupScheduler
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger

from util.filesystem_logger import FilesystemLogger


def generate_experiment_name(name, config):
    if config.resume is not None:
        experiment = Path(config.resume).parents[1].name
        os.environ['experiment'] = experiment
    elif not os.environ.get('experiment'):
        experiment = f"{datetime.datetime.now().strftime('%m%d%H%M')}_{name}_{config.experiment}_{randomname.get_name()}"
        os.environ['experiment'] = experiment
    else:
        experiment = os.environ['experiment']
    return experiment


def create_trainer(name, config):
    if not config.wandb_main and config.suffix == '':
        config.suffix = '-dev'
    config.experiment = generate_experiment_name(name, config)
    if config.val_check_interval > 1:
        config.val_check_interval = int(config.val_check_interval)
    if config.seed is None:
        config.seed = randint(0, 999)
    if isinstance(config.image_dim, int):
        config.image_dim = [config.image_dim, config.image_dim]
    assert config.image_dim[0] == config.image_dim[1], "only 1:1 supported"  # TODO: fix dataprocessing bug limiting this

    seed_everything(config.seed)

    # noinspection PyUnusedLocal
    filesystem_logger = FilesystemLogger(config)

    # use wandb logger instead
    if config.logger == 'wandb':
        logger = WandbLogger(project=f'{name}{config.suffix}', name=config.experiment, id=config.experiment, settings=wandb.Settings(start_method='thread'))
    else:
        logger = TensorBoardLogger(name='tb', save_dir=(Path("runs") / config.experiment))

    checkpoint_callback = ModelCheckpoint(dirpath=(Path("runs") / config.experiment / "checkpoints"),
                                          save_top_k=-1,
                                          verbose=False,
                                          every_n_epochs=config.save_epoch)
    gpu_count = torch.cuda.device_count()

    if gpu_count > 1:
        trainer = Trainer(accelerator='gpu',
                          strategy=DDPStrategy(find_unused_parameters=True),
                          num_nodes=1,
                          devices=gpu_count,
                          num_sanity_val_steps=config.sanity_steps,
                          max_epochs=config.max_epoch,
                          limit_val_batches=config.val_check_percent,
                          callbacks=[checkpoint_callback],
                          val_check_interval=float(min(config.val_check_interval, 1)),
                          check_val_every_n_epoch=max(1, config.val_check_interval),
                          resume_from_checkpoint=config.resume,
                          logger=logger,
                          benchmark=True)
    return trainer


def step(opt, modules):
    for module in modules:
        for param in module.parameters():
            if param.grad is not None:
                torch.nan_to_num(param.grad, nan=0, posinf=1e5, neginf=-1e5, out=param.grad)
    opt.step()


def get_optimizer_and_scheduler(params, config, betas=None):
    opt = torch.optim.Adam(params, lr=config.lr, weight_decay=config.weight_decay, betas=betas)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(opt, milestones=config.decay_step, gamma=config.decay_gamma)
    if config.warmup_epochs > 0:
        scheduler = GradualWarmupScheduler(opt, multiplier=config.warmup_multiplier, total_epoch=config.warmup_epochs, after_scheduler=scheduler)
    return opt, scheduler


def visualize_panoptic_outputs(p_rgb, p_semantics, p_instances, p_depth, rgb, semantics, instances, H, W, thing_classes, visualize_entropy=True):
    alpha = 0.65
    distinct_colors = DistinctColors()
    img = p_rgb.view(H, W, 3).cpu()
    img = torch.clamp(img, 0, 1).permute(2, 0, 1)
    if visualize_entropy:
        img_sem_entropy = visualize_depth(probability_to_normalized_entropy(torch.nn.functional.softmax(p_semantics, dim=-1)).reshape(H, W), minval=0.0, maxval=1.00, use_global_norm=True)
    else:
        img_sem_entropy = torch.zeros_like(img)
    if p_depth is not None:
        depth = visualize_depth(p_depth.view(H, W))
    else:
        depth = torch.zeros_like(img)
    if len(p_instances.shape) > 1 and len(p_semantics.shape) > 1:
        p_instances = p_instances.argmax(dim=1)
        p_semantics = p_semantics.argmax(dim=1)
    img_semantics = distinct_colors.apply_colors_fast_torch(p_semantics.cpu()).view(H, W, 3).permute(2, 0, 1) * alpha + img * (1 - alpha)
    boundaries_img_semantics = get_boundary_mask(p_semantics.cpu().view(H, W))
    img_semantics[:, boundaries_img_semantics > 0] = 0
    colored_img_instance = distinct_colors.apply_colors_fast_torch(p_instances.cpu()).float()
    boundaries_img_instances = get_boundary_mask(p_instances.cpu().view(H, W))
    colored_img_instance[boundaries_img_instances.reshape(-1) > 0, :] = 0
    thing_mask = torch.logical_not(sum(p_semantics == s for s in thing_classes).bool())
    colored_img_instance[thing_mask, :] = p_rgb.cpu()[thing_mask, :]
    img_instances = colored_img_instance.view(H, W, 3).permute(2, 0, 1) * alpha + img * (1 - alpha)
    if rgb is not None and semantics is not None and instances is not None:
        img_gt = rgb.view(H, W, 3).permute(2, 0, 1).cpu()
        img_semantics_gt = distinct_colors.apply_colors_fast_torch(semantics.cpu()).view(H, W, 3).permute(2, 0, 1) * alpha + img_gt * (1 - alpha)
        boundaries_img_semantics_gt = get_boundary_mask(semantics.cpu().view(H, W))
        img_semantics_gt[:, boundaries_img_semantics_gt > 0] = 0
        colored_img_instance_gt = distinct_colors.apply_colors_fast_torch(instances.cpu()).float()
        boundaries_img_instances_gt = get_boundary_mask(instances.cpu().view(H, W))
        colored_img_instance_gt[instances == 0, :] = rgb.cpu()[instances == 0, :]
        img_instances_gt = colored_img_instance_gt.view(H, W, 3).permute(2, 0, 1) * alpha + img_gt * (1 - alpha)
        img_instances_gt[:, boundaries_img_instances_gt > 0] = 0
        stack = torch.cat([torch.stack([img_gt, img_semantics_gt, img_instances_gt, torch.zeros_like(img_gt), torch.zeros_like(img_gt)]), torch.stack([img, img_semantics, img_instances, depth, img_sem_entropy])], dim=0)
    else:
        stack = torch.stack([img, img_semantics, img_instances, depth, img_sem_entropy])
    return stack
# Copyright (c) Meta Platforms, Inc. All Rights Reserved

import math

import numpy as np
import scipy
import torch.multiprocessing
from pathlib import Path
import torch
import hydra
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only
from tabulate import tabulate
from torchvision.utils import save_image, make_grid
from torch.utils.data import DataLoader
from torch_scatter import scatter_mean

from dataset import get_dataset, get_inconsistent_single_dataset, get_segment_dataset
from model.loss.loss import TVLoss, get_semantic_weights, SCELoss
from model.radiance_field.tensoRF import TensorVMSplit
from model.renderer.panopli_tensoRF_renderer import TensoRFRenderer
from util.distinct_colors import DistinctColors
from trainer import create_trainer, get_optimizer_and_scheduler, visualize_panoptic_outputs
from util.metrics import psnr, ConfusionMatrix
from util.panoptic_quality import panoptic_quality

torch.multiprocessing.set_sharing_strategy('file_system')  # possible fix for the "OSError: too many files" exception

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.deterministic = False


class TensoRFTrainer(pl.LightningModule):

    def __init__(self, config):
        super().__init__()
        self.train_set, self.val_set = get_dataset(config)
        if config.visualized_indices is None:
            config.visualized_indices = list(range(0, len(self.val_set), int(1 / 0.15)))
            if len(config.visualized_indices) < 16:
                config.visualized_indices = list(range(0, len(self.val_set)))[:16]
        config.instance_optimization_epoch = config.instance_optimization_epoch + config.late_semantic_optimization
        config.segment_optimization_epoch = config.segment_optimization_epoch + config.late_semantic_optimization
        self.config = config
        self.current_lambda_dist_reg = 0
        if self.config.segment_grouping_mode != "none":
            self.train_segment_set = get_segment_dataset(self.config)
        self.save_hyperparameters(config)
        total_classes = len(self.train_set.segmentation_data.bg_classes) + len(self.train_set.segmentation_data.fg_classes)
        output_mlp_semantics = torch.nn.Identity() if self.config.semantic_weight_mode != "softmax" else torch.nn.Softmax(dim=-1)
        self.model = TensorVMSplit([config.min_grid_dim, config.min_grid_dim, config.min_grid_dim], num_semantics_comps=(32, 32, 32),
                                   num_semantic_classes=total_classes, dim_feature_instance=self.config.max_instances,
                                   output_mlp_semantics=output_mlp_semantics, use_semantic_mlp=self.config.use_mlp_for_semantics,
                                   use_feature_reg=self.config.use_feature_regularization)
        self.renderer = TensoRFRenderer(self.train_set.scene_bounds, [config.min_grid_dim, config.min_grid_dim, config.min_grid_dim],
                                        semantic_weight_mode=self.config.semantic_weight_mode, stop_semantic_grad=config.stop_semantic_grad)
        semantic_weights = get_semantic_weights(config.reweight_fg, self.train_set.segmentation_data.fg_classes, self.train_set.segmentation_data.num_semantic_classes)
        semantic_weights[0] = config.weight_class_0
        self.loss = torch.nn.MSELoss(reduction='mean')
        self.loss_feat = torch.nn.L1Loss(reduction='mean')
        self.tv_regularizer = TVLoss()
        if not self.config.use_symmetric_ce:
            self.loss_semantics = torch.nn.CrossEntropyLoss(reduction='none', weight=semantic_weights)
        else:
            self.loss_semantics = SCELoss(self.config.ce_alpha, self.config.ce_beta, semantic_weights)
        self.loss_instances_cluster = torch.nn.CrossEntropyLoss(reduction='none')
        self.output_dir_result_images = Path(f'runs/{self.config.experiment}/images')
        self.output_dir_result_images.mkdir(exist_ok=True)
        self.output_dir_result_clusters = Path(f'runs/{self.config.experiment}/instance_clusters')
        self.output_dir_result_clusters.mkdir(exist_ok=True)
        self.automatic_optimization = False
        self.distinct_colors = DistinctColors()

    def configure_optimizers(self):
        params = self.model.get_optimizable_parameters(self.config.lr * 20, self.config.lr, weight_decay=self.config.weight_decay)
        optimizer, scheduler = get_optimizer_and_scheduler(params, self.config, betas=(0.9, 0.99))
        param_instance = self.model.get_optimizable_instance_parameters(self.config.lr * 20, self.config.lr)
        optimizer_instance, scheduler_instance = get_optimizer_and_scheduler(param_instance, self.config, betas=(0.9, 0.999))
        return [optimizer, optimizer_instance], [scheduler, scheduler_instance]

    def forward(self, rays, is_train):
        B = rays.shape[0]
        out_rgb, out_semantics, out_instances, out_depth, out_regfeat, out_dist_regularizer = [], [], [], [], [], []
        for i in range(0, B, self.config.chunk):
            out_rgb_, out_semantics_, out_instances_, out_depth_, out_regfeat_, out_dist_reg_ =\
                self.renderer(self.model, rays[i: i + self.config.chunk], self.config.perturb, self.train_set.white_bg, is_train)
            out_rgb.append(out_rgb_)
            out_semantics.append(out_semantics_)
            out_instances.append(out_instances_)
            out_regfeat.append(out_regfeat_)
            out_depth.append(out_depth_)
            out_dist_regularizer.append(out_dist_reg_.unsqueeze(0))
        out_rgb = torch.cat(out_rgb, 0)
        out_instances = torch.cat(out_instances, 0)
        out_depth = torch.cat(out_depth, 0)
        out_semantics = torch.cat(out_semantics, 0)
        out_regfeat = torch.cat(out_regfeat, 0)
        out_dist_regularizer = torch.mean(torch.cat(out_dist_regularizer, 0))
        return out_rgb, out_semantics, out_instances, out_depth, out_regfeat, out_dist_regularizer

    def forward_instance(self, rays, is_train):
        B = rays.shape[0]
        out_feats_instance = []
        for i in range(0, B, self.config.chunk):
            batch_rays = rays[i: i + self.config.chunk]
            out_feats_instance_ = self.renderer.forward_instance_feature(self.model, batch_rays, self.config.perturb, is_train)
            out_feats_instance.append(out_feats_instance_)
        out_feats_instance = torch.cat(out_feats_instance, 0)
        return out_feats_instance

    def forward_segments(self, rays, is_train):
        B = rays.shape[0]
        out_feats_segments = []
        for i in range(0, B, self.config.chunk_segment):
            batch_rays = rays[i: i + self.config.chunk_segment]
            out_feats_segments_ = self.renderer.forward_segment_feature(self.model, batch_rays, self.config.perturb, is_train)
            out_feats_segments.append(out_feats_segments_)
        out_feats_segments = torch.cat(out_feats_segments, 0)
        return out_feats_segments

    def training_step(self, batch, batch_idx):
        opts = self.optimizers()
        opts[0].zero_grad(set_to_none=True)
        rays, rgbs, semantics, probs, confs, masks, feats = batch[0]['rays'], batch[0]['rgbs'], batch[0]['semantics'], batch[0]['probabilities'], batch[0]['confidences'], batch[0]['mask'], batch[0]['feats']
        output_rgb, output_semantics, output_instances, _output_depth, output_feats, loss_dist_reg = self(rays, True)
        output_rgb[~masks, :] = 0
        rgbs[~masks, :] = 0
        confs[~masks] = 0
        loss = torch.zeros(1, device=batch[0]['rays'].device, requires_grad=True)
        if self.config.lambda_rgb > 0:
            loss_rgb = self.loss(output_rgb, rgbs)
            loss_tv_density = self.model.tv_loss_density(self.tv_regularizer)
            loss_tv_semantics = self.model.tv_loss_semantics(self.tv_regularizer) * (1 if self.current_epoch >= self.config.late_semantic_optimization else 0)
            loss_tv_appearance = self.model.tv_loss_appearance(self.tv_regularizer)
            loss_tv = loss_tv_density * self.config.lambda_tv_density + loss_tv_appearance * self.config.lambda_tv_appearance + loss_tv_semantics * self.config.lambda_tv_semantics
            loss_feat = torch.zeros_like(loss_tv)
            if self.config.use_feature_regularization:
                loss_feat = self.loss_feat(output_feats, feats).mean()
            loss = self.config.lambda_rgb * (loss_rgb + loss_tv + loss_dist_reg * self.current_lambda_dist_reg + loss_feat * self.config.lambda_feat)
            self.log("train/loss_rgb", loss_rgb, on_step=True, on_epoch=False, prog_bar=True, logger=True, sync_dist=True)
            self.log("train/loss_feat", loss_feat, on_step=True, on_epoch=False, prog_bar=False, logger=True, sync_dist=True)
            self.log("train/loss_tv_regularizer", loss_tv_appearance + loss_tv_density, on_step=True, on_epoch=False, prog_bar=False, logger=True, sync_dist=True)
            self.log("train/loss_dist_regularizer", loss_dist_reg + loss_tv_density, on_step=True, on_epoch=False, prog_bar=True, logger=True, sync_dist=True)
        if self.config.probabilistic_ce_mode == "TTAConf":
            loss_semantics = (self.loss_semantics(output_semantics, probs) * confs).mean()
        elif self.config.probabilistic_ce_mode == "NoTTAConf":
            loss_semantics = (self.loss_semantics(output_semantics, semantics) * confs).mean()
        else:
            loss_semantics = self.loss_semantics(output_semantics, semantics).mean()
        loss_segment_clustering = torch.zeros(1, device=self.device, requires_grad=True)
        if self.config.segment_grouping_mode != "none" and self.current_epoch >= self.config.segment_optimization_epoch:
            batch[2]["rays"] = torch.cat(batch[2]['rays'], dim=0)
            batch[2]['group'] = torch.cat(batch[2]['group'], dim=0)
            batch[2]['confidences'] = torch.cat(batch[2]['confidences'], dim=0)
            semantic_features = self.forward_segments(batch[2]["rays"], True)
            batch_target_mean = torch.zeros(self.config.batch_size_segments, semantic_features.shape[-1], device=semantic_features.device)
            scatter_mean(semantic_features, batch[2]['group'], 0, batch_target_mean)
            target = batch_target_mean[batch[2]['group'], :].argmax(-1)
            # mask_incorrect = semantic_features.detach().argmax(-1) != target
            # loss_segment_clustering = (self.loss_semantics(semantic_features, target) * batch[2]['confidences'])[mask_incorrect].mean()
            loss_segment_clustering = (self.loss_semantics(semantic_features, target) * batch[2]['confidences']).mean()
            self.log(f"train/loss_segment", loss_segment_clustering, on_step=True, on_epoch=False, prog_bar=True, logger=True, sync_dist=True)
        if self.current_epoch >= self.config.late_semantic_optimization:
            loss = loss + self.config.lambda_semantics * loss_semantics + self.config.lambda_semantics * self.config.lambda_segment * loss_segment_clustering
        self.manual_backward(loss)
        opts[0].step()
        metric_psnr = psnr(output_rgb, rgbs)
        self.log("train/loss_semantics", loss_semantics, on_step=True, on_epoch=False, prog_bar=True, logger=True, sync_dist=True)
        self.log("lr", opts[0].param_groups[0]['lr'], on_step=True, on_epoch=False, prog_bar=False, logger=True, sync_dist=True)
        self.log("train/psnr", metric_psnr, on_step=True, on_epoch=False, prog_bar=True, logger=True, sync_dist=True)
        output_semantics_without_invalid = output_semantics.detach().argmax(dim=1)
        output_semantics_without_invalid[semantics == 0] = 0
        train_cm = ConfusionMatrix(num_classes=self.model.num_semantic_classes, ignore_class=[0])
        metric_iou = train_cm.add_batch(output_semantics_without_invalid.cpu().numpy(), semantics.cpu().numpy(), return_miou=True)
        self.log("train/iou", metric_iou, on_step=True, on_epoch=False, prog_bar=True, logger=True, sync_dist=True)
        if self.current_epoch >= self.config.instance_optimization_epoch:
            opts[1].zero_grad(set_to_none=True)
            loss_instance_clustering = torch.zeros(1, device=self.device, requires_grad=True)
            for img_idx in range(len(batch[1]['rays'])):
                instance_features = self.forward_instance(batch[1]['rays'][img_idx], True)
                loss_instance_clustering_ = self.calculate_instance_clustering_loss(batch[1]['instances'][img_idx], instance_features, batch[1]['confidences'][img_idx])
                loss_instance_clustering = loss_instance_clustering + loss_instance_clustering_
            self.manual_backward(loss_instance_clustering)
            opts[1].step()
            self.log(f"train/loss_clustering", loss_instance_clustering, on_step=True, on_epoch=False, prog_bar=True, logger=True, sync_dist=True)

    def calculate_instance_clustering_loss(self, labels_gt, instance_features, confidences):
        virtual_gt_labels = self.create_virtual_gt_with_linear_assignment(labels_gt, instance_features)
        predicted_labels = instance_features.argmax(dim=-1)
        if torch.any(virtual_gt_labels != predicted_labels):  # should never reinforce correct labels
            return (self.loss_instances_cluster(instance_features, virtual_gt_labels) * confidences).mean()
        return torch.tensor(0., device=self.device, requires_grad=True)

    @torch.no_grad()
    def create_virtual_gt_with_linear_assignment(self, labels_gt, predicted_scores):
        labels = sorted(torch.unique(labels_gt).cpu().tolist())[:predicted_scores.shape[-1]]
        predicted_probabilities = torch.softmax(predicted_scores, dim=-1)
        cost_matrix = np.zeros([len(labels), predicted_probabilities.shape[-1]])
        for lidx, label in enumerate(labels):
            cost_matrix[lidx, :] = -(predicted_probabilities[labels_gt == label, :].sum(dim=0) / ((labels_gt == label).sum() + 1e-4)).cpu().numpy()
        assignment = scipy.optimize.linear_sum_assignment(np.nan_to_num(cost_matrix))
        new_labels = torch.zeros_like(labels_gt)
        for aidx, lidx in enumerate(assignment[0]):
            new_labels[labels_gt == labels[lidx]] = assignment[1][aidx]
        return new_labels

    def calculate_segment_clustering_loss(self, sem_features, confidences):
        target = torch.mean(torch.softmax(sem_features, dim=-1), dim=0).detach().unsqueeze(0).expand(sem_features.shape[0], -1)
        if self.config.segment_grouping_mode == "argmax_noconf":
            return self.loss_semantics(sem_features, target.argmax(-1)).mean()
        elif self.config.segment_grouping_mode == "argmax_conf":
            return (self.loss_semantics(sem_features, target.argmax(-1)) * confidences).mean()
        elif self.config.segment_grouping_mode == "prob_noconf":
            return self.loss_semantics(sem_features, target).mean()
        elif self.config.segment_grouping_mode == "prob_conf":
            return (self.loss_semantics(sem_features, target) * confidences).mean()
        raise NotImplementedError

    def validation_step(self, batch, batch_idx):
        rays, rgbs, semantics, instances, mask = batch['rays'].squeeze(), batch['rgbs'].squeeze(), batch['semantics'].squeeze(), batch['instances'].squeeze(), batch['mask'].squeeze()
        rs_semantics, rs_instances = batch['rs_semantics'].squeeze(), batch['rs_instances'].squeeze()
        probs, confs = batch['probabilities'].squeeze(), batch['confidences'].squeeze()
        output_rgb, output_semantics, output_instances, _output_depth, _, _ = self(rays, False)
        output_rgb[torch.logical_not(mask), :] = 0
        rgbs[torch.logical_not(mask), :] = 0
        loss = self.loss(output_rgb, rgbs)
        metric_psnr = psnr(output_rgb, rgbs)
        self.log("val/loss_rgb", loss, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)
        self.log("val/psnr", metric_psnr, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)
        output_semantics[torch.logical_not(mask)] = 0
        if self.config.probabilistic_ce_mode == "TTAConf":
            loss_semantics = (self.loss_semantics(output_semantics, probs) * confs).mean()
        elif self.config.probabilistic_ce_mode == "NoTTAConf":
            loss_semantics = (self.loss_semantics(output_semantics, semantics) * confs).mean()
        else:
            loss_semantics = self.loss_semantics(output_semantics, semantics).mean()
        output_semantics_without_invalid = output_semantics.detach().argmax(dim=1)
        output_semantics_without_invalid[semantics == 0] = 0
        val_cm = ConfusionMatrix(num_classes=self.model.num_semantic_classes, ignore_class=[0])
        metric_iou = val_cm.add_batch(output_semantics_without_invalid.cpu().numpy(), semantics.cpu().numpy(), return_miou=True)
        pano_pred = torch.cat([output_semantics_without_invalid.unsqueeze(1), output_instances.argmax(dim=1).unsqueeze(1)], dim=1)
        pano_target = torch.cat([semantics.unsqueeze(1), instances.unsqueeze(1)], dim=1)
        metric_pq, metric_sq, metric_rq = panoptic_quality(pano_pred, pano_target, self.train_set.things_filtered, self.train_set.stuff_filtered, allow_unknown_preds_category=True)
        self.log("val/iou", metric_iou, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)
        self.log("val/pq", metric_pq, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)
        self.log("val/sq", metric_sq, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)
        self.log("val/rq", metric_rq, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)
        self.log("val/loss_semantics", loss_semantics, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)
        val_rs_cm = ConfusionMatrix(num_classes=self.model.num_semantic_classes, ignore_class=list(self.train_set.faulty_classes))
        output_semantics_with_invalid = output_semantics.detach().argmax(dim=1)
        metric_rs_iou = val_rs_cm.add_batch(output_semantics_with_invalid.cpu().numpy(), rs_semantics.cpu().numpy(), return_miou=True)
        pano_rs_pred = torch.cat([output_semantics_with_invalid.unsqueeze(1), output_instances.argmax(dim=1).unsqueeze(1)], dim=1)
        pano_rs_target = torch.cat([rs_semantics.unsqueeze(1), rs_instances.unsqueeze(1)], dim=1)
        metric_rs_pq, metric_rs_sq, metric_rs_rq = panoptic_quality(pano_rs_pred, pano_rs_target, self.train_set.things_filtered, self.train_set.stuff_filtered, allow_unknown_preds_category=True)
        self.log("val_rs/iou", metric_rs_iou, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)
        self.log("val_rs/pq", metric_rs_pq, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)
        self.log("val_rs/sq", metric_rs_sq, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)
        self.log("val_rs/rq", metric_rs_rq, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)
        return {'loss_rgb': loss.item(), 'loss_sem': loss_semantics.item(), 'psnr': metric_psnr.item(),
                'iou': metric_iou.item(), 'pq': metric_pq.item(), 'sq': metric_sq.item(), 'rq': metric_rq.item(),
                'rs_iou': metric_rs_iou.item(), 'rs_pq': metric_rs_pq.item(), 'rs_sq': metric_rs_sq.item(), 'rs_rq': metric_rs_rq.item()}

    @rank_zero_only
    def validation_epoch_end(self, val_step_outputs):
        print()
        table = [('loss_rgb', 'loss_sem', 'psnr', 'iou', 'pq', 'sq', 'rq', 'rs_iou', 'rs_pq', 'rs_sq', 'rs_rq'), ]
        table.append(tuple([np.array([val_step_outputs[i][key] for i in range(len(val_step_outputs))]).mean() for key in table[0]]))
        print(tabulate(table, headers='firstrow', tablefmt='fancy_grid'))
        H, W = self.config.image_dim[0], self.config.image_dim[1]
        (self.output_dir_result_clusters / f"{self.global_step:06d}").mkdir(exist_ok=True)
        # self.renderer.export_instance_clusters(self.model, self.output_dir_result_clusters / f"{self.global_step:06d}")
        for batch_idx, batch in enumerate(self.val_dataloader()):
            if batch_idx in self.config.visualized_indices:
                rays, rgbs, semantics, instances = batch['rays'].squeeze().to(self.device), batch['rgbs'].squeeze().to(self.device), \
                                                   batch['semantics'].squeeze().to(self.device), batch['instances'].squeeze().to(self.device)
                rs_semantics, rs_instances = batch['rs_semantics'].squeeze().to(self.device), batch['rs_instances'].squeeze().to(self.device)
                mask = batch['mask'].squeeze().to(self.device)
                output_rgb, output_semantics, output_instances, output_depth, _, _ = self(rays, False)
                output_rgb[torch.logical_not(mask), :] = 0
                output_semantics[torch.logical_not(mask)] = 0
                output_instances[torch.logical_not(mask)] = 0
                rgbs[torch.logical_not(mask), :] = 0
                stack = visualize_panoptic_outputs(output_rgb, output_semantics, output_instances, output_depth, rgbs, rs_semantics, rs_instances, H, W, thing_classes=self.train_set.segmentation_data.fg_classes)
                save_image(stack, self.output_dir_result_images / f"{self.global_step:06d}_{batch_idx:04d}.jpg", value_range=(0, 1), nrow=5, normalize=True)
                if self.config.logger == 'wandb':
                    self.logger.log_image(key=f"images/{batch_idx:04d}", images=[make_grid(stack, value_range=(0, 1), nrow=5, normalize=True)])
                else:
                    self.logger.experiment.add_image(f'visuals/{batch_idx:04d}', make_grid(stack, value_range=(0, 1), nrow=5, normalize=True), global_step=self.global_step)

    def train_dataloader(self):
        loaders = {
            0: DataLoader(self.train_set, self.config.batch_size, shuffle=True, pin_memory=True, drop_last=True, num_workers=self.config.num_workers)
        }
        train_instance_set = get_inconsistent_single_dataset(self.config)
        assert len(train_instance_set) > 0, f"Warning: Empty instance dataset"
        loaders[1] = DataLoader(train_instance_set, self.config.batch_size_contrastive, shuffle=True, drop_last=True, collate_fn=train_instance_set.collate_fn, num_workers=0)
        if self.config.segment_grouping_mode != "none":
            loaders[2] = DataLoader(self.train_segment_set, self.config.batch_size_segments, shuffle=False, drop_last=True, collate_fn=self.train_segment_set.collate_fn, num_workers=0)
        return loaders

    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=1, shuffle=False, drop_last=False, num_workers=0)

    def on_train_epoch_start(self):
        self.current_lambda_dist_reg = self.config.lambda_dist_reg * (1 - math.exp(-0.25 * self.current_epoch))
        if self.current_epoch in self.config.bbox_aabb_reset_epochs:
            self.renderer.update_bbox_aabb_and_shrink(self.model)
        if self.current_epoch in self.config.grid_upscale_epochs:
            num_voxel_list = (torch.round(torch.exp(torch.linspace(np.log(self.config.min_grid_dim**3), np.log(self.config.max_grid_dim**3), len(self.config.grid_upscale_epochs)+1))).long()).tolist()[1:]
            target_num_voxels = num_voxel_list[self.config.grid_upscale_epochs.index(self.current_epoch)]
            target_resolution = self.renderer.get_target_resolution(target_num_voxels)
            self.config.weight_decay = 0
            self.model.upsample_volume_grid(target_resolution)
            self.renderer.update_step_size(target_resolution)
            self.trainer.strategy.setup_optimizers(self.trainer)
        if self.config.segment_grouping_mode != "none" and self.current_epoch >= self.config.segment_optimization_epoch:
            self.train_segment_set.enabled = True

    def on_load_checkpoint(self, checkpoint):
        for epoch in self.config.grid_upscale_epochs[::-1]:
            if checkpoint['epoch'] >= epoch:
                grid_dim = checkpoint["state_dict"]["renderer.grid_dim"].cpu()
                self.model.upsample_volume_grid(grid_dim)
                self.renderer.update_step_size(grid_dim)
                self.config.weight_decay = 0
                self.trainer.strategy.setup_optimizers(self.trainer)
                break


@hydra.main(config_path='../config', config_name='panopli', version_base='1.2')
def main(config):
    trainer = create_trainer("PanopLi", config)
    model = TensoRFTrainer(config)
    trainer.fit(model)


if __name__ == '__main__':
    main()