I have a Trainer
with truncated_bptt_steps
option like this:
trainer = pl.Trainer(truncated_bptt_steps=100)
A training_step
method looks like this:
def training_step(self, batch, batch_idx, hiddens):
out, hiddens = self.lstm(data, hidden)
result = pl.TrainResult(minimize=loss, hiddens=hiddens)
return result
I have a problem because test_step
doesn’t have a hiddens
arguments, but forward
method of a neural network needs it:
def forward(self, input, hiddens):
output, hiddens = self.lstm(input, hiddens)
return output, hiddens
My current test_step
looks like this:
def test_step(self, batch, batch_idx, hiddens) -> Any:
logits, hiddens = self.lstm(X_batch, hiddens)
...
But when I run the whole model I have an error when reaching test_step
:
File "/Users/ken/opt/anaconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 333, in _evaluate
output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)
File "/Users/ken/opt/anaconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 685, in evaluation_forward
output = model.test_step(*args)
TypeError: test_step() missing 1 required positional argument: 'hiddens'