Source code for pytorch_lightning.utilities.rank_zero
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Utilities that can be used for calling functions on a particular rank."""importloggingimportosimportwarningsfromfunctoolsimportpartial,wrapsfromplatformimportpython_versionfromtypingimportAny,Callable,Optional,Unionlog=logging.getLogger(__name__)
[docs]defrank_zero_only(fn:Callable)->Callable:"""Function that can be used as a decorator to enable a function/method being called only on rank 0."""@wraps(fn)defwrapped_fn(*args:Any,**kwargs:Any)->Optional[Any]:ifrank_zero_only.rank==0:returnfn(*args,**kwargs)returnNonereturnwrapped_fn
# TODO: this should be part of the cluster environmentdef_get_rank()->int:# SLURM_PROCID can be set even if SLURM is not managing the multiprocessing,# therefore LOCAL_RANK needs to be checked firstrank_keys=("RANK","LOCAL_RANK","SLURM_PROCID","JSM_NAMESPACE_RANK")forkeyinrank_keys:rank=os.environ.get(key)ifrankisnotNone:returnint(rank)return0# add the attribute to the function but don't overwrite in case Trainer has already set itrank_zero_only.rank=getattr(rank_zero_only,"rank",_get_rank())def_info(*args:Any,stacklevel:int=2,**kwargs:Any)->None:ifpython_version()>="3.8.0":kwargs["stacklevel"]=stacklevellog.info(*args,**kwargs)def_debug(*args:Any,stacklevel:int=2,**kwargs:Any)->None:ifpython_version()>="3.8.0":kwargs["stacklevel"]=stacklevellog.debug(*args,**kwargs)
[docs]@rank_zero_onlydefrank_zero_debug(*args:Any,stacklevel:int=4,**kwargs:Any)->None:"""Function used to log debug-level messages only on rank 0."""_debug(*args,stacklevel=stacklevel,**kwargs)
[docs]@rank_zero_onlydefrank_zero_info(*args:Any,stacklevel:int=4,**kwargs:Any)->None:"""Function used to log info-level messages only on rank 0."""_info(*args,stacklevel=stacklevel,**kwargs)
def_warn(message:Union[str,Warning],stacklevel:int=2,**kwargs:Any)->None:iftype(stacklevel)istypeandissubclass(stacklevel,Warning):rank_zero_deprecation("Support for passing the warning category positionally is deprecated in v1.6 and will be removed in v1.8"f" Please, use `category={stacklevel.__name__}`.")kwargs["category"]=stacklevelstacklevel=kwargs.pop("stacklevel",2)warnings.warn(message,stacklevel=stacklevel,**kwargs)
[docs]@rank_zero_onlydefrank_zero_warn(message:Union[str,Warning],stacklevel:int=4,**kwargs:Any)->None:"""Function used to log warn-level messages only on rank 0."""_warn(message,stacklevel=stacklevel,**kwargs)
[docs]classLightningDeprecationWarning(DeprecationWarning):"""Deprecation warnings raised by PyTorch Lightning."""
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.