Dataset类

发布时间 2023-08-09 14:18:52作者: ydky

用来创建自己的数据集,提供一种方式去获取数据及其label。

1.如何获取每一个数据及其label;2.告诉我们总共有多少数据

help:所有的数据集都需要继承该类,所有的子类都应该重写__getitem__方法(获取每一个数据及其label),选择性重写__len__类(返回数据集的大小)

(b站土堆蚂蚁和蜜蜂案例数据集下载:https://download.pytorch.org/tutorial/hymenoptera_data.zip)

创建自己的数据集:1.新建数据类继承Dataset类;2.重写方法;3.实例化使用(注意文件的路径修改为自己的路径)

from torch.utils.data import Dataset
from PIL import Image
import os

class Mydata(Dataset):

   def __init__(self, root_dir, label_dir):
       self.root_dir = root_dir
       self.label_dir = label_dir
       #路径相加
       self.path = os.path.join(self.root_dir, self.label_dir)
       #左侧为list类型,右侧函数是将该文件夹下的所有文件变成一个列表,保存的是文件名
       self.img_path = os.listdir(self.path)

   def __getitem__(self, idx):
       img_name = self.img_path[idx]
       #获取图片的路径
       img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
       #获取数据集中的图片
       img = Image.open(img_item_path)
       #获取图片对应的标签(此处标签为父目录的文件名)
       label = self.label_dir
       return img, label

   def __len__(self):
       return len(self.img_path)


root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
#实例化
ants_dataset = Mydata(root_dir, ants_label_dir)
bees_dataset = Mydata(root_dir, bees_label_dir)

train_dataset = ants_dataset + bees_dataset
#获取数据集中的图片和标签
img, label = train_dataset[0]