Shortcuts

BasePredictionWriter

class lightning.pytorch.callbacks.BasePredictionWriter(write_interval='batch')[source]

Bases: lightning.pytorch.callbacks.callback.Callback

Base class to implement how the predictions should be stored.

Parameters

write_interval (Literal[‘batch’, ‘epoch’, ‘batch_and_epoch’]) – When to write.

Example:

import torch
from lightning.pytorch.callbacks import BasePredictionWriter

class CustomWriter(BasePredictionWriter):

    def __init__(self, output_dir, write_interval):
        super().__init__(write_interval)
        self.output_dir = output_dir

    def write_on_batch_end(
        self, trainer, pl_module', prediction, batch_indices, batch, batch_idx, dataloader_idx
    ):
        torch.save(prediction, os.path.join(self.output_dir, dataloader_idx, f"{batch_idx}.pt"))

    def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
        torch.save(predictions, os.path.join(self.output_dir, "predictions.pt"))


pred_writer = CustomWriter(output_dir="pred_path", write_interval="epoch")
trainer = Trainer(callbacks=[pred_writer])
model = BoringModel()
trainer.predict(model, return_predictions=False)

Example:

# multi-device inference example

import torch
from lightning.pytorch.callbacks import BasePredictionWriter

class CustomWriter(BasePredictionWriter):

    def __init__(self, output_dir, write_interval):
        super().__init__(write_interval)
        self.output_dir = output_dir

    def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
        # this will create N (num processes) files in `output_dir` each containing
        # the predictions of it's respective rank
        torch.save(predictions, os.path.join(self.output_dir, f"predictions_{trainer.global_rank}.pt"))

        # optionally, you can also save `batch_indices` to get the information about the data index
        # from your prediction data
        torch.save(batch_indices, os.path.join(self.output_dir, f"batch_indices_{trainer.global_rank}.pt"))


# or you can set `writer_interval="batch"` and override `write_on_batch_end` to save
# predictions at batch level
pred_writer = CustomWriter(output_dir="pred_path", write_interval="epoch")
trainer = Trainer(accelerator="gpu", strategy="ddp", devices=8, callbacks=[pred_writer])
model = BoringModel()
trainer.predict(model, return_predictions=False)
on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]

Called when the predict batch ends.

Return type

None

on_predict_epoch_end(trainer, pl_module)[source]

Called when the predict epoch ends.

Return type

None

write_on_batch_end(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx)[source]

Override with the logic to write a single batch.

Return type

None

write_on_epoch_end(trainer, pl_module, predictions, batch_indices)[source]

Override with the logic to write all batches.

Return type

None