Pytorch中利用ByteTensor()对数据进行mask掩码

发布时间 2023-06-18 00:09:51作者: 绘守辛玥

案例描述

  在DataWhale的针对VOC数据集进行目标检测的案例中,有这么一段代码(已用自定变量将其补全)

keep_difficult = False
json_file = [
    {"bbox": [[34, 11, 448, 293]], "difficulty": [1], "label": [20]},
    ...
]  # bbox为检测框锚点,difficulty为难度,label为分类
obj = json_file[0]  # 取出示例中的那一排

# 取出相应变量
bbox = torch.FloatTensor(obj['bbox'])
label = torch.LongTensor(obj['label'])
difficulty = torch.ByteTensor(obj['difficulty'])

# 如果keep_difficult为False,即不保留difficult标志为True的目标
# 那么这里将对应的目标删去
if not keep_difficult:
    boxes = bbox[1 - difficulty]
    labels = label[1 - difficulty]
    difficulties = difficulty[1 - difficulty]

  当我在分析这段代码时,不禁抛出疑惑:这不是取出第0行吗,和删除有什么关系?于是我通过print大量代码并结合这篇知乎文章总结出了答案。

问题解决

  Pytorch中的整型(int)都可用于掩码操作(mask),当整型为LongTensor(int64)、IntTensor(int32,Pytorch默认整型)时其掩码操作和Python中列表取值方式一致,即:

a = torch.randn((3, 4))
b = torch.LongTensor([1, 2])

print("a: ", a)
print("b: ", b)
print("a[b]: ", a[b])  # 取第1、2行
print("a[:, b]: ", a[b])  # 取第1、2列


# >>> a: tensor([[ 0.3405,  0.2542, -0.8521, -0.2498],
#                [-0.1697, -0.6549,  0.6513,  1.3177],
#                [ 1.2851, -2.0055, -0.6256,  0.2196]])
#
# >>> b: tensor([1, 2])
#
# >>> a[b]: tensor([[-0.1697, -0.6549,  0.6513,  1.3177],
#                   [ 1.2851, -2.0055, -0.6256,  0.2196]])
#
# >>> a[:, b]: tensor([[-0.1697, -0.6549,  0.6513,  1.3177],
#                      [ 1.2851, -2.0055, -0.6256,  0.2196]])

注: 在Pytorch中用掩码不支持ShortTensor(int16)整型。

  但当掩码类型为ByteTensor(int8)的时候,情况就会变为类似bool取值的方式,如下

a = torch.randn((3, 4))
b = torch.ByteTensor([0, 1, 2])
c = torch.ByteTensor([3, 0, 1, 2])

print("a: ", a)
print("b: ", b)
print("c: ", c)
print("a[b]", a[b])  # 取第1、2行
print("a[:, c]", a[:, c])  # 取第1、3、4列


# >>> a: tensor([[-0.1412,  0.8143,  1.0005,  0.0090],
#                [ 0.2290,  0.3913,  1.9158,  0.2050],
#                [ 0.8630, -0.6538,  0.6923,  1.2283]])
#
# >>> b: tensor([0, 1, 2], dtype=torch.uint8)
#
# >>> c: tensor([3, 0, 1, 2], dtype=torch.uint8)
#
# >>> a[b]: tensor([[ 0.2290,  0.3913,  1.9158,  0.2050],
#                   [ 0.8630, -0.6538,  0.6923,  1.2283]])
#
# >>> a[:, c]: tensor([[-0.1412,  1.0005,  0.0090],
#                      [ 0.2290,  1.9158,  0.2050],
#                      [ 0.8630,  0.6923,  1.2283]])

可以看出,当类型为ByteTensor时,不管这个Tensor的取值是几,只要非0则按照True算,为0则按照False算,然后将其变换为利用布尔值对张量a取值

进阶学习

  可以看到,在使用ByteTensor对张量掩码操作时,弹出了这么个Warning:

UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.

可以看出官方并不建议我们使用ByteTensor,并且该方法已被官方废弃,最好使用Bool类型,那么我们对上面的代码做个改进,即

a = torch.randn((3, 4))
b = torch.BoolTensor([0, 1, 2])  # 这里改为BoolTensor()
c = torch.BoolTensor([3, 0, 1, 2])

print("a: ", a)
print("b: ", b)
print("c: ", c)
print("a[b]", a[b])  # 取第1、2行
print("a[:, c]", a[:, c])  # 取第1、3、4列


# >>> a: tensor([[-1.0239,  0.8308, -0.1544,  1.2652],
#                [-1.4537, -0.6388, -0.0800, -0.0714],
#                [-2.4132, -1.0962, -0.0569,  1.7111]])
#
# >>> b: tensor([False,  True,  True])
#
# >>> c: tensor([ True, False,  True,  True])
#
# >>> a[b]: tensor([[-1.4537, -0.6388, -0.0800, -0.0714],
#                   [-2.4132, -1.0962, -0.0569,  1.7111]])
#
# >>> a[:, c]: tensor([[-1.0239, -0.1544,  1.2652],
#                      [-1.4537, -0.0800, -0.0714],
#                      [-2.4132, -0.0569,  1.7111]])

  这样一来,代码就不出现Warning了,于是这里将官方代码改写一下

if not keep_difficult:
    boxes = bbox[~difficulty]
    labels = label[~difficulty]
    difficulties = difficulty[~difficulty]