如何将list转为tensor

发布时间 2023-04-17 18:45:59作者: 瓜瓜没有瓜子

如何将list转为tensor

在遇到需要将list转换为tensor的情况时,往往不能直接转换,而是需要借助 torch.cat 方法进行。为防止需要的时候找不到教程,本文给出示例进行该操作。

操作方法

问题

给定对于数据 xyx 的形状为 (2, 3, 4) ,表示 batch_size为2,每个batch有3个向量,每个向量维数是4y 的形状为 (2, 4) ,表示 batch_size为2,每个batch内仅有一个4维的向量。现在目标是将同一个 batchx 的3个向量分别和 y 的1个向量进行拼接,得到一个形状为 (2, 3, 4) 的数据。

代码

首先随机地生成 xy

import torch
x = torch.rand((2, 3, 4))
y = torch.rand((2, 4))
print(x)
print(y)

输出如下:

tensor([[[0.3170, 0.5800, 0.2717, 0.3887],
[0.0862, 0.4881, 0.1419, 0.1491],
[0.1860, 0.4508, 0.2637, 0.9106]],

    [[0.0923, 0.1211, 0.8768, 0.7573],
     [0.9067, 0.0651, 0.2780, 0.6712],
     [0.0755, 0.1534, 0.9984, 0.8169]]])

tensor([[0.1451, 0.0273, 0.5603, 0.3951],
[0.8981, 0.8639, 0.3545, 0.4461]])

第二步,拆分拼接

s = []
for xx, yy in zip(x, y):
    ss = []
    for i in xx:
        ss.append(torch.cat((i, yy), 0).unsqueeze(0))
    print(ss)
    ss = torch.cat(ss, dim=0)
    s.append(ss.unsqueeze(0))
s

输出如下:

[tensor([[[0.3170, 0.5800, 0.2717, 0.3887, 0.1451, 0.0273, 0.5603, 0.3951],
[0.0862, 0.4881, 0.1419, 0.1491, 0.1451, 0.0273, 0.5603, 0.3951],
[0.1860, 0.4508, 0.2637, 0.9106, 0.1451, 0.0273, 0.5603, 0.3951]]]),
tensor([[[0.0923, 0.1211, 0.8768, 0.7573, 0.8981, 0.8639, 0.3545, 0.4461],
[0.9067, 0.0651, 0.2780, 0.6712, 0.8981, 0.8639, 0.3545, 0.4461],
[0.0755, 0.1534, 0.9984, 0.8169, 0.8981, 0.8639, 0.3545, 0.4461]]])]

这一步完成了每一个 batch 中的拼接,但 batch 之间还是以 list 的方式链接的。

第三步,合成 batch

s = torch.cat(s, dim=0)
s

输出如下:

tensor([[[0.3170, 0.5800, 0.2717, 0.3887, 0.1451, 0.0273, 0.5603, 0.3951],
[0.0862, 0.4881, 0.1419, 0.1491, 0.1451, 0.0273, 0.5603, 0.3951],
[0.1860, 0.4508, 0.2637, 0.9106, 0.1451, 0.0273, 0.5603, 0.3951]],

    [[0.0923, 0.1211, 0.8768, 0.7573, 0.8981, 0.8639, 0.3545, 0.4461],
     [0.9067, 0.0651, 0.2780, 0.6712, 0.8981, 0.8639, 0.3545, 0.4461],
     [0.0755, 0.1534, 0.9984, 0.8169, 0.8981, 0.8639, 0.3545, 0.4461]]])

至此,目标达成!