Torchmetrics not moved to device in callbacks

I was refactoring my code and was moving some metrics from the LightningModule to a callback but got a RuntimeError: Encountered different devices in metric calculation.

I managed so recreate it using the bug report colab template
here.

I also managed a fix by moving the metrics to the device in the setup stage.
but is this intended behavior? to quote the torch metrics documentation " * Modular metrics are automatically placed on the correct device when properly defined inside a LightningModule. This means that your data will always be placed on the same device as your metrics. No need to call .to(device) anymore!"
is callbacks not properly inside a lightningModule?