Best practices for double precision training

Hi, I am writing a pytorch lightning wrapper for training an ML interatomic potential, where double precision is used frequently. All my input data to the model is generally in double precision. What are the best practices to ensure that I can use both single and double precision when needed.

In my wrapper, pl.Trainer works fine with precision=32, but for precision=64 I get error:

  File "/opt/mambaforge/mambaforge/envs/colabfit/lib/python3.9/site-packages/torch/autograd/", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Found dtype Float but expected Double

Which is surprising and opposite of what I expected! I thought single precision would be a problem.

Below is my trainer invoking line.

precision = 32 if self.model_manifest["precision"] == "single" else 64
        return pl.Trainer(
            logger=[self.tb_logger, self.csv_logger],