Source code for pytorch_lightning.lite.lite
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
from typing import List, Optional, Tuple, Union
from lightning_utilities.core.rank_zero import rank_zero_deprecation, rank_zero_warn
from lightning_lite.connector import _PLUGIN_INPUT as _LITE_PLUGIN_INPUT
from lightning_lite.connector import _PRECISION_INPUT
from lightning_lite.lite import LightningLite as _NewLightningLite
from lightning_lite.plugins import CheckpointIO, ClusterEnvironment
from lightning_lite.plugins import DeepSpeedPrecision as LiteDeepSpeedPrecision
from lightning_lite.plugins import DoublePrecision as LiteDoublePrecision
from lightning_lite.plugins import NativeMixedPrecision as LiteNativeMixedPrecision
from lightning_lite.plugins import Precision as LitePrecision
from lightning_lite.plugins import TPUBf16Precision as LiteTPUBf16Precision
from lightning_lite.plugins import TPUPrecision as LiteTPUPrecision
from lightning_lite.strategies import DataParallelStrategy as LiteDataParallelStrategy
from lightning_lite.strategies import DDPShardedStrategy as LiteDDPShardedStrategy
from lightning_lite.strategies import DDPSpawnShardedStrategy as LiteDDPSpawnShardedStrategy
from lightning_lite.strategies import DDPSpawnStrategy as LiteDDPSpawnStrategy
from lightning_lite.strategies import DDPStrategy as LiteDDPStrategy
from lightning_lite.strategies import DeepSpeedStrategy as LiteDeepSpeedStrategy
from lightning_lite.strategies import SingleDeviceStrategy as LiteSingleDeviceStrategy
from lightning_lite.strategies import SingleTPUStrategy as LiteSingleTPUStrategy
from lightning_lite.strategies import Strategy as LiteStrategy
from lightning_lite.strategies import XLAStrategy
from pytorch_lightning.accelerators import Accelerator as PLAccelerator
from pytorch_lightning.plugins import DeepSpeedPrecisionPlugin as PLDeepSpeedPrecisionPlugin
from pytorch_lightning.plugins import DoublePrecisionPlugin as PLDoublePrecisionPlugin
from pytorch_lightning.plugins import NativeMixedPrecisionPlugin as PLNativeMixedPrecisionPlugin
from pytorch_lightning.plugins import PrecisionPlugin as PLPrecisionPlugin
from pytorch_lightning.plugins import TPUBf16PrecisionPlugin as PLTPUBf16PrecisionPlugin
from pytorch_lightning.plugins import TPUPrecisionPlugin as PLTPUPrecisionPlugin
from pytorch_lightning.strategies import DataParallelStrategy as PLDataParallelStrategy
from pytorch_lightning.strategies import DDPShardedStrategy as PLDDPShardedStrategy
from pytorch_lightning.strategies import DDPSpawnShardedStrategy as PLDDPSpawnShardedStrategy
from pytorch_lightning.strategies import DDPSpawnStrategy as PLDDPSpawnStrategy
from pytorch_lightning.strategies import DDPStrategy as PLDDPStrategy
from pytorch_lightning.strategies import DeepSpeedStrategy as PLDeepSpeedStrategy
from pytorch_lightning.strategies import SingleDeviceStrategy as PLSingleDeviceStrategy
from pytorch_lightning.strategies import SingleTPUStrategy as PLSingleTPUStrategy
from pytorch_lightning.strategies import Strategy as PLStrategy
from pytorch_lightning.strategies import TPUSpawnStrategy as PLTPUSpawnStrategy
_PL_PLUGIN = Union[PLPrecisionPlugin, ClusterEnvironment, CheckpointIO]
_PL_PLUGIN_INPUT = Union[_PL_PLUGIN, str]
[docs]class LightningLite(_NewLightningLite, ABC):
"""Lite accelerates your PyTorch training or inference code with minimal changes required.
- Automatic placement of models and data onto the device.
- Automatic support for mixed and double precision (smaller memory footprint).
- Seamless switching between hardware (CPU, GPU, TPU) and distributed training strategies
(data-parallel training, sharded training, etc.).
- Automated spawning of processes, no launch utilities required.
- Multi-node support.
Args:
accelerator: The hardware to run on. Possible choices are:
``"cpu"``, ``"cuda"``, ``"mps"``, ``"gpu"``, ``"tpu"``, ``"auto"``.
strategy: Strategy for how to run across multiple devices. Possible choices are:
``"dp"``, ``"ddp"``, ``"ddp_spawn"``, ``"deepspeed"``, ``"ddp_sharded"``.
devices: Number of devices to train on (``int``), which GPUs to train on (``list`` or ``str``), or ``"auto"``.
The value applies per node.
num_nodes: Number of GPU nodes for distributed training.
precision: Double precision (``64``), full precision (``32``), half precision (``16``),
or bfloat16 precision (``"bf16"``).
plugins: One or several custom plugins
gpus: Provides the same function as the ``devices`` argument but implies ``accelerator="gpu"``.
.. deprecated:: v1.8.0
``gpus`` has been deprecated in v1.8.0 and will be removed in v1.10.0.
Please use ``accelerator='gpu'`` and ``devices=x`` instead.
tpu_cores: Provides the same function as the ``devices`` argument but implies ``accelerator="tpu"``.
.. deprecated:: v1.8.0
``tpu_cores`` has been deprecated in v1.8.0 and will be removed in v1.10.0.
Please use ``accelerator='tpu'`` and ``devices=x`` instead.
"""
[docs] def __init__(
self,
accelerator: Optional[Union[str, PLAccelerator]] = None,
strategy: Optional[Union[str, PLStrategy]] = None,
devices: Optional[Union[List[int], str, int]] = None,
num_nodes: int = 1,
precision: _PRECISION_INPUT = 32,
plugins: Optional[Union[_PL_PLUGIN_INPUT, List[_PL_PLUGIN_INPUT]]] = None,
gpus: Optional[Union[List[int], str, int]] = None,
tpu_cores: Optional[Union[List[int], str, int]] = None,
) -> None:
if gpus is not None or tpu_cores is not None:
devices, accelerator = _convert_deprecated_device_flags(
accelerator=accelerator,
devices=devices,
gpus=gpus,
tpu_cores=tpu_cores,
)
lite_plugins: Optional[Union[_LITE_PLUGIN_INPUT, List[_LITE_PLUGIN_INPUT]]]
if isinstance(plugins, PLPrecisionPlugin):
lite_plugins = _to_lite_precision(plugins)
elif isinstance(plugins, list):
lite_plugins = [
_to_lite_precision(plugin) if isinstance(plugin, PLPrecisionPlugin) else plugin for plugin in plugins
]
else:
lite_plugins = plugins
super().__init__(
accelerator=accelerator,
strategy=(_to_lite_strategy(strategy) if isinstance(strategy, PLStrategy) else strategy),
devices=devices,
num_nodes=num_nodes,
precision=precision,
plugins=lite_plugins,
)
def _convert_deprecated_device_flags(
accelerator: Optional[Union[str, PLAccelerator]],
devices: Optional[Union[List[int], str, int]],
gpus: Optional[Union[List[int], str, int]],
tpu_cores: Optional[Union[List[int], str, int]],
) -> Tuple[Optional[Union[List[int], str, int]], Optional[Union[str, PLAccelerator]]]:
"""Emit deprecation warnings for gpus and tpu_cores and translate them into the new accelerator and devices.
Similar implementation as in ``pytorch_lightning.trainer.connectors.accelerator_connector``.
"""
if gpus is not None:
rank_zero_deprecation(
f"Setting `Lite(gpus={gpus!r})` is deprecated in v1.8.0 and will be removed"
f" in v1.10.0. Please use `Lite(accelerator='gpu', devices={gpus!r})` instead."
)
if tpu_cores is not None:
rank_zero_deprecation(
f"Setting `Lite(tpu_cores={tpu_cores!r})` is deprecated in v1.8.0 and will be removed"
f" in v1.10.0. Please use `Lite(accelerator='tpu', devices={tpu_cores!r})` instead."
)
deprecated_devices_specific_flag = gpus or tpu_cores
if deprecated_devices_specific_flag and deprecated_devices_specific_flag not in ([], 0, "0"):
if devices:
rank_zero_warn(
f"The option `devices={devices}` will be ignored and the device specific number"
f"{deprecated_devices_specific_flag} will be used instead."
)
if gpus is not None and tpu_cores is not None:
rank_zero_warn(
f"Both `Lite(gpus={gpus!r}, tpu_cores={tpu_cores!r})` were specified. Please choose only one of"
" the two."
)
if accelerator is None:
if tpu_cores:
accelerator = "tpu"
if gpus:
accelerator = "cuda"
return deprecated_devices_specific_flag, accelerator
def _to_lite_strategy(strategy: PLStrategy) -> LiteStrategy:
"""Re-instantiates a PL-Strategy as the corresponding Lite-Strategy."""
strategy_cls = type(strategy)
if strategy_cls is PLDDPStrategy:
return LiteDDPStrategy(
accelerator=strategy.accelerator,
parallel_devices=strategy.parallel_devices,
cluster_environment=strategy.cluster_environment,
checkpoint_io=strategy.checkpoint_io,
precision=_to_lite_precision(strategy.precision_plugin),
process_group_backend=strategy.process_group_backend,
timeout=strategy._timeout,
**strategy._ddp_kwargs,
)
if strategy_cls is PLDDPSpawnStrategy:
return LiteDDPSpawnStrategy(
accelerator=strategy.accelerator,
parallel_devices=strategy.parallel_devices,
cluster_environment=strategy.cluster_environment,
checkpoint_io=strategy.checkpoint_io,
precision=_to_lite_precision(strategy.precision_plugin),
process_group_backend=strategy.process_group_backend,
timeout=strategy._timeout,
start_method=strategy._start_method,
**strategy._ddp_kwargs,
)
if strategy_cls is PLTPUSpawnStrategy:
return XLAStrategy(
accelerator=strategy.accelerator,
parallel_devices=strategy.parallel_devices,
checkpoint_io=strategy.checkpoint_io,
precision=_to_lite_precision(strategy.precision_plugin),
)
if strategy_cls is PLDeepSpeedStrategy:
return LiteDeepSpeedStrategy(
accelerator=strategy.accelerator,
parallel_devices=strategy.parallel_devices,
cluster_environment=strategy.cluster_environment,
precision=_to_lite_precision(strategy.precision_plugin),
process_group_backend=strategy.process_group_backend,
config=strategy.config,
remote_device=strategy.remote_device,
load_full_weights=strategy.load_full_weights,
loss_scale=strategy.loss_scale,
initial_scale_power=strategy.initial_scale_power,
loss_scale_window=strategy.loss_scale_window,
hysteresis=strategy.hysteresis,
min_loss_scale=strategy.min_loss_scale,
)
if strategy_cls is PLDataParallelStrategy:
return LiteDataParallelStrategy(
accelerator=strategy.accelerator,
parallel_devices=strategy.parallel_devices,
checkpoint_io=strategy.checkpoint_io,
precision=_to_lite_precision(strategy.precision_plugin),
)
if strategy_cls is PLDDPShardedStrategy:
return LiteDDPShardedStrategy(
accelerator=strategy.accelerator,
parallel_devices=strategy.parallel_devices,
cluster_environment=strategy.cluster_environment,
checkpoint_io=strategy.checkpoint_io,
precision=_to_lite_precision(strategy.precision_plugin),
process_group_backend=strategy.process_group_backend,
timeout=strategy._timeout,
**strategy._ddp_kwargs,
)
if strategy_cls is PLDDPSpawnShardedStrategy:
return LiteDDPSpawnShardedStrategy(
accelerator=strategy.accelerator,
parallel_devices=strategy.parallel_devices,
cluster_environment=strategy.cluster_environment,
checkpoint_io=strategy.checkpoint_io,
precision=_to_lite_precision(strategy.precision_plugin),
process_group_backend=strategy.process_group_backend,
timeout=strategy._timeout,
start_method=strategy._start_method,
**strategy._ddp_kwargs,
)
if strategy_cls is PLSingleDeviceStrategy:
return LiteSingleDeviceStrategy(
device=strategy.root_device,
accelerator=strategy.accelerator,
checkpoint_io=strategy.checkpoint_io,
precision=_to_lite_precision(strategy.precision_plugin),
)
if strategy_cls is PLSingleTPUStrategy:
return LiteSingleTPUStrategy(
device=strategy.root_device.index,
accelerator=strategy.accelerator,
checkpoint_io=strategy.checkpoint_io,
precision=_to_lite_precision(strategy.precision_plugin),
)
raise NotImplementedError(f"Unsupported strategy: `{strategy_cls.__name__}`")
def _to_lite_precision(plugin: Optional[PLPrecisionPlugin]) -> LitePrecision:
"""Re-instantiates a PL-PrecisionPlugin as the corresponding Lite-Precision plugin."""
if type(plugin) is PLPrecisionPlugin:
return LitePrecision()
if type(plugin) is PLNativeMixedPrecisionPlugin:
return LiteNativeMixedPrecision(
precision=plugin.precision, device=plugin.device, scaler=plugin.scaler # type: ignore[arg-type]
)
if type(plugin) is PLDoublePrecisionPlugin:
return LiteDoublePrecision()
if type(plugin) is PLDeepSpeedPrecisionPlugin:
return LiteDeepSpeedPrecision(
precision=plugin.precision, amp_type=plugin.amp_type, amp_level=plugin.amp_level # type: ignore[arg-type]
)
if type(plugin) is PLTPUPrecisionPlugin:
return LiteTPUPrecision()
if type(plugin) is PLTPUBf16PrecisionPlugin:
return LiteTPUBf16Precision()
# No backward compatibility for custom plugins / subclasses, as we can't re-instantiate these plugins
raise TypeError(
"You passed an unsupported plugin as input to Lite(plugins=...) or to a strategy. If you built a custom plugin,"
" please change it to subclass the `lightning_lite.plugins.precision.Precision` class. Otherwise, please open"
" an issue on the Lightning GitHub repository with your use case."
)