pytorch 的 torchvision.datasets.ImageFolder 来自定义数据集

发布时间 2023-06-05 14:10:30作者: cold_moon
import torchvision

class ClassificationDataset(torchvision.datasets.ImageFolder):
	"""
	YOLOv5 Classification Dataset.
	Arguments
		root:  Dataset path
	"""

	def __init__(self, root):
		super().__init__(root=root) # 调用了 父类的 初始化函数,就拥有了以下的 self 属性

		classes = self.classes # list 每个类的文件名
		# ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
		class_to_idx = self.class_to_idx # 字典 每个类的文件名,类别标签(数字)
		# {'0': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9}
		samples = self.samples # list 图像路径,标签(0,1,2...)
		# [('/data/huyuzhen/proje...in/0/1.png', 0), ('/data/huyuzhen/proje...0/1000.png', 0),...
		targets = self.targets # list 类别标签 数字:0,1,2...
		# [0, 0, 0, 0, 0, 0, 0...


path = '/data/huyuzhen/projects/datasets/mnist/train'
dataset = ClassificationDataset(root=path)

自定义一个图像分类 类,mnist 数据组织为 :

mnist
	├── test
	│   ├── 0
	│   ├── 1
	...
	├── train
	│   ├── 0
	│   ├── 1
	...

ImageFolder是DatasetFolder的子类,有以下属性:

 Attributes:
    classes (list): List of the class names sorted alphabetically.
    class_to_idx (dict): Dict with items (class_name, class_index).
    samples (list): List of (sample path, class_index) tuples
    targets (list): The class_index value for each image in the dataset
"""

使用 torchvision.datasets.ImageFolder 需要把数据集按如上组织。