BasePredictionWriter¶
- class pytorch_lightning.callbacks.BasePredictionWriter(write_interval='batch')[source]¶
Bases:
pytorch_lightning.callbacks.callback.Callback
Base class to implement how the predictions should be stored.
Example:
import torch from pytorch_lightning.callbacks import BasePredictionWriter class CustomWriter(BasePredictionWriter): def __init__(self, output_dir: str, write_interval: str): super().__init__(write_interval) self.output_dir = output_dir def write_on_batch_end( self, trainer, pl_module: 'LightningModule', prediction: Any, batch_indices: List[int], batch: Any, batch_idx: int, dataloader_idx: int ): torch.save(prediction, os.path.join(self.output_dir, dataloader_idx, f"{batch_idx}.pt")) def write_on_epoch_end( self, trainer, pl_module: 'LightningModule', predictions: List[Any], batch_indices: List[Any] ): torch.save(predictions, os.path.join(self.output_dir, "predictions.pt"))
- on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]¶
Called when the predict batch ends.
- Return type:
- on_predict_epoch_end(trainer, pl_module, outputs)[source]¶
Called when the predict epoch ends.
- Return type: