NN output within a numba jitted function


I have a jitted function within which I need to use the output of a neural network (trained using PyTorch Lightning). The pseudo code will make this clearer:

while True:
x = sample_from_model() ← numpy type, hence compatible with numba
out = NN(torch.Tensor(x)) ← incompatible with numba

Is there a way to circumvent this problem? First thing that comes to mind is to manually extract the weights and compute the forward pass.

Thanks in advance,

Did you achieve your goal? I’ve a similar problem. I computed the forward pass within a jitted function, unfortunately it didn’t solve all my issues.