Source code for pytorch_lightning.strategies.single_hpu

# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Dict, Optional, Union

from torch.nn import Module
from torch.optim.optimizer import Optimizer

import pytorch_lightning as pl
from lightning_fabric.plugins import CheckpointIO
from lightning_fabric.utilities.types import _DEVICE
from pytorch_lightning.accelerators.hpu import _HPU_AVAILABLE
from import HPUCheckpointIO
from import _WrappingCheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.single_device import SingleDeviceStrategy
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import STEP_OUTPUT

    import habana_frameworks.torch.core as htcore

[docs]class SingleHPUStrategy(SingleDeviceStrategy): """Strategy for training on single HPU device.""" strategy_name = "hpu_single" def __init__( self, device: _DEVICE = "hpu", accelerator: Optional["pl.accelerators.Accelerator"] = None, checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None, ): if not _HPU_AVAILABLE: raise MisconfigurationException("`SingleHPUStrategy` requires HPU devices to run") super().__init__( accelerator=accelerator, device=device, checkpoint_io=checkpoint_io, precision_plugin=precision_plugin, ) @property def checkpoint_io(self) -> CheckpointIO: if self._checkpoint_io is None: self._checkpoint_io = HPUCheckpointIO() elif isinstance(self._checkpoint_io, _WrappingCheckpointIO): self._checkpoint_io.checkpoint_io = HPUCheckpointIO() return self._checkpoint_io @checkpoint_io.setter def checkpoint_io(self, io: Optional[CheckpointIO]) -> None: self._checkpoint_io = io @property def is_distributed(self) -> bool: return False
[docs] def setup(self, trainer: "pl.Trainer") -> None: self.model_to_device() super().setup(trainer)
[docs] def setup_optimizers(self, trainer: "pl.Trainer") -> None: super().setup_optimizers(trainer)
[docs] def model_to_device(self) -> None: # type: ignore
def on_after_backward(self) -> None: # Break lazy accumulation of graph after fwd+bwd htcore.mark_step()
[docs] def optimizer_step( self, optimizer: Optimizer, opt_idx: int, closure: Callable[[], Any], model: Optional[Union["pl.LightningModule", Module]] = None, **kwargs: Any, ) -> Any: optimizer_output = super().optimizer_step(optimizer, opt_idx, closure, model, **kwargs) # Break lazy accumulation of graph after optimizer htcore.mark_step() return optimizer_output
def validation_step_end(self, step_output: STEP_OUTPUT) -> STEP_OUTPUT: # Break lazy accumulation of graph after every step htcore.mark_step() return step_output def test_step_end(self, step_output: STEP_OUTPUT) -> STEP_OUTPUT: # Break lazy accumulation of graph after every step htcore.mark_step() return step_output @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: strategy_registry.register( cls.strategy_name, cls, description=f"{cls.__class__.__name__}", )

© Copyright Copyright (c) 2018-2023, Lightning AI et al...

Built with Sphinx using a theme provided by Read the Docs.