torch和numpy的维度交换方法

发布时间 2023-09-10 12:19:45作者: Lumeng

Tensor的维度转置方法

​ 在搭建神经网络的时候,经常会遇到需要交换维度的时候,比如将HWCN的Tensor维度顺序变换为NCHW顺序,此时需要用到Tensor的转置方法。
​ 一般有以下三种方法:

1、numpy.transpose

​ 如果Tensor是由np.Array转换而来,那么可以在变量还是np.Array的时候先进行通道转置,此时可以使用np.transpose方法:

>>> import numpy as np
>>> aa = np.ndarray((1,3,3,4))
>>> aa.shape
(1,3,3,4)
>>> aa.transpose((3,1,0,2)).shape
(4,3,1,3)
>>> np.transpose(aa,(3,1,0,2)).shape
(4,3,1,3)

​ arr.transpose(new_shape)和np.transpose(arr,new_shape)都合法,结果完全一样。

​ * 如果只是二维数组转置或者只交换第一和最后两个维度,那么也可以用arr.T方法:

>>> aa.T.shape
(4,3,3,1)
2、torch.tranpose

​ torch.transpose方法和np.transpose方法有一个最大的区别,torch.transpose只能支持两个维度的交换,函数原型为:

torch.transpose(tensor,dim0,dim1)

​ 如果超过两个维度,会报错。使用方式为:

>>> aaTensor = torch.from_numpy(aa) 

>>> aaTensor.transpose(0,3).shape
torch.Size([4,3,3,1])

>>> torch.transpose(aaTensor,0,3).shape
torch.Size([4,3,3,1])

>>> aaTensor.transpose(3,1,0,2).shape
Traceback (most recent call last):
  File "<string>", line 1, in <module>
TypeError: transpose() received an invalid combination of arguments - got (int, int, int, int), but expected one of:
 * (int dim0, int dim1)
 * (name dim0, name dim1)

​ torch.transpose方法有一个后缀格式函数tensor.transpose_(),是transpose的inplace版本,调用该函数不返回结果,直接修改原始tensor的维度:

>>> aaTensor.transpose_(3,0)
>>> aaTesor.shape
torch.Size([4,3,3,1])
3、torch.permute

torch.permute用法和numpy.transpose完全相同,接受多个指定的维度,将输入Tensor的维度按照指定的维度顺序重排:

>>> torch.permute(aaTensor,3,1,0,2).shape
torch.Size([4,3,1,3])
>>>aaTensor.permute(3,1,0,2).shape
torch.Size([4,3,1,3])

注意torch.transpose、torch.permute、arr.transpose可接受tuple、list、多个整数作为输入,而numpy.transpose只能接受tuple和list。