[docs]classHivemindStrategy(Strategy):INITIAL_PEERS_ENV:str="PL_INITIAL_PEERS"def__init__(self,target_batch_size:int,run_id:str="lightning_run",batch_size:Optional[int]=None,delay_state_averaging:bool=False,delay_optimizer_step:Optional[bool]=None,delay_grad_averaging:bool=False,offload_optimizer:Optional[bool]=None,reuse_grad_buffers:bool=False,scheduler_fn:Optional[Callable]=None,matchmaking_time:float=5.0,averaging_timeout:float=30.0,verbose:bool=False,averager_opts:Optional[Dict]=None,host_maddrs:Optional[List]=None,initial_peers:Optional[Union[str,List]]=None,**optimizer_kwargs:Any,):"""Provides capabilities to train using the Hivemind Library, training collaboratively across the internet with unreliable machines. For more information, `refer to the docs <https://pytorch- lightning.readthedocs.io/en/latest/strategies/hivemind.html>`__. .. warning:: ``HivemindStrategy`` is experimental and subject to change. Arguments: target_batch_size: When training, the batch size to accumulate to before running a step. The larger this batch size, the more work can be done asynchronously without communication. run_id: A unique identifier of this training run, used as a common prefix for all DHT keys. See ``https://learning-at-home.readthedocs.io/en/latest/user/dht.html``. batch_size: The local batch size per process. If not provided, we infer this from the first batch of data passed in at training (lazy). Note that this should not change throughout training. delay_state_averaging: If enabled (default), average parameters and extra tensors in a background thread; if set to False, average parameters synchronously within the corresponding :meth:`hivemind.Optimizer.step` call. delay_optimizer_step: Run optimizer in background, apply results in future .step. requires :paramref:`~pytorch_lightning.strategies.hivemind.HivemindStrategy.offload_optimizer`. delay_grad_averaging: Average gradients in background; requires :paramref:`~pytorch_lightning.strategies.hivemind.HivemindStrategy.offload_optimizer` and :paramref:`~pytorch_lightning.strategies.hivemind.HivemindStrategy.delay_optimizer_step`. offload_optimizer: Offload the optimizer to host memory, saving GPU memory for parameters and gradients. reuse_grad_buffers: Use the model's gradient buffers (params.grad) for gradient accumulation which is more memory efficient. Lightning will automatically disable ``zero_grad`` in the ``LightningModule``. scheduler_fn: callable(optimizer) -> PyTorch LRScheduler or a pre-initialized PyTorch scheduler. When using `offload_optimizer`/`delay_optimizer_step`/`delay_state_averaging` ``scheduler_fn`` is required to be passed to the ``HivemindStrategy``. This is because the optimizer is re-created and the scheduler needs to be re-created as well. matchmaking_time: When looking for group, wait for peers to join for up to this many seconds. Increase if you see "averaged gradients with N peers" where N is below 0.9x on >=25% of epochs. Training with low-latency network, decreasing matchmaking_time allows training with smaller batch sizes. averaging_timeout: If an averaging step hangs for this long, it will be cancelled automatically. Increase averaging_timeout if you see "Proceeding with local gradients" at least 25% of the time. Do not set this timeout too high, as it may cause your optimizer to hang after some types of network errors. verbose: Report internal Hivemind events such as accumulating gradients and running background tasks. averager_opts: Additional keyword arguments forwarded to both ``GradientAverager`` and ``TrainingStateAverager``. host_maddrs: List of multi-addrs to create visible peers for other processes. `https://learning-at-home.readthedocs.io/en/latest/user/dht.html#running-across-the-internet` initial_peers: If connecting to a running process, a list of initial peers needs to be passed in. This can also be set via the env variable ``INITIAL_PEERS``. **optimizer_kwargs: kwargs are passed to the :class:`hivemind.Optimizer` class. """ifnot_HIVEMIND_AVAILABLEorplatform.system()!="Linux":raiseMisconfigurationException("To use the `HivemindStrategy`, you must have Hivemind installed and be running on Linux."" Install it by running `pip install -U hivemind`.")super().__init__()self._initial_peers=initial_peersself._target_batch_size=target_batch_sizeself._batch_size=batch_sizeself._scheduler_fn=scheduler_fnself._require_scheduler_fn=delay_optimizer_stepordelay_state_averagingoroffload_optimizerself._opt=Noneself._optimizer_zero_grad_original:Optional[Callable]=Noneself._run_id=run_idself._reuse_grad_buffers=reuse_grad_buffersself._optimizer_kwargs=dict(matchmaking_time=matchmaking_time,averaging_timeout=averaging_timeout,delay_optimizer_step=delay_optimizer_step,delay_state_averaging=delay_state_averaging,delay_grad_averaging=delay_grad_averaging,offload_optimizer=offload_optimizer,averager_opts=averager_optsifaveraging_timeoutisnotNoneelsedict(request_timeout=1.0),verbose=verbose,reuse_grad_buffers=reuse_grad_buffers,**optimizer_kwargs,)self._parse_env_initial_peers()self.dht=hivemind.DHT(start=True,initial_peers=initial_peers,host_maddrs=host_maddrsifhost_maddrsisnotNoneelse["/ip4/0.0.0.0/tcp/0","/ip4/0.0.0.0/udp/0/quic"],)visible_addresses=[str(a)forainself.dht.get_visible_maddrs()ifnotipaddress.ip_address(a.values()[0]).is_loopback]ifinitial_peersisNone:log.info("\nOther machines can connect running the same command:\n"f"INITIAL_PEERS={','.join(visible_addresses)} python ...\n""or passing the peers to the strategy:\n"f"HivemindStrategy(initial_peers='{','.join(visible_addresses)}')")self._hivemind_initialized=Falsedef_parse_env_initial_peers(self)->None:initial_peers=os.environ.get(self.INITIAL_PEERS_ENV,self._initial_peers)self._initial_peers=initial_peers.split(",")ifisinstance(initial_peers,str)elseself._initial_peers@propertydefnum_peers(self)->int:ifself._opt:returnself._opt.tracker.global_progress.num_peersreturn1@propertydefroot_device(self)->torch.device:frompytorch_lightning.accelerators.cpuimportCPUAcceleratorfrompytorch_lightning.accelerators.cudaimportCUDAAcceleratorifisinstance(self.accelerator,CUDAAccelerator):returntorch.device(f"cuda:{torch.cuda.current_device()}")elifisinstance(self.accelerator,CPUAccelerator):returntorch.device("cpu")raiseMisconfigurationException(f"Was unable to infer device type from the accelerator: {self.accelerator.__class__.__name__}.")@propertydefglobal_rank(self)->int:return0@propertydefis_global_zero(self)->bool:returnTrue
def_initialize_hivemind(self)->None:iflen(self.optimizers)>1:raiseMisconfigurationException("Hivemind only supports training with one optimizer.")optimizer=self.optimizers[0]ifself._require_scheduler_fnandself._scheduler_fnisNone:rank_zero_warn("Enabling `delay_optimizer_step`, `delay_state_averaging` or `offload_optimizer` ""requires a `scheduler_fn` to be passed to the strategy if a scheduler is being used ""(this is because the optimizer is re-created within Hivemind).")scheduler=self._scheduler_fnifself._require_scheduler_fnelseNoneparams=optimizer.param_groupsifself._require_scheduler_fnelseNoneoptimizer=type(optimizer)ifself._require_scheduler_fnelseoptimizeropt=hivemind.Optimizer(dht=self.dht,run_id=self._run_id,params=params,optimizer=optimizer,scheduler=scheduler,target_batch_size=self._target_batch_size,batch_size_per_step=self._batch_size,**self._optimizer_kwargs,)ifnotself._scheduler_fn:self._wrap_schedulers(opt)opt.load_state_from_peers()self.optimizers=[opt]self._opt=optifself._reuse_grad_buffers:assertself.lightning_moduleisnotNoneself._optimizer_zero_grad_original=self.lightning_module.optimizer_zero_gradself._disable_zero_grad()def_disable_zero_grad(self)->None:lightning_module=self.lightning_moduleifis_overridden("optimizer_zero_grad",lightning_module):assertlightning_moduleisnotNone# `is_overridden` returns False otherwiserank_zero_warn("You have overridden `optimizer_zero_grad` which will be disabled."" When `HivemindStrategy(reuse_grad_buffers=True)`, the optimizer cannot call zero grad,"" as this would delete the gradients before they are averaged.")assertlightning_moduleisnotNonelightning_module.optimizer_zero_grad=None# type: ignore[assignment]def_wrap_schedulers(self,opt:"hivemind.Optimizer")->None:# wrap schedulers so that they only update when the hivemind optimizer updatesforscheduler_configinself.lr_scheduler_configs:scheduler=scheduler_config.schedulerifisinstance(scheduler,ReduceLROnPlateau):raiseValueError(f"The `ReduceLROnPlateau` scheduler is not currently supported with `{self.__class__.__name__}`.")scheduler_config.scheduler=HiveMindScheduler(optimizer=opt,scheduler=scheduler,)
[docs]defon_train_batch_start(self,batch:Any,batch_idx:int,dataloader_idx:int=0)->None:ifnotself._hivemind_initialized:self._hivemind_initialized=True# todo (sean): we could technically support a dynamic batch size by inferring each step# and passing it to the ``hivemind.Optimizer``.ifself._batch_sizeisNone:try:self._batch_size=extract_batch_size(batch)log.info(f"Found per machine batch size automatically from the batch: {self._batch_size}")except(MisconfigurationException,RecursionError)ase:raiseMisconfigurationException("We tried to infer the batch size from the first batch of data. ""Please provide the batch size to the Strategy by ""``Trainer(strategy=HivemindStrategy(batch_size=x))``. ")fromeself._initialize_hivemind()
[docs]defteardown(self)->None:ifself._optimizer_zero_grad_originalisnotNoneandself.lightning_moduleisnotNone:# re-enable `optimizer_zero_grad`self.lightning_module.optimizer_zero_grad=(# type: ignore[method-assign]self._optimizer_zero_grad_original)ifself._opt:self._opt.shutdown()log.info("Shutting down hivemind DHT.")self.dht.shutdown()super().teardown()
classHiveMindScheduler:"""Wrapper for schedulers to prevent Lightning from stepping the scheduler too soon. This code ensures that we only step when the HiveMind optimizer reaches the global step. """base_lrs:List[float]def__init__(self,optimizer:"hivemind.Optimizer",scheduler:LRScheduler)->None:# copy most of the `Scheduler` methods into this instance. `__del__` is skipped in case the scheduler has# implemented custom logic which we would not want to call on destruction of the `HiveMindScheduler`self.__dict__={k:vfork,vinscheduler.__dict__.items()ifknotin("step","__del__")}self.optimizer=optimizerself.scheduler=schedulerself.current_step=-1defstep(self,epoch:Optional[int]=None)->None:whileself.current_step<self.optimizer.local_epoch:self.scheduler.step(epoch=epoch)self.current_step+=1defload_state_dict(self,state_dict:Dict)->None:self.scheduler.load_state_dict(state_dict)defstate_dict(self)->Dict:returnself.scheduler.state_dict()
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.