class TrialModel(pl.LightningModule):
def __init__(self, cfg: Dict, pretrained=True):
super().__init__()
self.backbone = resnet18(pretrained=True, progress=True)
----
def forward(self, x):
----
return pred, confidences
def training_step(self, batch, batch_idx):
targets = torch.tensor(batch["target_positions"], device=self.device)
data = torch.tensor(batch["image"], device=self.device)
outputs,confidence = self(data)
loss = pytorch_neg_multi_log_likelihood_batch(targets, outputs, confidence)
pbar ={'train_loss':loss}
return {'loss':loss,'progress_bar':pbar}
def configure_optimizers(self):
return optim.Adam(self.parameters(), lr=1e-3)
# data loader
os.environ["DATA_FOLDER"] = DIR_INPUT
dm = LocalDataManager(None)
train_cfg = cfg["train_data_loader"]
rasterizer = build_rasterizer(cfg, dm) #cfg is predefined dictionary
train_zarr = ChunkedDataset(dm.require(train_cfg["key"])).open()
train_dataset = AgentDataset(cfg, train_zarr, rasterizer,min_frame_future=10)
train_dataloader = DataLoader(train_dataset,
shuffle=train_cfg["shuffle"],#shuffle=True
batch_size=train_cfg["batch_size"],#batch_size=24
num_workers=train_cfg["num_workers"],pin_memory=True)#num_workers=4
model = TrialtModel(cfg, pretrained=False)
trainer = pl.Trainer(tpu_cores=8, max_steps=500)#,
trainer.fit(model,train_dataloader)
ERROR:
GPU available: False, used: False
TPU available: True, using: 8 TPU cores
training on 8 TPU cores
INIT TPU local core: 0, global rank: 0 with XLA_USE_BF16=None
INIT TPU local core: 6, global rank: 6 with XLA_USE_BF16=None
INIT TPU local core: 5, global rank: 5 with XLA_USE_BF16=None
INIT TPU local core: 7, global rank: 7 with XLA_USE_BF16=None
INIT TPU local core: 3, global rank: 3 with XLA_USE_BF16=None
INIT TPU local core: 1, global rank: 1 with XLA_USE_BF16=None
INIT TPU local core: 4, global rank: 4 with XLA_USE_BF16=None
INIT TPU local core: 2, global rank: 2 with XLA_USE_BF16=None
| Name | Type | Params
----------------------------------------
0 | backbone | ResNet | 11 M
1 | head | Sequential | 2 M
2 | logit | Linear | 1 M
Epoch 0: 0%
0/88561 [00:00<?, ?it/s]
---------------------------------------------------------------------------
Exception Traceback (most recent call last)
<ipython-input-38-0766566ac519> in <module>()
24 trainer = pl.Trainer(tpu_cores=8, max_steps=500)#,
25
---> 26 trainer.fit(model,train_dataloader)
27
5 frames
/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/states.py in wrapped_fn(self, *args, **kwargs)
46 if entering is not None:
47 self.state = entering
---> 48 result = fn(self, *args, **kwargs)
49
50 # The INTERRUPTED state can be set inside the run function. To indicate that run was interrupted
/usr/local/lib/python3.6/dist-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloader, val_dataloaders, datamodule)
1076 self.accelerator_backend = TPUBackend(self)
1077 self.accelerator_backend.setup()
-> 1078 self.accelerator_backend.train(model)
1079 self.accelerator_backend.teardown(model)
1080
/usr/local/lib/python3.6/dist-packages/pytorch_lightning/accelerators/tpu_backend.py in train(self, model)
85 args=(model, self.trainer, self.mp_queue),
86 nprocs=self.trainer.tpu_cores,
---> 87 start_method=self.start_method
88 )
89
/usr/local/lib/python3.6/dist-packages/torch_xla/distributed/xla_multiprocessing.py in spawn(fn, args, nprocs, join, daemon, start_method)
393 join=join,
394 daemon=daemon,
--> 395 start_method=start_method)
396
397
/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/spawn.py in start_processes(fn, args, nprocs, join, daemon, start_method)
155
156 # Loop on join until it returns True or raises an exception.
--> 157 while not context.join():
158 pass
159
/usr/local/lib/python3.6/dist-packages/torch/multiprocessing/spawn.py in join(self, timeout)
105 raise Exception(
106 "process %d terminated with signal %s" %
--> 107 (error_index, name)
108 )
109 else:
Exception: process 6 terminated with signal SIGKILL
I am getting the above error while running on tpu. There seems to be some issue with (start_method of xmp.spawn torch mutiprocessing ) data loading. It runs without any error while running on a single gpu. I am loading the data correctly? or any other issue?
Thanks