Source code for pytorch_lightning.utilities.apply_func
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Utilities used for collections."""importdataclassesimportoperatorfromabcimportABCfromcollectionsimportdefaultdict,OrderedDictfromcopyimportcopy,deepcopyfromfunctoolsimportpartialfromtypingimportAny,Callable,List,Mapping,Optional,Sequence,Tuple,UnionimportnumpyasnpimporttorchfromtorchimportTensorfrompytorch_lightning.utilities.exceptionsimportMisconfigurationExceptionfrompytorch_lightning.utilities.importsimport_compare_version,_TORCHTEXT_LEGACYfrompytorch_lightning.utilities.warningsimportrank_zero_deprecationif_TORCHTEXT_LEGACY:if_compare_version("torchtext",operator.ge,"0.9.0"):fromtorchtext.legacy.dataimportBatchelse:fromtorchtext.dataimportBatchelse:Batch=type(None)_BLOCKING_DEVICE_TYPES=("cpu","mps")defto_dtype_tensor(value:Union[int,float,List[Union[int,float]]],dtype:torch.dtype,device:Union[str,torch.device])->Tensor:returntorch.tensor(value,dtype=dtype,device=device)deffrom_numpy(value:np.ndarray,device:Union[str,torch.device])->Tensor:returntorch.from_numpy(value).to(device)CONVERSION_DTYPES:List[Tuple[Any,Callable[[Any,Any],Tensor]]]=[# bool -> uint8 as bool -> torch.bool triggers RuntimeError: Unsupported data type for NCCL process group(bool,partial(to_dtype_tensor,dtype=torch.uint8)),(int,partial(to_dtype_tensor,dtype=torch.int)),(float,partial(to_dtype_tensor,dtype=torch.float)),(np.ndarray,from_numpy),]def_is_namedtuple(obj:object)->bool:# https://github.com/pytorch/pytorch/blob/v1.8.1/torch/nn/parallel/scatter_gather.py#L4-L8returnisinstance(obj,tuple)andhasattr(obj,"_asdict")andhasattr(obj,"_fields")def_is_dataclass_instance(obj:object)->bool:# https://docs.python.org/3/library/dataclasses.html#module-level-decorators-classes-and-functionsreturndataclasses.is_dataclass(obj)andnotisinstance(obj,type)
[docs]defapply_to_collection(data:Any,dtype:Union[type,Any,Tuple[Union[type,Any]]],function:Callable,*args:Any,wrong_dtype:Optional[Union[type,Tuple[type,...]]]=None,include_none:bool=True,**kwargs:Any,)->Any:"""Recursively applies a function to all elements of a certain dtype. Args: data: the collection to apply the function to dtype: the given function will be applied to all elements of this dtype function: the function to apply *args: positional arguments (will be forwarded to calls of ``function``) wrong_dtype: the given function won't be applied if this type is specified and the given collections is of the ``wrong_dtype`` even if it is of type ``dtype`` include_none: Whether to include an element if the output of ``function`` is ``None``. **kwargs: keyword arguments (will be forwarded to calls of ``function``) Returns: The resulting collection """# Breaking conditionifisinstance(data,dtype)and(wrong_dtypeisNoneornotisinstance(data,wrong_dtype)):returnfunction(data,*args,**kwargs)elem_type=type(data)# Recursively apply to collection itemsifisinstance(data,Mapping):out=[]fork,vindata.items():v=apply_to_collection(v,dtype,function,*args,wrong_dtype=wrong_dtype,include_none=include_none,**kwargs)ifinclude_noneorvisnotNone:out.append((k,v))ifisinstance(data,defaultdict):returnelem_type(data.default_factory,OrderedDict(out))returnelem_type(OrderedDict(out))is_namedtuple=_is_namedtuple(data)is_sequence=isinstance(data,Sequence)andnotisinstance(data,str)ifis_namedtupleoris_sequence:out=[]fordindata:v=apply_to_collection(d,dtype,function,*args,wrong_dtype=wrong_dtype,include_none=include_none,**kwargs)ifinclude_noneorvisnotNone:out.append(v)returnelem_type(*out)ifis_namedtupleelseelem_type(out)if_is_dataclass_instance(data):# make a deepcopy of the data,# but do not deepcopy mapped fields since the computation would# be wasted on values that likely get immediately overwrittenfields={}memo={}forfieldindataclasses.fields(data):field_value=getattr(data,field.name)fields[field.name]=(field_value,field.init)memo[id(field_value)]=field_valueresult=deepcopy(data,memo=memo)# apply function to each fieldforfield_name,(field_value,field_init)infields.items():v=Noneiffield_init:v=apply_to_collection(field_value,dtype,function,*args,wrong_dtype=wrong_dtype,include_none=include_none,**kwargs,)ifnotfield_initor(notinclude_noneandvisNone):# retain old valuev=getattr(data,field_name)try:setattr(result,field_name,v)exceptdataclasses.FrozenInstanceErrorase:raiseMisconfigurationException("A frozen dataclass was passed to `apply_to_collection` but this is not allowed."" HINT: is your batch a frozen dataclass?")fromereturnresult# data is neither of dtype, nor a collectionreturndata
[docs]defapply_to_collections(data1:Optional[Any],data2:Optional[Any],dtype:Union[type,Any,Tuple[Union[type,Any]]],function:Callable,*args:Any,wrong_dtype:Optional[Union[type,Tuple[type]]]=None,**kwargs:Any,)->Any:"""Zips two collections and applies a function to their items of a certain dtype. Args: data1: The first collection data2: The second collection dtype: the given function will be applied to all elements of this dtype function: the function to apply *args: positional arguments (will be forwarded to calls of ``function``) wrong_dtype: the given function won't be applied if this type is specified and the given collections is of the ``wrong_dtype`` even if it is of type ``dtype`` **kwargs: keyword arguments (will be forwarded to calls of ``function``) Returns: The resulting collection Raises: AssertionError: If sequence collections have different data sizes. """ifdata1isNone:ifdata2isNone:return# in case they were passed reverseddata1,data2=data2,Noneelem_type=type(data1)ifisinstance(data1,dtype)anddata2isnotNoneand(wrong_dtypeisNoneornotisinstance(data1,wrong_dtype)):returnfunction(data1,data2,*args,**kwargs)ifisinstance(data1,Mapping)anddata2isnotNone:# use union because we want to fail if a key does not exist in bothzipped={k:(data1[k],data2[k])forkindata1.keys()|data2.keys()}returnelem_type({k:apply_to_collections(*v,dtype,function,*args,wrong_dtype=wrong_dtype,**kwargs)fork,vinzipped.items()})is_namedtuple=_is_namedtuple(data1)is_sequence=isinstance(data1,Sequence)andnotisinstance(data1,str)if(is_namedtupleoris_sequence)anddata2isnotNone:assertlen(data1)==len(data2),"Sequence collections have different sizes."out=[apply_to_collections(v1,v2,dtype,function,*args,wrong_dtype=wrong_dtype,**kwargs)forv1,v2inzip(data1,data2)]returnelem_type(*out)ifis_namedtupleelseelem_type(out)if_is_dataclass_instance(data1)anddata2isnotNone:ifnot_is_dataclass_instance(data2):raiseTypeError("Expected inputs to be dataclasses of the same type or to have identical fields"f" but got input 1 of type {type(data1)} and input 2 of type {type(data2)}.")ifnot(len(dataclasses.fields(data1))==len(dataclasses.fields(data2))andall(map(lambdaf1,f2:isinstance(f1,type(f2)),dataclasses.fields(data1),dataclasses.fields(data2)))):raiseTypeError("Dataclasses fields do not match.")# make a deepcopy of the data,# but do not deepcopy mapped fields since the computation would# be wasted on values that likely get immediately overwrittendata=[data1,data2]fields:List[dict]=[{},{}]memo:dict={}foriinrange(len(data)):forfieldindataclasses.fields(data[i]):field_value=getattr(data[i],field.name)fields[i][field.name]=(field_value,field.init)ifi==0:memo[id(field_value)]=field_valueresult=deepcopy(data1,memo=memo)# apply function to each fieldfor((field_name,(field_value1,field_init1)),(_,(field_value2,field_init2)))inzip(fields[0].items(),fields[1].items()):v=Noneiffield_init1andfield_init2:v=apply_to_collections(field_value1,field_value2,dtype,function,*args,wrong_dtype=wrong_dtype,**kwargs,)ifnotfield_init1ornotfield_init2orvisNone:# retain old valuereturnapply_to_collection(data1,dtype,function,*args,wrong_dtype=wrong_dtype,**kwargs)try:setattr(result,field_name,v)exceptdataclasses.FrozenInstanceErrorase:raiseMisconfigurationException("A frozen dataclass was passed to `apply_to_collections` but this is not allowed."" HINT: is your batch a frozen dataclass?")fromereturnresultreturnapply_to_collection(data1,dtype,function,*args,wrong_dtype=wrong_dtype,**kwargs)
[docs]classTransferableDataType(ABC):"""A custom type for data that can be moved to a torch device via ``.to(...)``. Example: >>> isinstance(dict, TransferableDataType) False >>> isinstance(torch.rand(2, 3), TransferableDataType) True >>> class CustomObject: ... def __init__(self): ... self.x = torch.rand(2, 2) ... def to(self, device): ... self.x = self.x.to(device) ... return self >>> isinstance(CustomObject(), TransferableDataType) True """@classmethoddef__subclasshook__(cls,subclass:Any)->Union[bool,Any]:ifclsisTransferableDataType:to=getattr(subclass,"to",None)returncallable(to)returnNotImplemented
[docs]defmove_data_to_device(batch:Any,device:Union[str,torch.device])->Any:"""Transfers a collection of data to the given device. Any object that defines a method ``to(device)`` will be moved and all other objects in the collection will be left untouched. Args: batch: A tensor or collection of tensors or anything that has a method ``.to(...)``. See :func:`apply_to_collection` for a list of supported collection types. device: The device to which the data should be moved Return: the same collection but with all contained tensors residing on the new device. See Also: - :meth:`torch.Tensor.to` - :class:`torch.device` """ifisinstance(device,str):device=torch.device(device)defbatch_to(data:Any)->Any:# try to move torchtext data firstif_TORCHTEXT_LEGACYandisinstance(data,Batch):# TODO: also remove the torchtext dependency with Lightning 1.8rank_zero_deprecation("The `torchtext.legacy.Batch` object is deprecated and Lightning will remove support for it in v1.8."" We recommend you to migrate away from Batch by following the TorchText README:"" https://github.com/pytorch/text#bc-breaking-legacy")# Shallow copy because each Batch has a reference to Dataset which contains all examplesdevice_data=copy(data)forfield,field_valueindata.dataset.fields.items():iffield_valueisNone:continuedevice_field=move_data_to_device(getattr(data,field),device)setattr(device_data,field,device_field)returndevice_datakwargs={}# Don't issue non-blocking transfers to CPU# Same with MPS due to a race condition bug: https://github.com/pytorch/pytorch/issues/83015ifisinstance(data,Tensor)andisinstance(device,torch.device)anddevice.typenotin_BLOCKING_DEVICE_TYPES:kwargs["non_blocking"]=Truedata_output=data.to(device,**kwargs)ifdata_outputisnotNone:returndata_output# user wrongly implemented the `TransferableDataType` and forgot to return `self`.returndatadtype=(TransferableDataType,Batch)if_TORCHTEXT_LEGACYelseTransferableDataTypereturnapply_to_collection(batch,dtype=dtype,function=batch_to)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.