Shortcuts

apply_func

Functions

apply_to_collection

Recursively applies a function to all elements of a certain dtype.

apply_to_collections

Zips two collections and applies a function to their items of a certain dtype.

convert_to_tensors

rtype

Any

from_numpy

rtype

Tensor

move_data_to_device

Transfers a collection of data to the given device.

to_dtype_tensor

rtype

Tensor

Classes

TransferableDataType

A custom type for data that can be moved to a torch device via .to(...).

Utilities used for collections.

class pytorch_lightning.utilities.apply_func.TransferableDataType[source]

Bases: abc.ABC

A custom type for data that can be moved to a torch device via .to(...).

Example

>>> isinstance(dict, TransferableDataType)
False
>>> isinstance(torch.rand(2, 3), TransferableDataType)
True
>>> class CustomObject:
...     def __init__(self):
...         self.x = torch.rand(2, 2)
...     def to(self, device):
...         self.x = self.x.to(device)
...         return self
>>> isinstance(CustomObject(), TransferableDataType)
True
pytorch_lightning.utilities.apply_func.apply_to_collection(data, dtype, function, *args, wrong_dtype=None, include_none=True, **kwargs)[source]

Recursively applies a function to all elements of a certain dtype.

Parameters
  • data (Any) – the collection to apply the function to

  • dtype (Union[type, Any, Tuple[Union[type, Any]]]) – the given function will be applied to all elements of this dtype

  • function (Callable) – the function to apply

  • *args – positional arguments (will be forwarded to calls of function)

  • wrong_dtype (Union[type, Tuple[type], None]) – the given function won’t be applied if this type is specified and the given collections is of the wrong_dtype even if it is of type dtype

  • include_none (bool) – Whether to include an element if the output of function is None.

  • **kwargs – keyword arguments (will be forwarded to calls of function)

Return type

Any

Returns

The resulting collection

pytorch_lightning.utilities.apply_func.apply_to_collections(data1, data2, dtype, function, *args, wrong_dtype=None, **kwargs)[source]

Zips two collections and applies a function to their items of a certain dtype.

Parameters
  • data1 (Optional[Any]) – The first collection

  • data2 (Optional[Any]) – The second collection

  • dtype (Union[type, Any, Tuple[Union[type, Any]]]) – the given function will be applied to all elements of this dtype

  • function (Callable) – the function to apply

  • *args – positional arguments (will be forwarded to calls of function)

  • wrong_dtype (Union[type, Tuple[type], None]) – the given function won’t be applied if this type is specified and the given collections is of the wrong_dtype even if it is of type dtype

  • **kwargs – keyword arguments (will be forwarded to calls of function)

Return type

Any

Returns

The resulting collection

Raises

AssertionError – If sequence collections have different data sizes.

pytorch_lightning.utilities.apply_func.move_data_to_device(batch, device)[source]

Transfers a collection of data to the given device. Any object that defines a method to(device) will be moved and all other objects in the collection will be left untouched.

Parameters
  • batch (Any) – A tensor or collection of tensors or anything that has a method .to(...). See apply_to_collection() for a list of supported collection types.

  • device (Union[str, device]) – The device to which the data should be moved

Return type

Any

Returns

the same collection but with all contained tensors residing on the new device.

You are viewing an outdated version of PyTorch Lightning Docs

Click here to view the latest version→