torch使用bool类型做检索

发布时间 2023-10-31 11:56:22作者: 蔚蓝色の天空

一、背景

在使用torch的时候,可以通过bool类型对数组进行检索操作。传统的list或者dict都是使用下标和关键字检索。而在torch中可以使用bool类型进行检索,它的的目标主要是以下功能:

  • 替换torch中的某个值

二、使用

torch在bool检索的情况下就是将为检索位置为True的地方用另一个数据进行替换。

import torch

x = torch.Tensor([1, 2, 3, 4, 5])
# print(x)
noise_labels = torch.randint(len(x), x.shape)
print(noise_labels)
labels = x.clone()
print(labels)
probability_matrix = torch.full(x.shape, 0.15)
# print(probability_matrix)
masked_indices = torch.bernoulli(probability_matrix).bool()
print(masked_indices)
labels[masked_indices] = noise_labels[masked_indices]  # 将True的部分进行修改
print(labels)

# output:
"""
masked_indices第四个位置为True,因此修改labels中第四个位置,由于噪声数据第四个的位置是1,因此labels中的数据为1
tensor([3, 2, 0, 1, 1])
tensor([1., 2., 3., 4., 5.])
tensor([False, False, False, False,  True])
tensor([1., 2., 3., 4., 1.])
"""


import torch

x = torch.Tensor([1, 2, 3, 4, 5])
# print(x)
noise_labels = torch.randint(len(x)+1999, x.shape)
print(noise_labels)
labels = x.clone()
print(labels)
probability_matrix = torch.full(x.shape, 0.15)
# print(probability_matrix)
masked_indices = torch.bernoulli(probability_matrix).bool()
print(masked_indices)
labels[masked_indices] = noise_labels[masked_indices]
print(labels)

# output:
"""
tensor([1516,  408,  274,  426,  126])
tensor([1., 2., 3., 4., 5.])
tensor([False, False, False,  True, False])
tensor([  1.,   2.,   3., 426.,   5.])
"""