I have a jitted function within which I need to use the output of a neural network (trained using PyTorch Lightning). The pseudo code will make this clearer:
x = sample_from_model() ← numpy type, hence compatible with numba
out = NN(torch.Tensor(x)) ← incompatible with numba
Is there a way to circumvent this problem? First thing that comes to mind is to manually extract the weights and compute the forward pass.
Thanks in advance,