torch.cat()

发布时间 2023-08-15 09:22:51作者: 海_纳百川

要将一个大小为(2, 2)的PyTorch张量和一个大小为(1, 2)的张量拼接在一起,以形成一个新的大小为(3, 2)的张量,你可以使用PyTorch库中的torch.cat()函数。以下是一个示例代码:

import torch

# 假设你有一个大小为(2, 2)的张量 tensor1 和一个大小为(1, 2)的张量 tensor2
tensor1 = torch.tensor([[1, 2], [3, 4]])
tensor2 = torch.tensor([[5, 6]])

# 使用 torch.cat() 进行拼接,指定维度为 0,表示在行方向进行拼接
new_tensor = torch.cat((tensor1, tensor2), dim=0)

print(new_tensor)

在这个示例中,我们首先创建了两个大小不同的张量 tensor1 和 tensor2。然后,我们使用 torch.cat() 函数将它们在维度0上拼接起来,得到了新的大小为(3, 2)的张量 new_tensor。这里的 dim=0 表示在行方向上进行拼接。请注意,拼接的张量维度要保持一致,除了拼接维度外的其他维度应该是一样的。