• Docs >
  • TPU training (Intermediate)
Shortcuts

TPU training (Intermediate)

Audience: Users looking to use cloud TPUs.

Warning

This is an experimental feature.


DistributedSamplers

Lightning automatically inserts the correct samplers - no need to do this yourself!

Usually, with TPUs (and DDP), you would need to define a DistributedSampler to move the right chunk of data to the appropriate TPU. As mentioned, this is not needed in Lightning

Note

Don’t add distributedSamplers. Lightning does this automatically

If for some reason you still need to, this is how to construct the sampler for TPU use

import torch_xla.core.xla_model as xm


def train_dataloader(self):
    dataset = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())

    # required for TPU support
    sampler = None
    if use_tpu:
        sampler = torch.utils.data.distributed.DistributedSampler(
            dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal(), shuffle=True
        )

    loader = DataLoader(dataset, sampler=sampler, batch_size=32)

    return loader

Configure the number of TPU cores in the trainer. You can only choose 1 or 8. To use a full TPU pod skip to the TPU pod section.

import lightning.pytorch as pl

my_model = MyLightningModule()
trainer = pl.Trainer(accelerator="tpu", devices=8)
trainer.fit(my_model)

That’s it! Your model will train on all 8 TPU cores.


TPU VM

Lightning supports training on the new Cloud TPU VMs. Previously, we needed separate VMs to connect to the TPU machines, but as Cloud TPU VMs run on the TPU Host machines, it allows direct SSH access for the users. Hence, this architecture upgrade leads to cheaper and significantly better performance and usability while working with TPUs.

The TPUVMs come pre-installed with latest versions of PyTorch and PyTorch XLA. After connecting to the VM and before running your Lightning code, you would need to set the XRT TPU device configuration.

export XRT_TPU_CONFIG="localservice;0;localhost:51011"

# Set the environment variable to visible devices.
# You might need to change the value depending on how many chips you have
export TPU_NUM_DEVICES=4

# Allow LIBTPU LOAD by multiple processes
export ALLOW_MULTIPLE_LIBTPU_LOAD=1

You can learn more about the Cloud TPU VM architecture here


TPU Pod

To train on more than the number of cores in a node, your code actually doesn’t change!

All TPU VMs in a Pod setup are required to access the model code and data. One easy way to achieve this is to use the following startup script when creating the TPU VM pod. It will perform the data downloading on all TPU VMs. Note that you need to export the corresponding environment variables following the instruction in Create TPU Node.

gcloud alpha compute tpus tpu-vm create ${TPU_POD_NAME} --zone ${ZONE} --project ${PROJECT_ID} --accelerator-type ${ACCELERATOR_TYPE} --version ${RUNTIME_VERSION} --metadata startup-script=setup.py

Then you could ssh to any TPU worker, e.g. worker 0, check if data/model downloading is finished and start the training after generating the ssh-keys to ssh between VM workers on a pod. All you need to do is submit the following command:

python3 -m torch_xla.distributed.xla_dist --tpu=$TPU_POD_NAME -- python3 train.py --max_epochs=5 --batch_size=32

See this guide on how to set up the instance groups and VMs needed to run TPU Pods.


16 bit precision

Lightning also supports training in 16-bit precision with TPUs. By default, TPU training will use 32-bit precision. To enable it, do

import lightning.pytorch as pl

my_model = MyLightningModule()
trainer = pl.Trainer(accelerator="tpu", precision="16-true")
trainer.fit(my_model)

Under the hood the xla library will use the bfloat16 type.