pytorch张量广播机制示例

发布时间 2023-07-25 12:16:48作者: Picassooo
import torch 
box = torch.tensor([            # 边界框的坐标,(x1, y1, x2, y2). box'shape: (3, 4)
    [0.1, 0.2, 0.5, 0.3], 
    [0.6, 0.6, 0.9, 0.9], 
    [0.1, 0.1, 0.2, 0.2]
    ])

whwh = torch.tensor([200, 400, 200, 400])

box_new = box * whwh[None, :]   
# None的作用是给whwh扩增第0维,扩增之后,whwh的shape是(1, 4). 然后就可以用pytorch的张量广播机制相乘了。
print(box_new)

输出:

 

我做张量广播时的一个办法是,先用None的方式把两个张量的维度数目变得相等,然后进行广播。虽然广播机制在维度数目不相等时也可以应用,但是为方便理解,我习惯上还是先把维度数目变为相等,再进行广播机制,就像上面的示例一样。

 

下面这几句总结的话,摘自:张量的广播机制。这篇博客写的挺清晰的。