class pytorch_lightning.strategies.ColossalAIStrategy(use_chunk=True, chunk_size=None, enable_distributed_storage=True, placement_policy='auto', force_outputs_fp32=False, gpu_margin_mem_ratio=0.0, chunk_search_range=67108864, chunk_search_n_grids=4096, min_chunk_size=33554432, initial_scale=65536, min_scale=1, growth_factor=2, backoff_factor=0.5, growth_interval=1000, hysteresis=2, max_scale=4294967296, accelerator=None, parallel_devices=None, cluster_environment=None, checkpoint_io=None, precision_plugin=None)[source]

Bases: pytorch_lightning.strategies.ddp.DDPStrategy

ColossalAI strategy. It only supports a single optimizer, which must be colossalai.nn.optimizer.CPUAdam or colossalai.nn.optimizer.HybridAdam now. Your model must be created in the function LightningModule.configure_sharded_model(). Thus, you should overwrite this function. More details can be found in the below example.

It configures accelerator and precision, and you should not configure them when initializing Trainer. CUDA is essential for this strategy. Please make sure CUDA is available.


class GLUETransformer(LightningModule):
    def configure_sharded_model(self) -> None:
        self.model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
trainer = Trainer(..., accelerator="gpu", precision=16, strategy="colossalai")
  • use_chunk (bool) – Whether to use chunk-based memory management. It can speed up training, but slightly more memory will be used.

  • chunk_size (Optional[int]) – The size of a chunk. It will be ignored when use_chunk=False. If it’s None, a best chunk size will be searched out based on chunk_search_range, chunk_search_n_grids and min_chunk_size.

  • enable_distributed_storage (bool) – Whether to storage model in a distributed manner. It reduces memory from 1 to 1/N, but it may slow down training.

  • placement_policy (str) –

    It can be “cpu”, “cuda” and “auto”.

    • If it’s “cpu”, parameters, gradients and optimizer states will be offloaded to CPU,

      which means min CUDA memory will be used.

    • If it’s “cuda”, they won’t be offloaded, which means max CUDA memory will be used. It’s the fastest.

    • If it’s “auto”, they are moving dynamically based on CPU and CUDA memory usage.

      It will utilize heterogeneous memory space evenly and well. Note that “auto” policy can only work well when no other processes use CUDA during your training.

  • force_outputs_fp32 (bool) – Whether to cast outputs to fp32.

  • gpu_margin_mem_ratio (float) – The ratio of GPU remaining memory (after the first forward-backward) which will be used by optimizer. This argument will be ignored when placement_policy is not “auto”.

  • chunk_search_range (int) – The range of chunk size to search. The actual search range will be from max(min_chunk_size, max_param_size) to max(min_chunk_size, max_param_size) + chunk_search_range.

  • chunk_search_n_grids (int) – The number of intervals in the search range.

  • min_chunk_size (int) – The minimum size for a chunk in bytes.

  • initial_scale (float) – The initial dynamic loss scale value.

  • min_scale (float) – The minimum dynamic loss scaling value.

  • growth_factor (float) – The multiplication factor for increasing loss scale.

  • backoff_factor (float) – The multiplication factor for decreasing loss scale.

  • growth_interval (int) – The number of steps to increase loss scale when no overflow occurs.

  • hysteresis (int) – The number of overflows before decreasing loss scale.

  • max_scale (float) – The maximum dynamic loss scaling value.

all_gather(tensor, group=None, sync_grads=False)[source]

Perform a all_gather on all processes.

Return type:


broadcast(obj, src=0)[source]

Broadcasts an object to all processes.

  • obj (TypeVar(TBroadcast)) – the object to broadcast

  • src (int) – source rank

Return type:



Returns a dictionary containing a whole state of the module. But all the tensors in the dictionary are detached from their parameters and located in cpu memory.


rank_zero_only (bool) – If True, only process rank 0 gets the correct dictionary. Otherwise, all processes get the same dictionary.

Return type:

Dict[str, Any]


Provide hook to create modules in a distributed aware context. This is useful for when we’d like to shard the model instantly, which is useful for extremely large models which can save memory and initialization time.

Returns: Model parallel context.

Return type:



Moves the model to the correct device.

Return type:


optimizer_step(optimizer, opt_idx, closure, model=None, **kwargs)[source]

Performs the actual optimizer step.

  • optimizer (Optimizer) – the optimizer performing the step

  • opt_idx (int) – index of the current optimizer

  • closure (Callable[[], Any]) – closure calculating the loss value

  • model (Union[LightningModule, Module, None]) – reference to the model, optionally defining optimizer step related hooks

  • **kwargs (Any) – Any extra arguments to optimizer.step

Return type:


predict_step(*args, **kwargs)[source]

The actual predict step.

See predict_step() for more details

Return type:

Union[Tensor, Dict[str, Any]]

reduce(tensor, group=None, reduce_op='sum')[source]

Reduces a tensor from several distributed processes to one aggregated tensor.

  • tensor (Tensor) – the tensor to sync and reduce

  • group (Optional[Any]) – the process group to gather results from. Defaults to all processes (world)

  • reduce_op (Union[ReduceOp, str, None]) – the reduction operation. Defaults to ‘mean’/’avg’. Can also be a string ‘sum’ to calculate the sum during reduction.

Return type:



reduced value, except when the input was not a tensor the output remains is unchanged


Setup plugins for the trainer fit and creates optimizers.


trainer (Trainer) – the trainer instance

Return type:



Attaches the precision plugin to the accelerator.

Return type:



This method is called to teardown the training process.

It is the right place to release memory and free other resources.

Return type:


test_step(*args, **kwargs)[source]

The actual test step.

See test_step() for more details

Return type:

Union[Tensor, Dict[str, Any], None]

validation_step(*args, **kwargs)[source]

The actual validation step.

See validation_step() for more details

Return type:

Union[Tensor, Dict[str, Any], None]

property handles_gradient_accumulation: bool

Whether the plugin handles gradient accumulation internally.

property restore_checkpoint_after_setup: bool

Override to delay restoring from checkpoint till after pre-dispatch.

property root_device: torch.device

Return the root device.