Source code for lightning.pytorch.callbacks.spike

import os
from typing import Any, Mapping, Union

import torch

import lightning.pytorch as pl
from lightning.fabric.utilities.spike import SpikeDetection as FabricSpikeDetection
from lightning.pytorch.callbacks.callback import Callback


[docs]class SpikeDetection(FabricSpikeDetection, Callback):
[docs] @torch.no_grad() def on_train_batch_end( # type: ignore self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: Union[torch.Tensor, Mapping[str, torch.Tensor]], batch: Any, batch_idx: int, ) -> None: if isinstance(outputs, torch.Tensor): loss = outputs.detach() elif isinstance(outputs, Mapping): loss = outputs["loss"].detach() else: raise TypeError(f"outputs have to be of type torch.Tensor or Mapping, got {type(outputs).__qualname__}") if self.exclude_batches_path is None: self.exclude_batches_path = os.path.join(trainer.default_root_dir, "skip_batches.json") return FabricSpikeDetection.on_train_batch_end(self, trainer, loss, batch, batch_idx) # type: ignore