Test Time Augmentation

发布时间 2023-05-24 17:12:58作者: dctwan

1.概念

1.1 数据增强

Data Augmentation,训练过程中经常使用数据增强技术

大型数据集是成功应用深度神经网络的先决条件。 图像增广在对训练图像进行一系列的随机变化之后,生成相似但不同的训练样本,从而扩大了训练集的规模。 此外,应用图像增广的原因是,随机改变训练样本可以减少模型对某些属性的依赖,从而提高模型的泛化能力。 例如,我们可以以不同的方式裁剪图像,使感兴趣的对象出现在不同的位置,减少模型对于对象出现位置的依赖。 我们还可以调整亮度、颜色等因素来降低模型对颜色的敏感度。

总结,应用数据增强的原因主要有

  1. 扩大了训练集的规模
  2. 减少模型对某些属性的依赖,从而提高模型的泛化能力

1.2 测试时数据增强

test time augmentation,TTA

测试时将原始数据做不同形式的增强,然后取结果的平均值作为最终结果,可以进一步提升最终结果的精度

可以对一幅图像做多种变换,创造出多个不同版本,包括不同区域裁剪更改缩放程度等,然后对多个版本数据进行计算最后得到平均输出作为最终结果,提高了结果的稳定性和精准度

2.常用数据增强的方法

主要介绍torchvision.transforms,官网:https://pytorch.org/vision/stable/auto_examples/plot_transforms.html

2.1 翻转

以给定概率p,进行翻转

torchvision.transforms.RandomHorizontalFlip(p=0.5)	# 水平翻转,默认p=0.5
torchvision.transforms.RandomVerticalFlip(p=0.5)	# 垂直翻转,默认p=0.5

除翻转之外,还有随机按0~180角度旋转

torchvision.transforms.RandomRotation(degrees=(0, 180))

2.2 剪裁

  • (200,200)

    剪裁后图像缩放到(200,200)

  • scale=(0.1,1)

    随机裁剪一个面积为原始面积10%到100%的区域

  • ratio=(0.5,2)

    该区域的宽高比从0.5~2之间随机取值

torchvision.transforms.RandomResizedCrop((200, 200), scale=(0.1, 1), ratio=(0.5, 2))

2.3 改变颜色

  • brightness:亮度
  • contrast:对比度
  • saturation:饱和度
  • hue:色调
torchvision.transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)

3.TTA实现

以李宏毅机器学习课程HW3中TTA部分作为例子

image-20230524102428924
# Test Time Augmentation
# 1个使用test_tfm测试集
test_set = FoodDataset(os.path.join(_dataset_dir, "test"), tfm=test_tfm)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

# 5个使用train_tfm测试集
test_loaders = []
for i in range(5):
    test_set_i = FoodDataset(os.path.join(_dataset_dir, 'test'), tfm=train_tfm)
    test_loader_i = DataLoader(test_loader_i, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    test_loaders.append(test_loader_i)
# model_best = Classifier().to(device)
model_best = torchvision.models.resnet18()
model_best.fc.out_features = 11
model_best.load_state_dict(torch.load(f"{_exp_name}_best.ckpt"))
model_best = model_best.to(device)
model_best.eval()

# preds存放在6个测试集(1+5)上的测试结果矩阵,每个矩阵是(3347,11)
preds = [[], [], [], [], [], []]
prediction = []
with torch.no_grad():
    # 用test_tfm的测试集
    for data, _ in test_loader:
        test_preds = model_best(data.to(device)).cpu().data.numpy()
        preds[0].extend(test_preds)
    # 5个用train_tfm的测试集
    for i, loader in enumerate(test_loaders):
        for data, _ in loader:
            test_preds = model_best(data.to(device).cpu().data.numpy())
            preds[i+1].extend(test_preds)

preds_np = np.array(preds, dtype=object)
print('preds_np shape: {}'.format(preds_np.shape))
# 对6个测试结果加权求和
bb = 0.5 * preds_np[0] + 0.1 * preds_np[1] + 0.1 * preds_np[2] + 0.1 * preds_np[3] + 0.1 * preds_np[4] + 0.1 * preds_np[5]
print('bb shape: {}'.format(bb.shape))
prediction = np.argmax(bb, axis=1)

4.ttach

使用pytorch实现的TTA包

待学习...

参考:

  1. 13.1. 图像增广 — 动手学深度学习 2.0.0 documentation (d2l.ai)
  2. 李宏毅_机器学习_作业3(详解)
  3. 深度学习中的TTA(Test Time Augmentation)--测试时数据增强技术