torch.unique

发布时间 2023-09-22 17:45:22作者: kksk43

写代码的时候想把一个张量\(X\)中的最后一个维度进行类似集合那样的操作,于是网上找到了torch.unique这个方法(官方文档

torch.unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None) → Tuple[Tensor, Tensor, Tensor]

其中参数sortedreturn_inversereturn_counts网上有很多介绍(可以参考这篇),这里就不再赘述
对于参数dim,就是先按指定维度将原张量\(X\)划分多个子张量\(x_0,x_1,...,x_n\),再对这些子张量做unique操作

举例子:

X = torch.tensor([[3, 3, 3, 2, 4],
                  [0, 3, 0, 2, 4]])

# dim=0的情况
uni_X = torch.unique(X, dim=0)
print(uni_X)
# tensor([[0, 3, 0, 2, 4],
#         [3, 3, 3, 2, 4]])

# dim=1的情况
uni_X = torch.unique(X, dim=1)
print(uni_X)
# tensor([[2, 3, 3, 4],
#         [2, 0, 3, 4]])

以上例子,在dim=0情况时,按以下步骤理解:

  1. 先按第0维(即按行)把原张量\(X\)划分为2个子张量:\(x_0=[3, 3, 3, 2, 4],x_1=[0, 3, 0, 2, 4]\)
  2. 由于\(x_0\neq x_1\)即没有重复的,所以结果还是由\(x_0,x_1\)组成
  3. 别忘了默认参数sorted=True,即对\(x_0\)\(x_1\)进行字典序升序后再返回,这里按字典序有\(x_1<x_0\),所以返回张量\([x_1; x_0]=[[0, 3, 0, 2, 4];[3, 3, 3, 2, 4]]\)(按第0维划分的就按第0维拼回去)

在dim=1情况时,按以下步骤理解:

  1. 先按第1维(即按列)把原张量\(X\)划分为5个子张量:\(x_0=[3, 0],x_1=[3, 3],x_2=[3, 0],x_3=[2, 2],x_4=[4, 4]\)
  2. 可以发现只有\(x_0= x_2\)即只有这对重复,所以结果由\(x_0,x_1,x_3,x_4\)组成(或由\(x_1,x_2,x_3,x_4\)组成)
  3. sorted=True影响,对\(x_0,x_1,x_3,x_4\)进行字典序升序后再返回,这里按字典序有\(x_3<x_0<x_2<x_4\),所以返回张量\([x_3,x_0,x_2,x_4]=[[2, 3, 3, 4];[2, 0, 3, 4]]\)(按第1维划分的就按第1维拼回去)