The following
coords_h = torch.tensor([0, 1, 2]) coords_w = torch.tensor([0, 1, 2]) xys = torch.stack(torch.meshgrid(coords_h, coords_w)).flatten(1) print(xys)
gives
tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2], [0, 1, 2, 0, 1, 2, 0, 1, 2]])
Now we rearrange
xys = rearrange(xys, "i coord -> coord i") print(xys)
to get
tensor([[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2], [2, 0], [2, 1], [2, 2]])