Script freezes when Trainer is instantiated

I can run once a training script with pytorch-lightning. However, after the training finishes, if train to run it again, the code freezes when the L.Trainer is instantiated. There are no error messages.

Only if I shutdown and restart, I can run it once again, but then the problem persist for the next time.

This happens to me with different codes, even in the “lightning in 15 minutes” example.

Here is a minimal code to reproduce it.

import os
import torch
from torch import optim, nn, utils
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import pytorch_lightning as L

# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))


# define the LightningModule
class LitAutoEncoder(L.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        x_hat = self.model_forward(x)
        loss = nn.functional.mse_loss(x_hat, x)
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        return batch

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


# init the autoencoder
autoencoder = LitAutoEncoder(encoder, decoder)

# setup data
dataset = MNIST(os.getcwd(), download=True, train=True, transform=ToTensor())
# use 20% of training data for validation
train_set_size = int(len(dataset) * 0.8)
valid_set_size = len(dataset) - train_set_size
seed = torch.Generator().manual_seed(42)
train_set, val_set = utils.data.random_split(dataset, [train_set_size, valid_set_size], generator=seed)
train_loader = utils.data.DataLoader(train_set, num_workers=15)
valid_loader = utils.data.DataLoader(val_set, num_workers=15)

print("Before instantiate Trainer")
# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = L.Trainer(limit_train_batches=100, max_epochs=10, check_val_every_n_epoch=10, accelerator="gpu")
print("After instantiate Trainer")