Source code for pytorch_lightning.strategies.colossalai
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, List, Mapping, Optional, TYPE_CHECKING, Union
import torch
from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.rank_zero import rank_zero_warn
from torch import Tensor
from torch.nn import Module
from torch.optim.optimizer import Optimizer
from typing_extensions import OrderedDict
import pytorch_lightning as pl
from lightning_lite.accelerators.cuda import _patch_cuda_is_available
from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment
from lightning_lite.utilities.distributed import ReduceOp
from pytorch_lightning.accelerators.cuda import CUDAAccelerator
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import ColossalAIPrecisionPlugin
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.strategies.strategy import TBroadcast
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import STEP_OUTPUT
_COLOSSALAI_AVAILABLE = RequirementCache("colossalai")
if TYPE_CHECKING and _COLOSSALAI_AVAILABLE:
with _patch_cuda_is_available():
from colossalai.utils.model.colo_init_context import ColoInitContext
else:
ColoInitContext = Any
[docs]class ColossalAIStrategy(DDPStrategy):
"""ColossalAI strategy. It only supports a single optimizer, which must be
:class:`colossalai.nn.optimizer.CPUAdam` or :class:`colossalai.nn.optimizer.HybridAdam` now. Your model must
be created in the function ``LightningModule.configure_sharded_model()``. Thus, you should overwrite this function.
More details can be found in the below example.
It configures accelerator and precision, and you should not configure them when initializing ``Trainer``.
CUDA is essential for this strategy. Please make sure CUDA is available.
Example::
class GLUETransformer(LightningModule):
...
def configure_sharded_model(self) -> None:
self.model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
trainer = Trainer(..., accelerator="gpu", precision=16, strategy="colossalai")
Args:
use_chunk: Whether to use chunk-based memory management.
It can speed up training, but slightly more memory will be used.
chunk_size: The size of a chunk.
It will be ignored when ``use_chunk=False``.
If it's None, a best chunk size will be searched out based on ``chunk_search_range``,
``chunk_search_n_grids`` and ``min_chunk_size``.
enable_distributed_storage: Whether to storage model in a distributed manner.
It reduces memory from 1 to 1/N, but it may slow down training.
placement_policy: It can be "cpu", "cuda" and "auto".
* If it's "cpu", parameters, gradients and optimizer states will be offloaded to CPU,
which means min CUDA memory will be used.
* If it's "cuda", they won't be offloaded, which means max CUDA memory will be used. It's the fastest.
* If it's "auto", they are moving dynamically based on CPU and CUDA memory usage.
It will utilize heterogeneous memory space evenly and well.
Note that "auto" policy can only work well when no other processes use CUDA during your training.
force_outputs_fp32: Whether to cast outputs to fp32.
gpu_margin_mem_ratio: The ratio of GPU remaining memory (after the first forward-backward)
which will be used by optimizer.
This argument will be ignored when ``placement_policy`` is not "auto".
chunk_search_range: The range of chunk size to search.
The actual search range will be from
``max(min_chunk_size, max_param_size)`` to ``max(min_chunk_size, max_param_size) + chunk_search_range``.
chunk_search_n_grids: The number of intervals in the search range.
min_chunk_size: The minimum size for a chunk.
initial_scale: The initial dynamic loss scale value.
min_scale: The minimum dynamic loss scaling value.
growth_factor: The multiplication factor for increasing loss scale.
backoff_factor: The multiplication factor for decreasing loss scale.
growth_interval: The number of steps to increase loss scale when no overflow occurs.
hysteresis: The number of overflows before decreasing loss scale.
max_scale: The maximum dynamic loss scaling value.
.. _colossalai.nn.optimizer.CPUAdam:
https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.nn.optimizer.cpu_adam.html
.. _colossalai.nn.optimizer.HybridAdam:
https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.nn.optimizer.hybrid_adam.html
"""
strategy_name = "colossalai"
def __init__(
self,
use_chunk: bool = True,
chunk_size: Optional[int] = None,
enable_distributed_storage: bool = True,
placement_policy: str = "auto",
force_outputs_fp32: bool = False,
gpu_margin_mem_ratio: float = 0.0,
chunk_search_range: int = 64 * 1024**2,
chunk_search_n_grids: int = 1024,
min_chunk_size: Optional[int] = None,
initial_scale: float = 2**16,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
accelerator: Optional["pl.accelerators.Accelerator"] = None,
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[ColossalAIPrecisionPlugin] = None,
) -> None:
if not _COLOSSALAI_AVAILABLE:
raise MisconfigurationException(
"To use the `ColossalAIStrategy`, please install `colossalai` first. "
"Download `colossalai` by consulting `https://colossalai.org/download`."
)
with _patch_cuda_is_available():
from colossalai.logging import get_dist_logger
super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
checkpoint_io=checkpoint_io,
precision_plugin=precision_plugin,
)
self.use_chunk = use_chunk
self.chunk_size = chunk_size
self.enable_distributed_storage = enable_distributed_storage
self.placement_policy = placement_policy
self.force_outputs_fp32 = force_outputs_fp32
self.gpu_margin_mem_ratio = gpu_margin_mem_ratio
self.chunk_size_search_kwargs = {
"search_range": chunk_search_range,
"n_grids": chunk_search_n_grids,
"min_chunk_size": min_chunk_size,
}
self.amp_kwargs = {
"initial_scale": initial_scale,
"min_scale": min_scale,
"growth_factor": growth_factor,
"backoff_factor": backoff_factor,
"growth_interval": growth_interval,
"hysteresis": hysteresis,
"max_scale": max_scale,
}
self._num_nodes = 1
self._logger = get_dist_logger()
@property
def root_device(self) -> torch.device:
with _patch_cuda_is_available():
from colossalai.utils import get_current_device
if self.parallel_devices is not None:
return self.parallel_devices[self.local_rank]
return get_current_device()
@property
def handles_gradient_accumulation(self) -> bool:
"""Whether the plugin handles gradient accumulation internally."""
return True
@property
def restore_checkpoint_after_setup(self) -> bool:
"""Override to delay restoring from checkpoint till after pre-dispatch."""
return True
def setup_distributed(self) -> None:
with _patch_cuda_is_available():
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers
assert self.cluster_environment is not None
self.set_world_ranks()
if not gpc.is_initialized(ParallelMode.GLOBAL):
disable_existing_loggers()
gpc.init_global_dist(
rank=self.global_rank,
world_size=self.world_size,
backend="nccl",
host=self.cluster_environment.main_address,
port=self.cluster_environment.main_port,
)
gpc.set_device(self.local_rank)
[docs] def model_sharded_context(self) -> "ColoInitContext":
"""Provide hook to create modules in a distributed aware context. This is useful for when we'd like to
shard the model instantly, which is useful for extremely large models which can save memory and
initialization time.
Returns: Model parallel context.
"""
with _patch_cuda_is_available():
from colossalai.utils.model.colo_init_context import ColoInitContext
class ModelShardedContext(ColoInitContext):
def _post_init_method(self, module: torch.nn.Module, *args: Any, **kwargs: Any) -> None:
if getattr(module, "_colossalai_module", False) is True:
return
super()._post_init_method(module, *args, **kwargs)
module._colossalai_module = True # type: ignore[assignment]
return ModelShardedContext()
[docs] def setup_precision_plugin(self) -> None:
with _patch_cuda_is_available():
from colossalai.gemini import ChunkManager, GeminiManager
from colossalai.nn.optimizer import CPUAdam, HybridAdam
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ProcessGroup
from colossalai.zero import ZeroOptimizer
super().setup_precision_plugin()
assert self.lightning_module is not None
is_training = self.lightning_module.trainer and self.lightning_module.trainer.training
if is_training:
if len(self.optimizers) > 1:
raise ValueError("`ColossalAIStrategy` only supports single Optimizer now.")
optimizer = self.optimizers[0]
if not isinstance(optimizer, (CPUAdam, HybridAdam)):
raise ValueError(
"`ColossalAIStrategy` only supports `colossalai.nn.optimizer.CPUAdam` "
"and `colossalai.nn.optimizer.HybridAdam` as its optimizer."
)
assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
pl_module = self.model
process_group = ProcessGroup()
if not hasattr(pl_module, "_colossalai_zero"):
if self.use_chunk:
chunk_size = self.chunk_size or ChunkManager.search_chunk_size(
self.model, **self.chunk_size_search_kwargs
)
else:
chunk_size = None
chunk_manager = ChunkManager(
chunk_size,
process_group,
self.enable_distributed_storage,
GeminiManager.get_default_device(self.placement_policy),
)
gemini_manager = GeminiManager(self.placement_policy, chunk_manager)
model = _LightningModuleWrapperBase(self.model)
self.model = ZeroDDP(model, gemini_manager, self.force_outputs_fp32)
assert self.model is not None
pl_module._colossalai_zero = [self.model] # type: ignore[assignment]
else:
self.model = pl_module._colossalai_zero[0] # type: ignore[index, assignment]
if is_training:
self.optimizers = [
ZeroOptimizer(optimizer, self.model, gpu_margin_mem_ratio=self.gpu_margin_mem_ratio, **self.amp_kwargs)
]
[docs] def setup(self, trainer: "pl.Trainer") -> None:
precision = self.precision_plugin.precision
if not (precision == PrecisionType.HALF):
raise ValueError(
f"`Trainer(strategy='colossalai', precision={precision!r})` is not supported."
" Consider setting `precision=16`."
)
if not isinstance(self.accelerator, CUDAAccelerator):
raise ValueError(
"`ColossalAIStrategy` is only supported on `CUDAAccelerator`, "
f"but `{self.accelerator.__class__.__name__}` is used."
)
if trainer.state.fn == TrainerFn.FITTING:
if is_overridden("backward", trainer.lightning_module):
rank_zero_warn(
"You have overridden the `LightningModule.backward` hook"
" but it will be ignored since ColossalAI handles"
" the backward logic internally."
)
if trainer.accumulate_grad_batches > 1:
raise ValueError(
"ColossalAI does not support gradient accumulation now. Please set `accumulate_grad_batches` to 1."
)
accumulation_scheduler = trainer.accumulation_scheduler
if accumulation_scheduler.epochs != [0]:
raise ValueError(
"ColossalAI currently does not support different `accumulate_grad_batches` at different epochs."
)
if not isinstance(self.precision_plugin, ColossalAIPrecisionPlugin):
raise ValueError("`ColossalAIStrategy` is only compatible with `ColossalAIPrecisionPlugin`.")
self.accelerator.setup(trainer)
assert self.lightning_module is not None
self.lightning_module._device = self.root_device
self.setup_optimizers(trainer)
self.setup_precision_plugin()
self.model_to_device()
[docs] def model_to_device(self) -> None:
assert self.lightning_module is not None
pl_module = self.lightning_module
for child in pl_module.modules():
if child is not pl_module and not getattr(child, "_colossalai_module", False):
child.to(self.root_device)
[docs] def teardown(self) -> None:
optimizers = self.optimizers
self.optimizers = list()
zero_model = self.model
self.model = None
pl_module = self._lightning_module
self._lightning_module = None
super().teardown()
self.optimizers = optimizers
self.model = zero_model
self._lightning_module = pl_module
[docs] def optimizer_step(
self,
optimizer: Optimizer,
opt_idx: int,
closure: Callable[[], Any],
model: Optional[Union["pl.LightningModule", Module]] = None,
**kwargs: Any,
) -> Any:
model = model or self.lightning_module
# TODO(lite): remove assertion once strategy's optimizer_step typing is fixed
assert isinstance(model, pl.LightningModule)
return self.precision_plugin.optimizer_step(
optimizer, model=model, optimizer_idx=opt_idx, closure=closure, **kwargs
)
[docs] def lightning_module_state_dict(self, rank_zero_only: bool = False) -> Dict[str, Any]:
"""Returns a dictionary containing a whole state of the module. But all the tensors in the dictionary are
detached from their parameters and located in cpu memory.
Args:
rank_zero_only: If True, only process rank 0 gets the correct dictionary.
Otherwise, all processes get the same dictionary.
"""
with _patch_cuda_is_available():
from colossalai.nn.parallel import ZeroDDP
assert isinstance(self.model, ZeroDDP)
org_dict = self.model.state_dict(only_rank_0=rank_zero_only)
children = list(self.model.named_children())
assert len(children) == 1
prefix, child = children[0]
prefix += "."
assert child is self.lightning_module
mapping_dict = dict()
for key in org_dict.keys():
mapping_dict[key] = key.replace(prefix, "") # remove "_forward_module." from the key
return {mapping_dict[key]: value for key, value in org_dict.items()}
def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
orig_dict = checkpoint["state_dict"]
assert self.model is not None
children = list(self.model.named_children())
assert len(children) == 1
prefix, child = children[0]
prefix += "."
assert child is self.lightning_module
mapping_dict = dict()
for key in orig_dict.keys():
mapping_dict[key] = prefix + key # add "_forward_module." to the key
load_dict = OrderedDict({mapping_dict[key]: value for key, value in orig_dict.items()})
self.model.load_state_dict(load_dict)
[docs] def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
assert self.model is not None
with self.precision_plugin.val_step_context():
return self.model(*args, **kwargs)
[docs] def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
assert self.model is not None
with self.precision_plugin.test_step_context():
return self.model(*args, **kwargs)
[docs] def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
assert self.model is not None
with self.precision_plugin.predict_step_context():
return self.model(*args, **kwargs)
@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register("colossalai", cls, description="Default ColossalAI Strategy")
[docs] def reduce(
self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "sum"
) -> Tensor:
with _patch_cuda_is_available():
from colossalai.communication.collective import reduce
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
if not isinstance(tensor, Tensor):
return tensor
if isinstance(reduce_op, str):
if reduce_op.lower() in ("avg", "mean"):
reduce_op = ReduceOp.SUM
div_factor = gpc.get_world_size(parallel_mode=ParallelMode.GLOBAL)
with torch.no_grad():
tensor = tensor / div_factor
else:
reduce_op = getattr(ReduceOp, reduce_op.upper())
tensor = reduce(tensor, dst=0, parallel_mode=ParallelMode.GLOBAL, op=reduce_op)
return tensor
[docs] def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
"""Broadcasts an object to all processes.
Args:
obj: the object to broadcast
src: source rank
"""
with _patch_cuda_is_available():
from colossalai.communication.collective import broadcast
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
if isinstance(obj, Tensor):
return broadcast(obj, src=src, parallel_mode=ParallelMode.GLOBAL)
else:
obj_list = [obj]
torch.distributed.broadcast_object_list(obj_list, src, group=gpc.get_group(ParallelMode.GLOBAL))
return obj_list[0]
[docs] def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
"""Perform a all_gather on all processes."""
with _patch_cuda_is_available():
from colossalai.communication.collective import all_gather
from colossalai.context import ParallelMode
assert sync_grads is False
return all_gather(tensor, dim=0, parallel_mode=ParallelMode.GLOBAL)