# Copyright The Lightning AI team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importinspectfromcollections.abcimportGenerator,Iterable,Mapping,SizedfromdataclassesimportfieldsfromtypingimportAny,Optional,Unionimporttorchfromlightning_utilities.core.apply_funcimportis_dataclass_instancefromtorchimportTensorfromtorch.utils.dataimportBatchSampler,DataLoader,IterableDataset,RandomSampler,Sampler,SequentialSamplerfromtyping_extensionsimportTypeGuardimportlightning.pytorchasplfromlightning.fabric.utilities.dataimport(_reinstantiate_wrapped_cls,_replace_value_in_saved_args,has_iterable_dataset,sized_len,)fromlightning.fabric.utilities.warningsimportPossibleUserWarningfromlightning.pytorch.overrides.distributedimport_IndexBatchSamplerWrapperfromlightning.pytorch.trainer.statesimportRunningStagefromlightning.pytorch.utilities.exceptionsimportMisconfigurationExceptionfromlightning.pytorch.utilities.rank_zeroimportWarningCache,rank_zero_warnBType=Union[Tensor,str,Mapping[Any,"BType"],Iterable["BType"]]warning_cache=WarningCache()def_extract_batch_size(batch:BType)->Generator[Optional[int],None,None]:ifisinstance(batch,Tensor):ifbatch.ndim==0:yield1else:yieldbatch.size(0)elifisinstance(batch,(Iterable,Mapping))andnotisinstance(batch,str):ifisinstance(batch,Mapping):batch=batch.values()forsampleinbatch:yield from_extract_batch_size(sample)elifis_dataclass_instance(batch):forfieldinfields(batch):# type: ignore[arg-type]yield from_extract_batch_size(getattr(batch,field.name))else:yieldNone
[docs]defextract_batch_size(batch:BType)->int:"""Unpack a batch to find a ``torch.Tensor``. Returns: ``len(tensor)`` when found, or ``1`` when it hits an empty or non iterable. """error_msg=("We could not infer the batch_size from the batch. Either simplify its structure"" or provide the batch_size as `self.log(..., batch_size=batch_size)`.")batch_size=Nonetry:forbsin_extract_batch_size(batch):ifbatch_sizeisNone:batch_size=bselifbatch_size!=bs:warning_cache.warn("Trying to infer the `batch_size` from an ambiguous collection. The batch size we"f" found is {batch_size}. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.")breakexceptRecursionError:raiseRecursionError(error_msg)ifbatch_sizeisNone:raiseMisconfigurationException(error_msg)returnbatch_size
[docs]defhas_len_all_ranks(dataloader:object,strategy:"pl.strategies.Strategy",allow_zero_length_dataloader_with_multiple_devices:bool=False,)->TypeGuard[Sized]:"""Checks if a given object has ``__len__`` method implemented on all ranks."""local_length=sized_len(dataloader)iflocal_lengthisNone:# __len__ is not defined, skip these checksreturnFalsetotal_length=strategy.reduce(torch.tensor(local_length,device=strategy.root_device),reduce_op="sum")iftotal_length==0:rank_zero_warn(f"Total length of `{type(dataloader).__name__}` across ranks is zero."" Please make sure this was your intention.")iftotal_length>0andlocal_length==0:dataloader_cls_name=type(dataloader).__name__ifnotallow_zero_length_dataloader_with_multiple_devices:raiseRuntimeError(f"`{dataloader_cls_name}` within local rank has zero length."" Please make sure that it returns at least 1 batch.")rank_zero_warn(f"Total length of `{dataloader_cls_name}` across ranks is zero, but local rank has zero"" length. Please be cautious of uneven batch length.")ifhas_iterable_dataset(dataloader):rank_zero_warn("Your `IterableDataset` has `__len__` defined."" In combination with multi-process data loading (when num_workers > 1),"" `__len__` could be inaccurate if each worker is not configured independently"" to avoid having duplicate data.")returnTrue
def_update_dataloader(dataloader:DataLoader,sampler:Union[Sampler,Iterable],mode:Optional[RunningStage]=None)->DataLoader:dl_args,dl_kwargs=_get_dataloader_init_args_and_kwargs(dataloader,sampler,mode)return_reinstantiate_wrapped_cls(dataloader,*dl_args,**dl_kwargs)def_get_dataloader_init_args_and_kwargs(dataloader:DataLoader,sampler:Union[Sampler,Iterable],mode:Optional[RunningStage]=None,)->tuple[tuple[Any],dict[str,Any]]:ifnotisinstance(dataloader,DataLoader):raiseValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`")was_wrapped=hasattr(dataloader,"__pl_saved_args")ifwas_wrapped:dl_args=dataloader.__pl_saved_argsdl_kwargs=dataloader.__pl_saved_kwargsarg_names=dataloader.__pl_saved_arg_namesoriginal_dataset=dataloader.__dataset# we have this saved from _wrap_initelse:# get the dataloader instance attributesattrs={k:vfork,vinvars(dataloader).items()ifnotk.startswith("_")}# We cannot be 100% sure the class sets dataset argument. Let's set it to None to be safe# and hope we can get it from the instance attributesoriginal_dataset=None# not part of `vars`attrs["multiprocessing_context"]=dataloader.multiprocessing_contextarg_names=()# get the dataloader instance `__init__` parametersparams=dict(inspect.signature(dataloader.__init__).parameters)# type: ignore[misc]has_variadic_kwargs=any(p.kindisp.VAR_KEYWORDforpinparams.values())ifhas_variadic_kwargs:# if the signature takes **kwargs, assume they will be passed down with `super().__init__(**kwargs)`ifwas_wrapped:# if the dataloader was wrapped in a hook, only take arguments with default values# and assume user passes their kwargs correctlyparams.update({k:vfork,vininspect.signature(DataLoader.__init__).parameters.items()ifv.defaultisnotv.empty})else:params.update(inspect.signature(DataLoader.__init__).parameters)params.pop("self",None)ifnotwas_wrapped:# keep only the params whose default is different to the current attr valuenon_defaults={nameforname,pinparams.items()ifnameinattrsandp.defaultisnotattrs[name]}# add `dataset` as it might have been replaced with `*args`non_defaults.add("dataset")# kwargs to re-construct the dataloaderdl_kwargs={k:vfork,vinattrs.items()ifkinnon_defaults}dl_args=()dataset=dl_kwargs.get("dataset",original_dataset)ifisinstance(dataset,IterableDataset):dl_kwargs["batch_sampler"]=Nonedl_kwargs["sampler"]=Noneelse:dl_kwargs.update(_dataloader_init_kwargs_resolve_sampler(dataloader,sampler,mode))required_args={p.nameforpinparams.values()ifp.kindin(p.POSITIONAL_ONLY,p.POSITIONAL_OR_KEYWORD)andp.defaultisp.emptyandp.namenotindl_kwargsandp.namenotinarg_names}# the dataloader has required args which we could not extract from the existing attributesifrequired_args:sorted_required_args=sorted(required_args)dataloader_cls_name=dataloader.__class__.__name__missing_args_message=", ".join(f"`self.{arg_name}`"forarg_nameinsorted_required_args)raiseMisconfigurationException(f"Trying to inject custom `Sampler` into the `{dataloader_cls_name}` instance. ""This would fail as some of the `__init__` arguments are not available as instance attributes. "f"The missing attributes are {sorted_required_args}. If you instantiate your `{dataloader_cls_name}` ""inside a `*_dataloader` hook of your module, we will do this for you."f" Otherwise, define {missing_args_message} inside your `__init__`.")ifnothas_variadic_kwargs:# the dataloader signature does not allow keyword arguments that need to be passedmissing_kwargs=(set(dl_kwargs)|set(arg_names))-params.keys()ifmissing_kwargs:sorted_missing_kwargs=sorted(missing_kwargs)dataloader_cls_name=dataloader.__class__.__name__raiseMisconfigurationException(f"Trying to inject parameters into the `{dataloader_cls_name}` instance. ""This would fail as it doesn't expose all its attributes in the `__init__` signature. "f"The missing arguments are {sorted_missing_kwargs}. HINT: If you wrote the `{dataloader_cls_name}` ""class, add the `__init__` arguments or allow passing `**kwargs`")returndl_args,dl_kwargsdef_dataloader_init_kwargs_resolve_sampler(dataloader:DataLoader,sampler:Union[Sampler,Iterable],mode:Optional[RunningStage]=None,)->dict[str,Any]:"""This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its re- instantiation. If the dataloader is being used for prediction, the sampler will be wrapped into an `_IndexBatchSamplerWrapper`, so Lightning can keep track of its indices. """is_predicting=mode==RunningStage.PREDICTINGbatch_sampler=getattr(dataloader,"batch_sampler")batch_sampler_cls=type(batch_sampler)ifbatch_samplerisnotNoneand(batch_sampler_clsisnotBatchSampleroris_predicting):ifhasattr(batch_sampler,"__pl_saved_args"):# This is a PyTorch `BatchSampler` subclass for which we captured the init argsargs=batch_sampler.__pl_saved_argskwargs=batch_sampler.__pl_saved_kwargsdefault_kwargs=batch_sampler.__pl_saved_default_kwargsarg_names=batch_sampler.__pl_saved_arg_namesifis_predicting:success,args,kwargs=_replace_value_in_saved_args("drop_last",False,args,kwargs,default_kwargs,arg_names)ifnotsuccess:rank_zero_warn(f"Trying to inject `drop_last=False` into batch sampler since you are predicting, however "f"it seems the class `{batch_sampler_cls.__qualname__}` does not support it. ""Your predictions might be incomplete. To mitigate this, expose `drop_last` in ""the `__init__` method of your custom class.")success,args,kwargs=_replace_value_in_saved_args("sampler",sampler,args,kwargs,default_kwargs,arg_names)ifnotsuccess:raiseTypeError("Trying to inject a modified sampler into the batch sampler; however, it seems the class "f"`{batch_sampler_cls.__qualname__}` does not have an argument called `sampler.` To mitigate ""this, expose an argument `sampler` in the `__init__` method of your custom class.")batch_sampler=_reinstantiate_wrapped_cls(batch_sampler,*args,**kwargs)elifhasattr(batch_sampler,"batch_size")andhasattr(batch_sampler,"drop_last"):# This is a sampler for which we could not capture the init args, but it kinda looks like a batch sampler# even if it does not inherit from PyTorch's interface.try:batch_sampler=batch_sampler_cls(sampler,batch_size=batch_sampler.batch_size,drop_last=(Falseifis_predictingelsebatch_sampler.drop_last),)exceptTypeErrorasex:importrematch=re.match(r".*__init__\(\) (got multiple values)|(missing \d required)",str(ex))ifnotmatch:# an unexpected `TypeError`, continue failureraise# There could either be too few or too many arguments. Customizing the message based on this doesn't# make much sense since our MisconfigurationException is going to be raised from the original one.raiseTypeError(" Lightning can't inject a (distributed) sampler into your batch sampler, because it doesn't"" subclass PyTorch's `BatchSampler`. To mitigate this, either follow the API of `BatchSampler` and"" instantiate your custom batch sampler inside the `*_dataloader` hook of your module,"" or set `Trainer(use_distributed_sampler=False)`. If you choose the latter, you will be"" responsible for handling the distributed sampling within your batch sampler.")fromexelifis_predicting:rank_zero_warn(f"You are using a custom batch sampler `{batch_sampler_cls.__qualname__}` for prediction."" Lightning would normally set `drop_last=False` to ensure all samples are returned, but for"" custom samplers it can't guarantee this. Make sure your sampler is configured correctly to return"" all indices.",category=PossibleUserWarning,)else:# The sampler is not a PyTorch `BatchSampler`, we don't know how to inject a custom sampler or# how to adjust the `drop_last` valueraiseTypeError(" Lightning can't inject a (distributed) sampler into your batch sampler, because it doesn't"" subclass PyTorch's `BatchSampler`. To mitigate this, either follow the API of `BatchSampler`"" or set `Trainer(use_distributed_sampler=False)`. If you choose the latter, you will be"" responsible for handling the distributed sampling within your batch sampler.")ifis_predicting:batch_sampler=_IndexBatchSamplerWrapper(batch_sampler)# batch_sampler option is mutually exclusive with batch_size, shuffle, sampler, and drop_lastreturn{"sampler":None,"shuffle":False,"batch_sampler":batch_sampler,"batch_size":1,"drop_last":False,}return{"sampler":sampler,"shuffle":False,"batch_sampler":None}def_is_dataloader_shuffled(dataloader:object)->bool:ifhasattr(dataloader,"__pl_saved_kwargs"):# this attribute is not part of PyTorch's DataLoader, but could have been set by# our `_replace_init_method` context managerif"shuffle"indataloader.__pl_saved_kwargs:returndataloader.__pl_saved_kwargs["shuffle"]if"shuffle"indataloader.__pl_saved_arg_names:returndataloader.__pl_saved_args[dataloader.__pl_saved_arg_names.index("shuffle")]ifhasattr(dataloader,"dataset")andisinstance(dataloader.dataset,IterableDataset):# shuffling is useless with iterable datasetsreturnFalseifnothasattr(dataloader,"sampler"):# shuffling is enabled via a sampler. No sampler, no shufflingreturnFalsebatch_sampler=dataloader.batch_samplerifbatch_samplerisnotNone:# custom batch samplers may not have an internal .samplersampler=batch_sampler.samplerifhasattr(batch_sampler,"sampler")elsebatch_samplerelse:sampler=dataloader.samplerifisinstance(sampler,SequentialSampler):returnFalsereturnisinstance(sampler,RandomSampler)
To analyze traffic and optimize your experience, we serve cookies on this
site. By clicking or navigating, you agree to allow our usage of cookies.
Read PyTorch Lightning's
Privacy Policy.
You are viewing an outdated version of PyTorch Lightning Docs