ThroughputMonitor¶
- class lightning.pytorch.callbacks.ThroughputMonitor(batch_size_fn, length_fn=None, **kwargs)[source]¶
Bases:
Callback
Computes and logs throughput with the
Throughput
Example:
class MyModel(LightningModule): def setup(self, stage): with torch.device("meta"): model = MyModel() def sample_forward(): batch = torch.randn(..., device="meta") return model(batch) self.flops_per_batch = measure_flops(model, sample_forward, loss_fn=torch.Tensor.sum) logger = ... throughput = ThroughputMonitor(batch_size_fn=lambda batch: batch.size(0)) trainer = Trainer(max_steps=1000, log_every_n_steps=10, callbacks=throughput, logger=logger) model = MyModel() trainer.fit(model)
Notes
It assumes that the batch size is the same during all iterations.
It will try to access a
flops_per_batch
attribute on yourLightningModule
on every iteration. We suggest using themeasure_flops()
function for this. You might want to compute it differently each time based on your setup.
- Parameters:
- on_predict_batch_end(trainer, pl_module, outputs, batch, *_, **__)[source]¶
Called when the predict batch ends.
- Return type:
- on_test_batch_end(trainer, pl_module, outputs, batch, *_, **__)[source]¶
Called when the test batch ends.
- Return type:
- on_train_batch_end(trainer, pl_module, outputs, batch, *_)[source]¶
Called when the train batch ends. :rtype:
None
Note
The value
outputs["loss"]
here will be the normalized value w.r.taccumulate_grad_batches
of the loss returned fromtraining_step
.