小灰灰深度学习之关于三维张量的一些索引

发布时间 2023-06-14 00:37:55作者: 啥都不会的灰太狼

首先要感谢CSDN中http://t.csdn.cn/XyT4e这篇文章(我接下来写的内容,也和这篇文章基本一样)

下面是我实际操作得到的结果:

我们看第一种情况的代码:

import torch
b = torch.arange(1, 61).reshape(3, 4, 5)
idx1 = torch.tensor([0, 0, 2]).unsqueeze(-1).repeat(1, 4)
bb = b[idx1, :
print(bb)

我们先来看一下张量b的内容:

'''
张量b内容:
 tensor([[[ 1,  2,  3,  4,  5],
          [ 6,  7,  8,  9, 10],
          [11, 12, 13, 14, 15],
          [16, 17, 18, 19, 20]],
 
         [[21, 22, 23, 24, 25],
          [26, 27, 28, 29, 30],
          [31, 32, 33, 34, 35],
          [36, 37, 38, 39, 40]],
 
         [[41, 42, 43, 44, 45],
          [46, 47, 48, 49, 50],
          [51, 52, 53, 54, 55],
          [56, 57, 58, 59, 60]]]))
'''

接下来我们看一下索引得到的张量bb的内容(这个有点长):

'''
张量bb的内容为:
tensor([[[[ 1,  2,  3,  4,  5],
          [ 6,  7,  8,  9, 10],
          [11, 12, 13, 14, 15],
          [16, 17, 18, 19, 20]],

         [[ 1,  2,  3,  4,  5],
          [ 6,  7,  8,  9, 10],
          [11, 12, 13, 14, 15],
          [16, 17, 18, 19, 20]],

         [[ 1,  2,  3,  4,  5],
          [ 6,  7,  8,  9, 10],
          [11, 12, 13, 14, 15],
          [16, 17, 18, 19, 20]],

         [[ 1,  2,  3,  4,  5],
          [ 6,  7,  8,  9, 10],
          [11, 12, 13, 14, 15],
          [16, 17, 18, 19, 20]]],


        [[[ 1,  2,  3,  4,  5],
          [ 6,  7,  8,  9, 10],
          [11, 12, 13, 14, 15],
          [16, 17, 18, 19, 20]],

         [[ 1,  2,  3,  4,  5],
          [ 6,  7,  8,  9, 10],
          [11, 12, 13, 14, 15],
          [16, 17, 18, 19, 20]],

         [[ 1,  2,  3,  4,  5],
          [ 6,  7,  8,  9, 10],
          [11, 12, 13, 14, 15],
          [16, 17, 18, 19, 20]],

         [[ 1,  2,  3,  4,  5],
          [ 6,  7,  8,  9, 10],
          [11, 12, 13, 14, 15],
          [16, 17, 18, 19, 20]]],


        [[[41, 42, 43, 44, 45],
          [46, 47, 48, 49, 50],
          [51, 52, 53, 54, 55],
          [56, 57, 58, 59, 60]],

         [[41, 42, 43, 44, 45],
          [46, 47, 48, 49, 50],
          [51, 52, 53, 54, 55],
          [56, 57, 58, 59, 60]],

         [[41, 42, 43, 44, 45],
          [46, 47, 48, 49, 50],
          [51, 52, 53, 54, 55],
          [56, 57, 58, 59, 60]],

         [[41, 42, 43, 44, 45],
          [46, 47, 48, 49, 50],
          [51, 52, 53, 54, 55],
          [56, 57, 58, 59, 60]]]])
'''

首先b是(3,4,5)的张量,然后b[idx1, :, ]这里索引就是将张量b的后面两个维度(4, 5)当作一个整体。然后根据idx1中的内容进行索引。又因为idx1的shape是(3, 4)

所以索引后的bb的shape为(3, 4, 4, 5)。然后也就得到了那个结果。

 

第二种情况的代码为(这里我们的张量b仍选用第一种情况的张量b):

import torch
b = torch.arange(1, 61).reshape(3, 4, 5)
idx1 = torch.tensor([0, 0, 2]).unsqueeze(-1).repeat(1, 4)
idx2 = torch.randint(0, 4, (3, 4), dtype = torch.long)
cc = b[idx1, idx2]
cc.shape,cc

此时的输出结果为:

'''
idx2得到的随机值为:
tensor([[1, 2, 1, 3],
            [2, 1, 1, 2],
            [2, 1, 0, 3]])

cc.shape:
torch.Size([3, 4, 5])

张量cc为:
 tensor([[[ 6,  7,  8,  9, 10],
          [11, 12, 13, 14, 15],
          [ 6,  7,  8,  9, 10],
          [16, 17, 18, 19, 20]],
 
         [[11, 12, 13, 14, 15],
          [ 6,  7,  8,  9, 10],
          [ 6,  7,  8,  9, 10],
          [11, 12, 13, 14, 15]],
 
         [[51, 52, 53, 54, 55],
          [46, 47, 48, 49, 50],
          [41, 42, 43, 44, 45],
          [56, 57, 58, 59, 60]]])
'''

此时我们可以看到idx1与idx2他们都是(3,4)d的矩阵,所以对于b[idx1, idx2]会先将idx1与idx2组合起来然后作为索引,去索引张量b中的第三个维度[5]。然后就得到结果了。

1.如果我们要多维度的索引,我们需要保证dim(idx1) = dim(idx2)

2.多个维度一起索引时,我们先把两个维度叠加在一起,然后根据新的索引去索引