Using Fabric with Distributed RPC

Hello,

I would like to use Fabric for multi-device training both on a single as well on multiple nodes.
In my setup, I have to share one global object across all workers, which is done via the Distributed
RPC framework. I would like to use Fabric for process creation and all the other functionality it offers.

I tried several conceptually different implementations but none of them worked. In general, all I need to do is to create a particular object on rank 0. Rank 0 then creates all workers with a RRef to that object, including rank 0.
A minimal example is below. The exepcted output is that all ranks print rank 0 as ref owner.

from lightning_fabric import Fabric
from torch.distributed import rpc
from torch.distributed.rpc import RRef

def test(fabric: Fabric, rref: RRef):
    print(f"running on rank {fabric.global_rank} with ref owner {rref.owner().id}\n")


def main():
    fabric = Fabric(accelerator="cpu", devices=4, strategy="ddp")
    fabric.launch()

    rpc.init_rpc(name=f"worker-{fabric.global_rank}", rank=fabric.global_rank, world_size=fabric.world_size)

    fabric.launch(test, RRef("test"))

    rpc.shutdown()


if __name__ == "__main__":
    main()
 

How would I use Fabric properly to a) configure RPC (I don’t want to do it manually, is it possible to avoid it?) and b) instantiate the object on rank 0 only, then make calls to all ranks with that reference.

My question is based on this issue: Correct construction of TensorDictReplayBuffer in DDP · Issue #1397 · pytorch/rl · GitHub

Thank you very much for your help.

I found a workaround (more a hack) to do what I wanted by introducing a global variable.

from lightning_fabric import Fabric
from torch.distributed import rpc
from torch.distributed.rpc import RRef

REF = None


def set_ref(value):
    global REF
    REF = value


def get_ref():
    return REF


def test(fabric: Fabric):
    rpc.init_rpc(name=f"worker-{fabric.global_rank}", rank=fabric.global_rank, world_size=fabric.world_size)

    if fabric.is_global_zero:
        set_ref(RRef("test"))

    fabric.barrier()

    ref = rpc.rpc_sync("worker-0", get_ref)
    print(f"got ref with owner {ref.owner().id} on rank {fabric.global_rank}")
    rpc.shutdown()


def main():
    fabric = Fabric(accelerator="cpu", devices=4, strategy="ddp")
    fabric.launch(test)


if __name__ == "__main__":
    main()

Does anybody have a better suggestion than this?