Shortcuts

apply_func

Functions

apply_to_collection

Recursively applies a function to all elements of a certain dtype.

apply_to_collections

Zips two collections and applies a function to their items of a certain dtype.

convert_to_tensors

rtype:

Any

from_numpy

rtype:

Tensor

move_data_to_device

Transfers a collection of data to the given device.

to_dtype_tensor

rtype:

Tensor

Classes

TransferableDataType

A custom type for data that can be moved to a torch device via .to(...).

Utilities used for collections.

class pytorch_lightning.utilities.apply_func.TransferableDataType[source]

Bases: abc.ABC

A custom type for data that can be moved to a torch device via .to(...).

Example

>>> isinstance(dict, TransferableDataType)
False
>>> isinstance(torch.rand(2, 3), TransferableDataType)
True
>>> class CustomObject:
...     def __init__(self):
...         self.x = torch.rand(2, 2)
...     def to(self, device):
...         self.x = self.x.to(device)
...         return self
>>> isinstance(CustomObject(), TransferableDataType)
True
pytorch_lightning.utilities.apply_func.apply_to_collection(data, dtype, function, *args, wrong_dtype=None, include_none=True, **kwargs)[source]

Recursively applies a function to all elements of a certain dtype.

Parameters:
  • data (Any) – the collection to apply the function to

  • dtype (Union[type, Any, Tuple[Union[type, Any]]]) – the given function will be applied to all elements of this dtype

  • function (Callable) – the function to apply

  • *args – positional arguments (will be forwarded to calls of function)

  • wrong_dtype (Union[type, Tuple[type, …], None]) – the given function won’t be applied if this type is specified and the given collections is of the wrong_dtype even if it is of type dtype

  • include_none (bool) – Whether to include an element if the output of function is None.

  • **kwargs – keyword arguments (will be forwarded to calls of function)

Return type:

Any

Returns:

The resulting collection

pytorch_lightning.utilities.apply_func.apply_to_collections(data1, data2, dtype, function, *args, wrong_dtype=None, **kwargs)[source]

Zips two collections and applies a function to their items of a certain dtype.

Parameters:
  • data1 (Optional[Any]) – The first collection

  • data2 (Optional[Any]) – The second collection

  • dtype (Union[type, Any, Tuple[Union[type, Any]]]) – the given function will be applied to all elements of this dtype

  • function (Callable) – the function to apply

  • *args – positional arguments (will be forwarded to calls of function)

  • wrong_dtype (Union[type, Tuple[type], None]) – the given function won’t be applied if this type is specified and the given collections is of the wrong_dtype even if it is of type dtype

  • **kwargs – keyword arguments (will be forwarded to calls of function)

Return type:

Any

Returns:

The resulting collection

Raises:

AssertionError – If sequence collections have different data sizes.

pytorch_lightning.utilities.apply_func.move_data_to_device(batch, device)[source]

Transfers a collection of data to the given device. Any object that defines a method to(device) will be moved and all other objects in the collection will be left untouched.

Parameters:
  • batch (Any) – A tensor or collection of tensors or anything that has a method .to(...). See apply_to_collection() for a list of supported collection types.

  • device (Union[str, device]) – The device to which the data should be moved

Return type:

Any

Returns:

the same collection but with all contained tensors residing on the new device.