Timer¶
- class lightning.pytorch.callbacks.Timer(duration=None, interval=Interval.step, verbose=True)[source]¶
Bases:
Callback
The Timer callback tracks the time spent in the training, validation, and test loops and interrupts the Trainer if the given time limit for the training loop is reached.
- Parameters:
duration¶ (
Union
[str
,timedelta
,Dict
[str
,int
],None
]) – A string in the format DD:HH:MM:SS (days, hours, minutes seconds), or adatetime.timedelta
, or a dict containing key-value compatible withtimedelta
.interval¶ (
str
) – Determines if the interruption happens on epoch level or mid-epoch. Can be either"epoch"
or"step"
.verbose¶ (
bool
) – Set this toFalse
to suppress logging messages.
- Raises:
MisconfigurationException – If
interval
is not one of the supported choices.
Example:
from lightning.pytorch import Trainer from lightning.pytorch.callbacks import Timer # stop training after 12 hours timer = Timer(duration="00:12:00:00") # or provide a datetime.timedelta from datetime import timedelta timer = Timer(duration=timedelta(weeks=1)) # or provide a dictionary timer = Timer(duration=dict(weeks=4, days=2)) # force training to stop after given time limit trainer = Trainer(callbacks=[timer]) # query training/validation/test time (in seconds) timer.time_elapsed("train") timer.start_time("validate") timer.end_time("test")
- end_time(stage=RunningStage.TRAINING)[source]¶
Return the end time of a particular stage (in seconds)
- load_state_dict(state_dict)[source]¶
Called when loading a checkpoint, implement to reload callback state given callback’s
state_dict
.
- on_train_batch_end(trainer, *args, **kwargs)[source]¶
Called when the train batch ends. :rtype:
None
Note
The value
outputs["loss"]
here will be the normalized value w.r.taccumulate_grad_batches
of the loss returned fromtraining_step
.
- on_train_epoch_end(trainer, *args, **kwargs)[source]¶
Called when the train epoch ends.
To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the
lightning.pytorch.core.LightningModule
and access them in this hook:class MyLightningModule(L.LightningModule): def __init__(self): super().__init__() self.training_step_outputs = [] def training_step(self): loss = ... self.training_step_outputs.append(loss) return loss class MyCallback(L.Callback): def on_train_epoch_end(self, trainer, pl_module): # do something with all training_step outputs, for example: epoch_mean = torch.stack(pl_module.training_step_outputs).mean() pl_module.log("training_epoch_mean", epoch_mean) # free up the memory pl_module.training_step_outputs.clear()
- Return type:
- on_validation_start(trainer, pl_module)[source]¶
Called when the validation loop begins.
- Return type:
- start_time(stage=RunningStage.TRAINING)[source]¶
Return the start time of a particular stage (in seconds)