案例描述
在DataWhale的针对VOC数据集进行目标检测的案例中,有这么一段代码(已用自定变量将其补全)
keep_difficult = False
json_file = [
{"bbox": [[34, 11, 448, 293]], "difficulty": [1], "label": [20]},
...
] # bbox为检测框锚点,difficulty为难度,label为分类
obj = json_file[0] # 取出示例中的那一排
# 取出相应变量
bbox = torch.FloatTensor(obj['bbox'])
label = torch.LongTensor(obj['label'])
difficulty = torch.ByteTensor(obj['difficulty'])
# 如果keep_difficult为False,即不保留difficult标志为True的目标
# 那么这里将对应的目标删去
if not keep_difficult:
boxes = bbox[1 - difficulty]
labels = label[1 - difficulty]
difficulties = difficulty[1 - difficulty]
当我在分析这段代码时,不禁抛出疑惑:这不是取出第0行吗,和删除有什么关系?于是我通过print大量代码并结合这篇知乎文章总结出了答案。
问题解决
Pytorch中的整型(int)都可用于掩码操作(mask),当整型为LongTensor(int64)、IntTensor(int32,Pytorch默认整型)时其掩码操作和Python中列表取值方式一致,即:
a = torch.randn((3, 4))
b = torch.LongTensor([1, 2])
print("a: ", a)
print("b: ", b)
print("a[b]: ", a[b]) # 取第1、2行
print("a[:, b]: ", a[b]) # 取第1、2列
# >>> a: tensor([[ 0.3405, 0.2542, -0.8521, -0.2498],
# [-0.1697, -0.6549, 0.6513, 1.3177],
# [ 1.2851, -2.0055, -0.6256, 0.2196]])
#
# >>> b: tensor([1, 2])
#
# >>> a[b]: tensor([[-0.1697, -0.6549, 0.6513, 1.3177],
# [ 1.2851, -2.0055, -0.6256, 0.2196]])
#
# >>> a[:, b]: tensor([[-0.1697, -0.6549, 0.6513, 1.3177],
# [ 1.2851, -2.0055, -0.6256, 0.2196]])
注: 在Pytorch中用掩码不支持ShortTensor(int16)整型。
但当掩码类型为ByteTensor(int8)的时候,情况就会变为类似bool取值的方式,如下
a = torch.randn((3, 4))
b = torch.ByteTensor([0, 1, 2])
c = torch.ByteTensor([3, 0, 1, 2])
print("a: ", a)
print("b: ", b)
print("c: ", c)
print("a[b]", a[b]) # 取第1、2行
print("a[:, c]", a[:, c]) # 取第1、3、4列
# >>> a: tensor([[-0.1412, 0.8143, 1.0005, 0.0090],
# [ 0.2290, 0.3913, 1.9158, 0.2050],
# [ 0.8630, -0.6538, 0.6923, 1.2283]])
#
# >>> b: tensor([0, 1, 2], dtype=torch.uint8)
#
# >>> c: tensor([3, 0, 1, 2], dtype=torch.uint8)
#
# >>> a[b]: tensor([[ 0.2290, 0.3913, 1.9158, 0.2050],
# [ 0.8630, -0.6538, 0.6923, 1.2283]])
#
# >>> a[:, c]: tensor([[-0.1412, 1.0005, 0.0090],
# [ 0.2290, 1.9158, 0.2050],
# [ 0.8630, 0.6923, 1.2283]])
可以看出,当类型为ByteTensor时,不管这个Tensor的取值是几,只要非0则按照True算,为0则按照False算,然后将其变换为利用布尔值对张量a取值
进阶学习
可以看到,在使用ByteTensor对张量掩码操作时,弹出了这么个Warning:
UserWarning: indexing with dtype torch.uint8 is now deprecated, please use a dtype torch.bool instead.
可以看出官方并不建议我们使用ByteTensor,并且该方法已被官方废弃,最好使用Bool类型,那么我们对上面的代码做个改进,即
a = torch.randn((3, 4))
b = torch.BoolTensor([0, 1, 2]) # 这里改为BoolTensor()
c = torch.BoolTensor([3, 0, 1, 2])
print("a: ", a)
print("b: ", b)
print("c: ", c)
print("a[b]", a[b]) # 取第1、2行
print("a[:, c]", a[:, c]) # 取第1、3、4列
# >>> a: tensor([[-1.0239, 0.8308, -0.1544, 1.2652],
# [-1.4537, -0.6388, -0.0800, -0.0714],
# [-2.4132, -1.0962, -0.0569, 1.7111]])
#
# >>> b: tensor([False, True, True])
#
# >>> c: tensor([ True, False, True, True])
#
# >>> a[b]: tensor([[-1.4537, -0.6388, -0.0800, -0.0714],
# [-2.4132, -1.0962, -0.0569, 1.7111]])
#
# >>> a[:, c]: tensor([[-1.0239, -0.1544, 1.2652],
# [-1.4537, -0.0800, -0.0714],
# [-2.4132, -0.0569, 1.7111]])
这样一来,代码就不出现Warning了,于是这里将官方代码改写一下
if not keep_difficult:
boxes = bbox[~difficulty]
labels = label[~difficulty]
difficulties = difficulty[~difficulty]