pytorch-tensor高级OP

发布时间 2023-07-31 23:16:33作者: 哎呦哎(iui)

Tensor advanced operation
▪ Where
▪ Gather

where

返回的最终的tensor中的每一个值有可能来着A,也有可能来自B。

torch.where(condition,A,B)->tensor
满足condition条件则该位置为A[..],否则为B[..]。
这个condition也是一个相同shape的tensor
比如说:torch.where(cond>0,a,b)
\( \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ \end{bmatrix} \) =>\( \begin{bmatrix} A & B \\ A & B \\ \end{bmatrix} \)

image

cond=torch.tensor([[0.6769,0.7271],[0.8884,0.4163]])
cond
# tensor([[0.6769, 0.7271],
#         [0.8884, 0.4163]])



a=torch.ones(2,2)
a
# tensor([[1., 1.],
#         [1., 1.]])



b=torch.zeros(2,2)
b
# tensor([[0., 0.],
#         [0., 0.]])



torch.where(cond>0.5,a,b)
# tensor([[1., 1.],
#         [1., 0.]])

image

gather

image
这个API设计的初衷就是这样的,下面有一个场景。
下面有三类动物,编号分别为0,1,2

\[\begin{bmatrix} dog \\ cat \\ pig \\ \end{bmatrix} \begin{matrix} 0 \\ 1 \\ 2 \end{matrix} \]

然后我们识别之后的结果是一些编号,然后我们希望将这个结果编号变为类别

\[\begin{bmatrix} 1 \\ 0 \\ 1 \\ 2 \end{bmatrix} => \begin{bmatrix} cat \\ dog \\ cat \\ pig \end{bmatrix} \]

所以这个API中input和index都是tensor。
然后我们看一个具体的例子:
image

idx
# tensor([[8, 2, 0],
#         [4, 8, 1],
#         [0, 4, 9],
#         [0, 9, 2]])

label
# tensor([100, 101, 102, 103, 104, 105, 106, 107, 108, 109])

torch.gather(label.expand(4,10),dim=1,index=idx.long())
# tensor([[108, 102, 100],
#         [104, 108, 101],
#         [100, 104, 109],
#         [100, 109, 102]])

我们来分析一下torch.gather(label.expand(4,10),dim=1,index=idx.long())
首先这个label.expand(4,10)是这样的
image
我们来看看这个label.expand(4,10),是将input的shape变为(4,10),然后idx的每一行都可以按照label变化之后每一行的下标输出了,所以这个dim=1,就是按照10这个下标输出的。