【pytorch】土堆pytorch教程学习(五)torchvision 中的数据集的使用

发布时间 2023-05-03 01:37:51作者: hzyuan

torchvision 中的数据集使用

torchvision.datasets模块中提供了许多内置的数据集。

内置的数据集有 CIFAR10、MNIST、COCO等,更多可进入 pytorch 官网查看。

所有内置的数据集都继承了 torch.utils.data.Dataset 类,并且实现了 __getitem____len__

所有的数据集几乎都有相似的API。下面以 CIFAR10 数据集的使用为例来认识下内置数据集的用法。

import torchvision
from torch.utils.tensorboard import SummaryWriter

dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

'''
dataset = torchvision.datasets.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)
Args:
root(string):数据集存放的根目录。
train(bool):如果True则从训练集创建数据集,False则从测试集创建数据集。
transform(callable):需要对图像进行的转换操作
target_transforms(callable):需要对 target 进行的转换操作
download(bool):True则从网络下载数据集到根目录。如果数据集已经存在,则不再下载。
'''
# 创建训练集
train_set = torchvision.datasets.CIFAR10(root='./dataset', train=True, transform=dataset_transform, download=True) 
# 创建测试集
test_set = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=dataset_transform, download=True)

img, target = test_set[0]  # 取出图像和target
print(img, test_set.classes[target])

# 在 tensorboard 里打开十张图像
writer = SummaryWriter('logs')
for i in range(10):
    img, target = test_set[i]
    writer.add_image('test_set', img, i)
writer.close()

内置数据集很方便地供我们下载使用。根据源码或者官方文档可以了解到创建数据集所需传入的参数,然后需要关注__getitem__ 方法返回的结果是什么

自定义数据集

自己定义的数据集可以参照内置数据集,即继承 torch.utils.data.Dataset 类,并且重写 __getitem____len__
数据存放在 dataset/train里,分为两个目录 antsbees,也分别是数据的标签,如下图所示:

from PIL import Image
from torch.utils.data import Dataset
import os

class MyDataSet(Dataset):

    # 在__init__里加载
    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir  # 根目录
        self.label_dir = label_dir  # 标签目录
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.img_path = os.listdir(self.path)  # 图片路径列表

    def __getitem__(self, idx):
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.path, img_name)
        img = Image.open(img_item_path) # 获取数据 
        label = self.label_dir # 获取label
        return img, label

    def __len__(self):
        return len(self.img_path) # 获取数据集长度

# test
root_dir = 'dataset/train'
ants_label_dir = 'ants'
bees_label_dir = 'bees'
# 生成两个数据集
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)
train_dataset = ants_dataset + bees_dataset  # 拼接两个数据集

img1, label1 = ants_dataset[0]
img1.show()
print('label1:', label1)
img2, label2 = train_dataset[130]
img2.show()
print('label2:', label2)