BasePredictionWriter
- class pytorch_lightning.callbacks.BasePredictionWriter(write_interval='batch')[source]
Bases:
pytorch_lightning.callbacks.base.Callback
Base class to implement how the predictions should be stored.
- Parameters
write_interval (
str
) – When to write.
Example:
import torch from pytorch_lightning.callbacks import BasePredictionWriter class CustomWriter(BasePredictionWriter): def __init__(self, output_dir: str, write_interval: str): super().__init__(write_interval) self.output_dir def write_on_batch_end( self, trainer, pl_module: 'LightningModule', prediction: Any, batch_indices: List[int], batch: Any, batch_idx: int, dataloader_idx: int ): torch.save(prediction, os.path.join(self.output_dir, dataloader_idx, f"{batch_idx}.pt")) def write_on_epoch_end( self, trainer, pl_module: 'LightningModule', predictions: List[Any], batch_indices: List[Any] ): torch.save(predictions, os.path.join(self.output_dir, "predictions.pt"))
- on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]
Called when the predict batch ends.
- Return type
- on_predict_epoch_end(trainer, pl_module, outputs)[source]
Called when the predict epoch ends.
- Return type
- write_on_batch_end(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx)[source]
Override with the logic to write a single batch.
- Return type