Shortcuts

BasePredictionWriter

class pytorch_lightning.callbacks.BasePredictionWriter(write_interval='batch')[source]

Bases: pytorch_lightning.callbacks.base.Callback

Base class to implement how the predictions should be stored.

Parameters

write_interval (str) – When to write.

Example:

import torch
from pytorch_lightning.callbacks import BasePredictionWriter

class CustomWriter(BasePredictionWriter):

    def __init__(self, output_dir: str, write_interval: str):
        super().__init__(write_interval)
        self.output_dir

    def write_on_batch_end(
        self, trainer, pl_module: 'LightningModule', prediction: Any, batch_indices: List[int], batch: Any,
        batch_idx: int, dataloader_idx: int
    ):
        torch.save(prediction, os.path.join(self.output_dir, dataloader_idx, f"{batch_idx}.pt"))

    def write_on_epoch_end(
        self, trainer, pl_module: 'LightningModule', predictions: List[Any], batch_indices: List[Any]
    ):
        torch.save(predictions, os.path.join(self.output_dir, "predictions.pt"))
on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]

Called when the predict batch ends.

Return type

None

on_predict_epoch_end(trainer, pl_module, outputs)[source]

Called when the predict epoch ends.

Return type

None

write_on_batch_end(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx)[source]

Override with the logic to write a single batch.

Return type

None

write_on_epoch_end(trainer, pl_module, predictions, batch_indices)[source]

Override with the logic to write all batches.

Return type

None