Timer¶
- class lightning.pytorch.callbacks.Timer(duration=None, interval=Interval.step, verbose=True)[source]¶
Bases:
Callback
The Timer callback tracks the time spent in the training, validation, and test loops and interrupts the Trainer if the given time limit for the training loop is reached.
- Parameters:
duration¶ (
Union
[str
,timedelta
,Dict
[str
,int
],None
]) – A string in the format DD:HH:MM:SS (days, hours, minutes seconds), or adatetime.timedelta
, or a dict containing key-value compatible withtimedelta
.interval¶ (
str
) – Determines if the interruption happens on epoch level or mid-epoch. Can be either"epoch"
or"step"
.verbose¶ (
bool
) – Set this toFalse
to suppress logging messages.
- Raises:
MisconfigurationException – If
interval
is not one of the supported choices.
Example:
from lightning.pytorch import Trainer from lightning.pytorch.callbacks import Timer # stop training after 12 hours timer = Timer(duration="00:12:00:00") # or provide a datetime.timedelta from datetime import timedelta timer = Timer(duration=timedelta(weeks=1)) # or provide a dictionary timer = Timer(duration=dict(weeks=4, days=2)) # force training to stop after given time limit trainer = Trainer(callbacks=[timer]) # query training/validation/test time (in seconds) timer.time_elapsed("train") timer.start_time("validate") timer.end_time("test")
- end_time(stage=RunningStage.TRAINING)[source]¶
Return the end time of a particular stage (in seconds)
- load_state_dict(state_dict)[source]¶
Called when loading a checkpoint, implement to reload callback state given callback’s
state_dict
.
- on_train_batch_end(trainer, *args, **kwargs)[source]¶
Called when the train batch ends. :rtype:
None
Note
The value
outputs["loss"]
here will be the normalized value w.r.taccumulate_grad_batches
of the loss returned fromtraining_step
.
- on_train_epoch_end(trainer, *args, **kwargs)[source]¶
Called when the train epoch ends.
To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the
pytorch_lightning.LightningModule
and access them in this hook:class MyLightningModule(L.LightningModule): def __init__(self): super().__init__() self.training_step_outputs = [] def training_step(self): loss = ... self.training_step_outputs.append(loss) return loss class MyCallback(L.Callback): def on_train_epoch_end(self, trainer, pl_module): # do something with all training_step outputs, for example: epoch_mean = torch.stack(pl_module.training_step_outputs).mean() pl_module.log("training_epoch_mean", epoch_mean) # free up the memory pl_module.training_step_outputs.clear()
- Return type:
- on_validation_start(trainer, pl_module)[source]¶
Called when the validation loop begins.
- Return type:
- start_time(stage=RunningStage.TRAINING)[source]¶
Return the start time of a particular stage (in seconds)