torch.cat

发布时间 2023-11-28 15:05:15作者: 黑逍逍

拼接tensor

  • torch.cat(tensors, dim): 沿指定维度拼接张量。
tensor1 = torch.tensor([[1, 2, 3], [4, 5, 6]])
tensor2 = torch.tensor([[7, 8, 9], [10, 11, 12]])

# dim=0 表示沿着第一个维度(行的方向)进行连接。
concatenated_tensor = torch.cat([tensor1, tensor2], dim=0)
tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12]])

# dim=1 表示沿着第二个维度(列的方向)进行连接。
concatenated_tensor = torch.cat([tensor1, tensor2], dim=1)
tensor([[ 1, 2, 3, 7, 8, 9],
[ 4, 5, 6, 10, 11, 12]])