0%
September 7, 2022

Mesh Grid Trick

deep-learning

pytorch

The following

coords_h = torch.tensor([0, 1, 2])
coords_w = torch.tensor([0, 1, 2])
xys = torch.stack(torch.meshgrid(coords_h, coords_w)).flatten(1)
print(xys)

gives

tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2],
        [0, 1, 2, 0, 1, 2, 0, 1, 2]])

Now we rearrange

xys = rearrange(xys, "i coord -> coord i")
print(xys)

to get

tensor([[0, 0],
        [0, 1],
        [0, 2],
        [1, 0],
        [1, 1],
        [1, 2],
        [2, 0],
        [2, 1],
        [2, 2]])