Pytorch one-hot编码

发布时间 2023-04-14 10:11:22作者: 抚琴尘世客

1. 引言

  在我们做分割任务时,通常会给一个mask,但训练时要进行onehot编码。

2. code

import torch

if __name__ == '__main__':
    label = torch.zeros(size=(1, 4, 4), dtype=torch.int)
    label[:, 2:4] = 1
    print(label.shape)
    print(label)
    label_one_hot = torch.zeros([2, 4, 4])
    label_one_hot.scatter_(0, label.long(), 1)
    print(label_one_hot)
    label_one_hot = torch.softmax(label_one_hot, 0)
    print(label_one_hot)
    label_one_hot = torch.max(label_one_hot, 0)[1]
    print(label_one_hot)

运行结果

torch.Size([1, 4, 4])
tensor([[[0, 0, 0, 0],
         [0, 0, 0, 0],
         [1, 1, 1, 1],
         [1, 1, 1, 1]]], dtype=torch.int32)
tensor([[[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]],

        [[0., 0., 0., 0.],
         [0., 0., 0., 0.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]]])
tensor([[[0.7311, 0.7311, 0.7311, 0.7311],
         [0.7311, 0.7311, 0.7311, 0.7311],
         [0.2689, 0.2689, 0.2689, 0.2689],
         [0.2689, 0.2689, 0.2689, 0.2689]],

        [[0.2689, 0.2689, 0.2689, 0.2689],
         [0.2689, 0.2689, 0.2689, 0.2689],
         [0.7311, 0.7311, 0.7311, 0.7311],
         [0.7311, 0.7311, 0.7311, 0.7311]]])
tensor([[0, 0, 0, 0],
        [0, 0, 0, 0],
        [1, 1, 1, 1],
        [1, 1, 1, 1]])

3. 结语

  努力去爱周围的每一个人,付出,不一定有收获,但是不付出就一定没有收获! 给街头卖艺的人零钱,不和深夜还在摆摊的小贩讨价还价。愿我的博客对你有所帮助(*^▽^*)(*^▽^*)!

  如果客官喜欢小生的园子,记得关注小生哟,小生会持续更新(#^.^#)(#^.^#)。