深度学习模型训练中,输入数据维度和标签数据维度调整方法

发布时间 2023-09-08 23:11:11作者: 欣杰科技
for inputs, labels in train_loader:
        # 使用numpy的transpose函数调整维度顺序
        inputs = np.transpose(inputs, (0, 3, 1, 2)) #将原输入数据最后一个维度换到第二个维度
        inputs = inputs.to(device)
        print(inputs.shape) #调试代码用
        m = labels.shape  #hdf5文件有时候标签数据大小为[batch_siza,1,1,1],需要做如下调整,先获取标签数据大小
        n = m[0] #获取标签数据第一个维度值,也就是batch_size
        labels = torch.reshape(labels, (n,)) #将标签数据转换为一维数据,也就是将原来的[[[1]]]里面的标签1提取出来
        labels = labels.to(device)
        print(labels.shape) #调试代码用
        # 前向传播
        outputs = model(inputs)
        print(outputs.shape) #调试代码用
        loss = criterion(outputs, labels.long())