TPU training (Advanced)

Audience: Users looking to apply advanced performance techniques to TPU training.


This is an experimental feature.

Weight Sharing/Tying

Weight Tying/Sharing is a technique where in the module weights are shared among two or more layers. This is a common method to reduce memory consumption and is utilized in many State of the Art architectures today.

PyTorch XLA requires these weights to be tied/shared after moving the model to the XLA device. To support this requirement, Lightning automatically finds these weights and ties them after the modules are moved to the XLA device under the hood. It will ensure that the weights among the modules are shared but not copied independently.

PyTorch Lightning has an inbuilt check which verifies that the model parameter lengths match once the model is moved to the device. If the lengths do not match Lightning throws a warning message.


from lightning.pytorch.core.module import LightningModule
from torch import nn
from lightning.pytorch.trainer.trainer import Trainer

class WeightSharingModule(LightningModule):
    def __init__(self):
        self.layer_1 = nn.Linear(32, 10, bias=False)
        self.layer_2 = nn.Linear(10, 32, bias=False)
        self.layer_3 = nn.Linear(32, 10, bias=False)
        # Lightning automatically ties these weights after moving to the XLA device,
        # so all you need is to write the following just like on other accelerators.
        self.layer_3.weight = self.layer_1.weight

    def forward(self, x):
        x = self.layer_1(x)
        x = self.layer_2(x)
        x = self.layer_3(x)
        return x

model = WeightSharingModule()
trainer = Trainer(max_epochs=1, accelerator="tpu")

See XLA Documentation


XLA is the library that interfaces PyTorch with the TPUs. For more information check out XLA.

Guide for troubleshooting XLA