Pytorch | view()函数的使用

发布时间 2023-06-26 10:15:05作者: 张Zong在修行

函数简介

Pytorch中的view函数主要用于Tensor维度的重构,即返回一个有相同数据但不同维度的Tensor。

根据上面的描述可知,view函数的操作对象应该是Tensor类型。如果不是Tensor类型,可以通过tensor = torch.tensor(data)来转换。

普通用法 (手动调整size)

view(参数a,参数b,…),其中,总的参数个数表示将张量重构后的维度。

view()相当于reshaperesize,重新调整Tensor的形状。

import torch
a1 = torch.arange(0,16)
print(a1) # tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])
a2 = a1.view(8, 2) # 将a1的维度改为8*2
a3 = a1.view(2, 8) # 将a1的维度改为2*8
a4 = a1.view(4, 4) # 将a1的维度改为4*4

# a5 = a1.view(2,2,1,4)
# 更多的维度也没有问题,只要保证维度改变前后的元素个数相同就行,即 2*2*1*4=16。

print(a2)
print(a3)
print(a4)
tensor([[ 0,  1],
        [ 2,  3],
        [ 4,  5],
        [ 6,  7],
        [ 8,  9],
        [10, 11],
        [12, 13],
        [14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7],
        [ 8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])

特殊用法:参数-1 (自动调整size)

view(参数a,参数b,…)中一个参数定为-1,代表自动调整这个维度上的元素个数,则表示该维度取决于其它维度,由Pytorch自己补充,以保证元素的总数不变。

import torch
a1 = torch.arange(0,16)
print(a1) # tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])
a2 = a1.view(-1, 16)
a3 = a1.view(-1, 8)
a4 = a1.view(-1, 4)
a5 = a1.view(-1, 2)
a6 = a1.view(4*4, -1)
a7 = a1.view(1*4, -1)
a8 = a1.view(2*4, -1)

print(a2)
print(a3)
print(a4)
print(a5)
print(a6)
print(a7)
print(a8)
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3,  4,  5,  6,  7],
        [ 8,  9, 10, 11, 12, 13, 14, 15]])
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
tensor([[ 0,  1],
        [ 2,  3],
        [ 4,  5],
        [ 6,  7],
        [ 8,  9],
        [10, 11],
        [12, 13],
        [14, 15]])
tensor([[ 0],
        [ 1],
        [ 2],
        [ 3],
        [ 4],
        [ 5],
        [ 6],
        [ 7],
        [ 8],
        [ 9],
        [10],
        [11],
        [12],
        [13],
        [14],
        [15]])
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])
tensor([[ 0,  1],
        [ 2,  3],
        [ 4,  5],
        [ 6,  7],
        [ 8,  9],
        [10, 11],
        [12, 13],
        [14, 15]])

view(-1)表示将Tensor转为一维Tensor。

a9 = a1.view(-1)

print(a1)
print(a9) # 因此,转变后还是一维,没什么变换
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])

到此这篇关于pytorch中的 .view()函数的用法介绍的文章就介绍到这了,更多相关pytorch .view()函数内容请去pytorch官网文档查看。