[note] pytorch的几种维度操作方式比对

发布时间 2023-08-28 16:35:15作者: NoNoe

pre

今天看代码在想 torch.unbind + torch.cattorch.reshape 的区别,直观上来看reshape似乎更便利。

chatgpt

问题

x is a tensor of three dimension, what is the difference between torch.cat(torch.unbind(x, dim=1), dim=0) and torch.reshape(x, (x.shape[0]*x.shape[1], -1))

回答

Both torch.cat(torch.unbind(x, dim=1), dim=0) and torch.reshape(x, (x.shape[0]*x.shape[1], -1)) are operations that reshape and reorganize a 3-dimensional tensor x, but they achieve this in slightly different ways.

  1. torch.cat(torch.unbind(x, dim=1), dim=0):

    • This operation involves two steps. First, it unbinds the tensor x along its second dimension (dim=1), which effectively splits the tensor into a list of tensors along that dimension. Each tensor in the list corresponds to one "column" from the original 3D tensor.
    • Then, it uses torch.cat to concatenate these split tensors along the specified dimension (dim=0), effectively stacking them on top of each other to create a new tensor with a reshaped layout.
    • The final shape of the tensor will be (x.shape[0] * x.shape[1], x.shape[2]).
  2. torch.reshape(x, (x.shape[0]*x.shape[1], -1)):

    • This operation directly reshapes the entire tensor x without splitting or concatenating it.
    • The desired shape is specified as a tuple (x.shape[0]*x.shape[1], -1), where -1 in the second dimension means that PyTorch will automatically calculate the size of that dimension to preserve the total number of elements in the tensor.
    • The final shape of the tensor will also be (x.shape[0] * x.shape[1], x.shape[2]).

In summary, both operations reshape the tensor x into a 2-dimensional tensor with the same number of elements, but they approach it differently. The first approach involves splitting and concatenating along dimensions, while the second approach directly reshapes the tensor. The end result will be the same in terms of shape and arrangement of elements.

测试

本来是相信chatgpt的,但后面越想越不对,做了如下实验:

import torch
from einops import rearrange

x = torch.randint(0, 100, (3, 2, 1))
print(x.squeeze())
# 对比四种不同的维度操作方式
a = torch.cat(torch.unbind(x, dim=1), dim=0)
b = torch.reshape(x, (x.shape[0]*x.shape[1], -1))
b2 = torch.reshape(x, (-1, x.shape[-1]))
c = x.view(x.shape[0]*x.shape[1], -1)
c2 = x.view(-1, x.shape[-1])
d = rearrange(x, 'b p c -> (b p) c')

# cat+unbind与其他三者结果均不同
print('a =', a.squeeze())
print('b =', b.squeeze())
print('b2 =', b2.squeeze())
print('c =', c.squeeze())
print('c2 =', c2.squeeze())
print('d =', d.squeeze())

# cat+unbind的结果(a)就无法像c一样用rearrange变回x
x2 = rearrange(c, '(b p) c -> b p c', b=3, p=2)
print(f'x==x2 = {(x==x2).squeeze()}')

输出:

tensor([[43, 84],
        [90, 80],
        [59, 23]])
a = tensor([43, 90, 59, 84, 80, 23])
b = tensor([43, 84, 90, 80, 59, 23])
b2 = tensor([43, 84, 90, 80, 59, 23])
c = tensor([43, 84, 90, 80, 59, 23])
c2 = tensor([43, 84, 90, 80, 59, 23])
d = tensor([43, 84, 90, 80, 59, 23])
x==x2 = tensor([[True, True],
        [True, True],
        [True, True]])
x-1 = tensor([[57, 19, 97, 12, 19, 24],
        [65, 71, 88, 40, 65, 46]])
x-2 = tensor([57, 19, 97, 12, 19, 24, 65, 71, 88, 40, 65, 46])
x-2 = tensor([[57, 19, 97, 12, 19, 24],
        [65, 71, 88, 40, 65, 46]])

ep

总的来说,假设x=[[97, 14], [ 0, 16], [55, 62]]torch.cat(torch.unbind(x, dim=1), dim=0)将x按列拆开然后拼合,得到[97, 0, 55, 14, 16, 62];而 reshape/view/rearrange则是将x按行拆开再拼合,得到[97, 14, 0, 16, 55, 62],该结果与torch.cat(torch.unbind(x, dim=0), dim=0)的一致