Source code for lightning.fabric.plugins.environments.mpi
# Copyright The Lightning AI team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importloggingimportsocketfromfunctoolsimportlru_cachefromtypingimportOptionalimportnumpyasnpfromlightning_utilities.core.importsimportRequirementCachefromlightning.fabric.plugins.environments.cluster_environmentimportClusterEnvironmentfromlightning.fabric.plugins.environments.lightningimportfind_free_network_portlog=logging.getLogger(__name__)_MPI4PY_AVAILABLE=RequirementCache("mpi4py")
[docs]classMPIEnvironment(ClusterEnvironment):"""An environment for running on clusters with processes created through MPI. Requires the installation of the `mpi4py` package. See also: https://github.com/mpi4py/mpi4py """def__init__(self)->None:ifnot_MPI4PY_AVAILABLE:raiseModuleNotFoundError(str(_MPI4PY_AVAILABLE))frommpi4pyimportMPIself._comm_world=MPI.COMM_WORLDself._comm_local:Optional[MPI.Comm]=Noneself._node_rank:Optional[int]=Noneself._main_address:Optional[str]=Noneself._main_port:Optional[int]=None@propertydefcreates_processes_externally(self)->bool:returnTrue@propertydefmain_address(self)->str:ifself._main_addressisNone:self._main_address=self._get_main_address()returnself._main_address@propertydefmain_port(self)->int:ifself._main_portisNone:self._main_port=self._get_main_port()returnself._main_port
[docs]@staticmethoddefdetect()->bool:"""Returns ``True`` if the `mpi4py` package is installed and MPI returns a world size greater than 1."""ifnot_MPI4PY_AVAILABLE:returnFalsefrommpi4pyimportMPIreturnMPI.COMM_WORLD.Get_size()>1
def_get_main_address(self)->str:returnself._comm_world.bcast(socket.gethostname(),root=0)def_get_main_port(self)->int:returnself._comm_world.bcast(find_free_network_port(),root=0)def_init_comm_local(self)->None:hostname=socket.gethostname()all_hostnames=self._comm_world.gather(hostname,root=0)# sort all the hostnames, and find unique onesunique_hosts=np.unique(all_hostnames)unique_hosts=self._comm_world.bcast(unique_hosts,root=0)# find the integer for this host in the list of hosts:self._node_rank=int(np.where(unique_hosts==hostname)[0])self._comm_local=self._comm_world.Split(color=self._node_rank)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.