I am trying to convert my PyTorch code to PyTorch lightning but I am getting an error.
In init, I compute the detect
by Detect
but it is not recognized (?) in the forward
module. Do I need to include the Detect
inside of the lightning module somehow?
import pytorch_lightning as pl
from utils.make import make_vgg, make_extras, make_loc_conf
from utils.others import L2Norm, DBox, Detect
from utils.loss import MultiBoxLoss
class SSD(pl.LightningModule):
def __init__(self, hparams, phase, cfg):
super(SSD, self).__init__()
self.save_hyperparameters(hparams)
self.phase = phase
self.num_classes = cfg["num_classes"]
self.vgg = make_vgg()
self.extras = make_extras()
self.L2Norm = L2Norm()
self.loc, self.conf = make_loc_conf(
cfg["num_classes"], cfg["bbox_aspect_num"]
)
dbox = DBox(cfg)
self.dbox_list = dbox.make_dbox_list()
if self.phase == 'inference':
self.detect = Detect.apply
def loss_function(self, outputs, targets):
criterion = MultiBoxLoss(jaccard_thresh=0.5, neg_pos = 3)
# Loss function = loss_l(position of bbox) + loss_c(classification)
print('Type of output is: ')
print(type(outputs))
loss_l, loss_c = criterion(outputs, targets)
return loss_l + loss_c
def forward(self, x):
sources = list()
loc = list()
conf = list()
for k in range(23):
x = self.vgg[k](x)
source1 = self.L2Norm(x)
sources.append(source1)
for k in range(23,len(self.vgg)):
x = self.vgg[k](x)
sources.append(x)
for k, v in enumerate(self.extras):
x = F.relu(v(x),inplace=True)
if k%2==1:
sources.append(x)
for (x,l,c) in zip(sources,self.loc,self.conf):
loc.append(l(x).permute(0,2,3,1).contiguous())
conf.append(c(x).permute(0,2,3,1).contiguous())
loc = torch.cat([o.view(o.size(0),-1) for o in loc], 1)
conf = torch.cat([o.view(o.size(0),-1) for o in conf], 1)
loc = loc.view(loc.size(0), -1, 4)
conf = conf.view(conf.size(0), -1, self.num_classes)
output = (loc, conf, self.dbox_list)
if self.phase == "inference":
return self.detect(output[0],output[1],output[2])
else:
return output
The class Detect
is defined like this:
class Detect(Function):
def __init__(self,conf_thresh=0.01, top_k = 200, nms_thresh=0.45):
self.softmax = nn.Softmax(dim=-1)
self.conf_thresh = conf_thresh
self.top_k = top_k
self.nms_thresh = nms_thresh # uncommented by Daigo
def forward(self, loc_data, conf_data, dbox_list):
num_batch = loc_data.size(0)
num_dbox = loc_data.size(1)
num_classes = conf_data.size(2)
conf_data = self.softmax(conf_data)
output = torch.zeros(num_batch, num_classes, self.top_k, 5)
conf_preds = conf_data.transpose(2,1)
for i in range(num_batch):
decoded_boxes = decode(loc_data[i], dbox_list)
conf_scores = conf_preds[i].clone()
for cl in range(1,num_classes):
c_mask = conf_scores[cl].gt(self.conf_thresh)
scores = conf_scores[cl][c_mask]
if scores.nelement()==0:
continue
l_mask = c_mask.unsqueeze(1).expand_as(decoded_boxes)
boxes = decoded_boxes[l_mask].view(-1,4)
ids, count = nm_suppression(
boxes, scores, self.nms_thresh, self.top_k
)
output[i,cl,:count] = torch.cat((scores[ids[:count]].unsqueeze(1), boxes[ids[:count]]), 1)
return output
I got the following error:
File “C:\Users\user\Object Detection - spyder\script_Sandia_0608_lightning\ssd_light.py”, line 226, in
trainer.fit(ssd)
File “C:\Users\user\anaconda3\lib\site-packages\pytorch_lightning\trainer\trainer.py”, line 458, in fit
self._run(model)
File “C:\Users\user\anaconda3\lib\site-packages\pytorch_lightning\trainer\trainer.py”, line 756, in _run
self.dispatch()
File “C:\Users\user\anaconda3\lib\site-packages\pytorch_lightning\trainer\trainer.py”, line 797, in dispatch
self.accelerator.start_training(self)
File “C:\Users\user\anaconda3\lib\site-packages\pytorch_lightning\accelerators\accelerator.py”, line 96, in start_training
self.training_type_plugin.start_training(trainer)
File “C:\Users\user\anaconda3\lib\site-packages\pytorch_lightning\plugins\training_type\training_type_plugin.py”, line 144, in start_training
self._results = trainer.run_stage()
File “C:\Users\user\anaconda3\lib\site-packages\pytorch_lightning\trainer\trainer.py”, line 807, in run_stage
return self.run_train()
File “C:\Users\user\anaconda3\lib\site-packages\pytorch_lightning\trainer\trainer.py”, line 842, in run_train
self.run_sanity_check(self.lightning_module)
File “C:\Users\user\anaconda3\lib\site-packages\pytorch_lightning\trainer\trainer.py”, line 1107, in run_sanity_check
self.run_evaluation()
File “C:\Users\user\anaconda3\lib\site-packages\pytorch_lightning\trainer\trainer.py”, line 962, in run_evaluation
output = self.evaluation_loop.evaluation_step(batch, batch_idx, dataloader_idx)
File “C:\Users\user\anaconda3\lib\site-packages\pytorch_lightning\trainer\evaluation_loop.py”, line 174, in evaluation_step
output = self.trainer.accelerator.validation_step(args)
File “C:\Users\user\anaconda3\lib\site-packages\pytorch_lightning\accelerators\accelerator.py”, line 226, in validation_step
return self.training_type_plugin.validation_step(*args)
File “C:\Users\user\anaconda3\lib\site-packages\pytorch_lightning\plugins\training_type\training_type_plugin.py”, line 161, in validation_step
return self.lightning_module.validation_step(*args, **kwargs)
File “C:\Users\user\Object Detection - spyder\script_Sandia_0608_lightning\ssd_light.py”, line 136, in validation_step
outputs = self(images)
File “C:\Users\user\anaconda3\lib\site-packages\torch\nn\modules\module.py”, line 889, in _call_impl
result = self.forward(*input, **kwargs)
File “C:\Users\user\Object Detection - spyder\script_Sandia_0608_lightning\ssd_light.py”, line 116, in forward
return self.detect(output[0],output[1],output[2])
File “C:\Users\user\anaconda3\lib\site-packages\torch\nn\modules\module.py”, line 947, in getattr
raise AttributeError(“‘{}’ object has no attribute ‘{}’”.format(
AttributeError: ‘SSD’ object has no attribute ‘detect’