Source code for lightning.pytorch.callbacks.spike

import os
from collections.abc import Mapping
from typing import Any, Union

import torch

import lightning.pytorch as pl
from lightning.fabric.utilities.spike import SpikeDetection as FabricSpikeDetection
from lightning.pytorch.callbacks.callback import Callback


[docs]class SpikeDetection(FabricSpikeDetection, Callback):
[docs] @torch.no_grad() def on_train_batch_end( # type: ignore self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: Union[torch.Tensor, Mapping[str, torch.Tensor]], batch: Any, batch_idx: int, ) -> None: if isinstance(outputs, torch.Tensor): loss = outputs.detach() elif isinstance(outputs, Mapping): loss = outputs["loss"].detach() else: raise TypeError(f"outputs have to be of type torch.Tensor or Mapping, got {type(outputs).__qualname__}") if self.exclude_batches_path is None: self.exclude_batches_path = os.path.join(trainer.default_root_dir, "skip_batches.json") return FabricSpikeDetection.on_train_batch_end(self, trainer, loss, batch, batch_idx) # type: ignore